import json
import random
import numpy as np
from termcolor import colored
import trax
from trax import layers as tl
from trax.supervised import training
import w4_unittest
1 Introduction
In this project, we are going to use the Reformer, also known as the efficient Transformer, to generate a dialogue between two bots. We will feed conversations to our model and it will learn how to understand the context of each one. Not only will it learn how to answer questions but it will also know how to ask questions if it needs more info. For example, after a customer asks for a train ticket, the chatbot can ask what time the said customer wants to leave. You could use this concept to automate call centers, hotel receptions, personal trainers, or any type of customer service.
We will:
- Understand how the Reformer works
- Explore the MultiWoz dataset
- Process the data to feed it into the model
- Train our model
- Generate a dialogue by feeding a question to the model
2 Exploring the MultiWoz Dataset
We will start by exploring the MultiWoz dataset. The dataset we are about to use has more than 10,000 human annotated dialogues and spans multiple domains and topics. Some dialogues include multiple domains and others include single domains. In this section, we will load and explore this dataset, as well as develop a function to extract the dialogues.
Let’s first import the modules we will be using:
Let’s also declare some constants we will be using in the exercises.
# filename of the MultiWOZ dialogue dataset
= 'data.json'
DATA_FILE
# data directory
= './data'
DATA_DIR
# dictionary where we will load the dialogue dataset
= {}
DIALOGUE_DB
# vocabulary filename
= 'en_32k.subword'
VOCAB_FILE
# vocabulary file directory
= 'data/vocabs' VOCAB_DIR
Let’s now load the MultiWOZ 2.1 dataset already downloaded.
# help function to load a JSON file
def load_json(directory, file):
with open(f'{directory}/{file}') as file:
= json.load(file)
db return db
# load the dialogue data set into our dictionary
= load_json(DATA_DIR, DATA_FILE) DIALOGUE_DB
Let’s see how many dialogues we have in the dictionary. 1 key-value pair is one dialogue so we can just get the dictionary’s length.
print(f'The number of dialogues is: {len(DIALOGUE_DB)}')
The number of dialogues is: 10438
The dialogues are composed of multiple files and the filenames are used as keys in our dictionary. Those with multi-domain dialogues have “MUL” in their filenames while single domain dialogues have either “SNG” or “WOZ”.
# print 7 keys from the dataset to see the filenames
print(list(DIALOGUE_DB.keys())[0:7])
['SNG01856.json', 'SNG0129.json', 'PMUL1635.json', 'MUL2168.json', 'SNG0073.json', 'SNG01445.json', 'MUL2105.json']
As we can see from the cells above, there are 10,438 conversations, each in its own file. We will train your model on all those conversations. Each file is also loaded into a dictionary and each has two keys which are the following:
# get keys of the fifth file in the list above
print(DIALOGUE_DB['SNG0073.json'].keys())
dict_keys(['goal', 'log'])
The goal
also points to a dictionary and it contains several keys pertaining to the objectives of the conversation. For example below, we can see that the conversation will be about booking a taxi.
'SNG0073.json']['goal'] DIALOGUE_DB[
{'taxi': {'info': {'leaveAt': '17:15',
'destination': 'pizza hut fen ditton',
'departure': "saint john's college"},
'reqt': ['car type', 'phone'],
'fail_info': {}},
'police': {},
'hospital': {},
'hotel': {},
'attraction': {},
'train': {},
'message': ["You want to book a <span class='emphasis'>taxi</span>. The taxi should go to <span class='emphasis'>pizza hut fen ditton</span> and should depart from <span class='emphasis'>saint john's college</span>",
"The taxi should <span class='emphasis'>leave after 17:15</span>",
"Make sure you get <span class='emphasis'>car type</span> and <span class='emphasis'>contact number</span>"],
'restaurant': {}}
The log
on the other hand contains the dialog. It is a list of dictionaries and each element of this list contains several descriptions as well. Let’s look at an example:
# get first element of the log list
'SNG0073.json']['log'][0] DIALOGUE_DB[
{'text': "I would like a taxi from Saint John's college to Pizza Hut Fen Ditton.",
'metadata': {},
'dialog_act': {'Taxi-Inform': [['Dest', 'pizza hut fen ditton'],
['Depart', "saint john 's college"]]},
'span_info': [['Taxi-Inform', 'Dest', 'pizza hut fen ditton', 11, 14],
['Taxi-Inform', 'Depart', "saint john 's college", 6, 9]]}
For this project, we are only interested in the conversation which is in the text
field. The conversation goes back and forth between two persons. Let’s call them ‘Person 1’ and ‘Person 2’. This implies that data[‘SNG0073.json’][‘log’][0][‘text’] is ‘Person 1’ and data[‘SNG0073.json’][‘log’][1][‘text’] is ‘Person 2’ and so on. The even offsets are ‘Person 1’ and the odd offsets are ‘Person 2’.
print(' Person 1: ', DIALOGUE_DB['SNG0073.json']['log'][0]['text'])
print(' Person 2: ',DIALOGUE_DB['SNG0073.json']['log'][1]['text'])
Person 1: I would like a taxi from Saint John's college to Pizza Hut Fen Ditton.
Person 2: What time do you want to leave and what time do you want to arrive by?
2.1 get_conversation
We will now implement the get_conversation()
function that will extract the conversations from the dataset’s file.
We will implement a function to extract conversations from the input file.
As described above, the conversation is in the text
field in each of the elements in the log
list of the file. If the log list has x
number of elements, then the function will get the text
entries of each of those elements. Our function should return the conversation, prepending each field with either ’ Person 1: ’ if ‘x’ is even or ’ Person 2: ’ if ‘x’ is odd. We can use the Python modulus operator ‘%’ to help select the even/odd entries. Important note: Do not print a newline character (i.e. \n
) when generating the string. For example, in the code cell above, your function should output something like:
Person 1: I would like a taxi from Saint John's college to Pizza Hut Fen Ditton. Person 2: What time do you want to leave and what time do you want to arrive by?
and not:
Person 1: I would like a taxi from Saint John's college to Pizza Hut Fen Ditton.
Person 2: What time do you want to leave and what time do you want to arrive by?
def get_conversation(file, data_db):
'''
Args:
file (string): filename of the dialogue file saved as json
data_db (dict): dialogue database
Returns:
string: A string containing the 'text' fields of data[file]['log'][x]
'''
# initialize empty string
= ''
result
# get length of file's log list
= len(data_db[file]['log'])
len_msg_log
# set the delimiter strings
= ' Person 1: '
delimiter_1 = ' Person 2: '
delimiter_2
# loop over the file's log list
for i in range(len_msg_log):
# get i'th element of file log list
= data_db[file]['log'][i]
cur_log
# check if i is even
if i%2 == 0:
# append the 1st delimiter string
+= delimiter_1
result else:
# append the 2nd delimiter string
+= delimiter_2
result
# append the message text from the log
+= cur_log['text']
result
return result
file = 'SNG01856.json'
= get_conversation(file, DIALOGUE_DB)
conversation
# print raw output
print(conversation)
Person 1: am looking for a place to to stay that has cheap price range it should be in a type of hotel Person 2: Okay, do you have a specific area you want to stay in? Person 1: no, i just need to make sure it's cheap. oh, and i need parking Person 2: I found 1 cheap hotel for you that includes parking. Do you like me to book it? Person 1: Yes, please. 6 people 3 nights starting on tuesday. Person 2: I am sorry but I wasn't able to book that for you for Tuesday. Is there another day you would like to stay or perhaps a shorter stay? Person 1: how about only 2 nights. Person 2: Booking was successful.
Reference number is : 7GAWK763. Anything else I can do for you? Person 1: No, that will be all. Good bye. Person 2: Thank you for using our services.
We can have a utility pretty print function just so we can visually follow the conversation more easily.
def print_conversation(conversation):
= 'Person 1: '
delimiter_1 = 'Person 2: '
delimiter_2
= conversation.split(delimiter_1)
split_list_d1
for sublist in split_list_d1[1:]:
= sublist.split(delimiter_2)
split_list_d2 print(colored(f'Person 1: {split_list_d2[0]}', 'red'))
if len(split_list_d2) > 1:
print(colored(f'Person 2: {split_list_d2[1]}', 'green'))
print_conversation(conversation)
Person 1: am looking for a place to to stay that has cheap price range it should be in a type of hotel
Person 2: Okay, do you have a specific area you want to stay in?
Person 1: no, i just need to make sure it's cheap. oh, and i need parking
Person 2: I found 1 cheap hotel for you that includes parking. Do you like me to book it?
Person 1: Yes, please. 6 people 3 nights starting on tuesday.
Person 2: I am sorry but I wasn't able to book that for you for Tuesday. Is there another day you would like to stay or perhaps a shorter stay?
Person 1: how about only 2 nights.
Person 2: Booking was successful.
Reference number is : 7GAWK763. Anything else I can do for you?
Person 1: No, that will be all. Good bye.
Person 2: Thank you for using our services.
For this project, we will just use the outputs of the calls to get_conversation
to train the model. But just to expound, there is also other information in the MultiWoz dataset that can be useful in other contexts. Each element of the log list has more information about it. For example, above, if you were to look at the other fields for the following, “am looking for a place to stay that has cheap price range it should be in a type of hotel”, you will get the following.
'SNG01856.json']['log'][0] DIALOGUE_DB[
{'text': 'am looking for a place to to stay that has cheap price range it should be in a type of hotel',
'metadata': {},
'dialog_act': {'Hotel-Inform': [['Type', 'hotel'], ['Price', 'cheap']]},
'span_info': [['Hotel-Inform', 'Type', 'hotel', 20, 20],
['Hotel-Inform', 'Price', 'cheap', 10, 10]]}
The dataset also comes with hotel, hospital, taxi, train, police, and restaurant databases. For example, in case you need to call a doctor, or a hotel, or a taxi, this will allow you to automate the entire conversation.
# this is an example of the attractions file
= open('data/attraction_db.json')
attraction_file = json.load(attraction_file)
attractions print(attractions[0])
{'address': 'pool way, whitehill road, off newmarket road', 'area': 'east', 'entrance fee': '?', 'id': '1', 'location': [52.208789, 0.154883], 'name': 'abbey pool and astroturf pitch', 'openhours': '?', 'phone': '01223902088', 'postcode': 'cb58nt', 'pricerange': '?', 'type': 'swimmingpool'}
# this is an example of the hospital file
= open('data/hospital_db.json')
hospital_file = json.load(hospital_file)
hospitals print(hospitals[0]) # feel free to index into other indices
{'department': 'neurosciences critical care unit', 'id': 0, 'phone': '01223216297'}
# this is an example of the hotel file
= open('data/hotel_db.json')
hotel_file = json.load(hotel_file)
hotels print(hotels[0]) # feel free to index into other indices
{'address': '124 tenison road', 'area': 'east', 'internet': 'yes', 'parking': 'no', 'id': '0', 'location': [52.1963733, 0.1987426], 'name': 'a and b guest house', 'phone': '01223315702', 'postcode': 'cb12dp', 'price': {'double': '70', 'family': '90', 'single': '50'}, 'pricerange': 'moderate', 'stars': '4', 'takesbookings': 'yes', 'type': 'guesthouse'}
# this is an example of the police file
= open('data/police_db.json')
police_file = json.load(police_file)
police print(police[0]) # feel free to index into other indices
{'name': 'Parkside Police Station', 'address': 'Parkside, Cambridge', 'id': 0, 'phone': '01223358966'}
# this is an example of a restaurant file
= open('data/restaurant_db.json')
restaurant_file = json.load(restaurant_file)
restaurants print(restaurants[0]) # feel free to index into other indices
{'address': 'Regent Street City Centre', 'area': 'centre', 'food': 'italian', 'id': '19210', 'introduction': 'Pizza hut is a large chain with restaurants nationwide offering convenience pizzas pasta and salads to eat in or take away', 'location': [52.20103, 0.126023], 'name': 'pizza hut city centre', 'phone': '01223323737', 'postcode': 'cb21ab', 'pricerange': 'cheap', 'type': 'restaurant'}
For more information about the multiwoz 2.1 data set, please run the cell below to read the ReadMe.txt
file.
with open('data/README') as file:
print(file.read())
#####################################################
#####################################################
# Copyright Cambridge Dialogue Systems Group, 2018 #
#####################################################
#####################################################
Dataset contains the following files:
1. data.json: the woz dialogue dataset, which contains the conversation users and wizards, as well as a set of coarse labels for each user turn. This file contains both system and user dialogue acts annotated at the turn level. Files with multi-domain dialogues have "MUL" in their names. Single domain dialogues have either "SNG" or "WOZ" in their names.
2. restaurant_db.json: the Cambridge restaurant database file, containing restaurants in the Cambridge UK area and a set of attributes.
3. attraction_db.json: the Cambridge attraction database file, contining attractions in the Cambridge UK area and a set of attributes.
4. hotel_db.json: the Cambridge hotel database file, containing hotels in the Cambridge UK area and a set of attributes.
5. train_db.json: the Cambridge train (with artificial connections) database file, containing trains in the Cambridge UK area and a set of attributes.
6. hospital_db.json: the Cambridge hospital database file, contatining information about departments.
7. police_db.json: the Cambridge police station information.
8. taxi_db.json: slot-value list for taxi domain.
9. valListFile.txt: list of dialogues for validation.
10. testListFile.txt: list of dialogues for testing.
11. system_acts.json:
There are 6 domains ('Booking', 'Restaurant', 'Hotel', 'Attraction', 'Taxi', 'Train') and 1 dummy domain ('general').
A domain-dependent dialogue act is defined as a domain token followed by a domain-independent dialogue act, e.g. 'Hotel-inform' means it is an 'inform' act in the Hotel domain.
Dialogue acts which cannot take slots, e.g., 'good bye', are defined under the 'general' domain.
A slot-value pair defined as a list with two elements. The first element is slot token and the second one is its value.
If a dialogue act takes no slots, e.g., dialogue act 'offer booking' for an utterance 'would you like to take a reservation?', its slot-value pair is ['none', 'none']
There are four types of values:
1) If a slot takes a binary value, e.g., 'has Internet' or 'has park', the value is either 'yes' or 'no'.
2) If a slot is under the act 'request', e.g., 'request' about 'area', the value is expressed as '?'.
3) The value that appears in the utterance e.g., the name of a restaurant.
4) If for some reason the turn does not have an annotation then it is labeled as "No Annotation."
12. ontology.json: Data-based ontology containing all the values for the different slots in the domains.
13. slot_descriptions.json: A collection of human-written slot descriptions for each slot in the dataset. Each slot has at least two descriptions.
14. tokenization.md: A description of the tokenization preprocessing we had to perform to maintain consistency between the dialogue act annotations of DSTC 8 Track 1 and the existing MultiWOZ 2.0 data.
As we can see, there are many other aspects of the MultiWoz dataset. Nonetheless, we’ll see that even with just the conversations, our model will still be able to generate useful responses. This concludes our exploration of the dataset. In the next section, we will do some preprocessing before we feed it into our model for training.
3 Processing the Data for Reformer Inputs
We will now use the get_conversation()
function to process the data. The Reformer expects inputs of this form:
Person 1: Why am I so happy? Person 2: Because you are learning NLP Person 1: … Person 2: …*
And the conversation keeps going with some text. As we can see ‘Person 1’ and ‘Person 2’ act as delimiters so the model automatically recognizes the person and who is talking. It can then come up with the corresponding text responses for each person. Let’s proceed to process the text in this fashion for the Reformer. First, let’s grab all the conversation strings from all dialogue files and put them in a list.
# the keys are the file names
= DIALOGUE_DB.keys()
all_files
# initialize empty list
= []
untokenized_data
# loop over all files
for file in all_files:
# this is the graded function you coded
# returns a string delimited by Person 1 and Person 2
= get_conversation(file, DIALOGUE_DB)
result
# append to the list
untokenized_data.append(result)
# print the first element to check if it's the same as the one we got before
print(untokenized_data[0])
Person 1: am looking for a place to to stay that has cheap price range it should be in a type of hotel Person 2: Okay, do you have a specific area you want to stay in? Person 1: no, i just need to make sure it's cheap. oh, and i need parking Person 2: I found 1 cheap hotel for you that includes parking. Do you like me to book it? Person 1: Yes, please. 6 people 3 nights starting on tuesday. Person 2: I am sorry but I wasn't able to book that for you for Tuesday. Is there another day you would like to stay or perhaps a shorter stay? Person 1: how about only 2 nights. Person 2: Booking was successful.
Reference number is : 7GAWK763. Anything else I can do for you? Person 1: No, that will be all. Good bye. Person 2: Thank you for using our services.
Now let us split the list to a train and eval dataset.
# shuffle the list we generated above
random.shuffle(untokenized_data)
# define a cutoff (5% of the total length for this assignment)
# convert to int because we will use it as a list index
= int(len(untokenized_data) * .05)
cut_off
# slice the list. the last elements after the cut_off value will be the eval set. the rest is for training.
= untokenized_data[:-cut_off], untokenized_data[-cut_off:]
train_data, eval_data
print(f'number of conversations in the data set: {len(untokenized_data)}')
print(f'number of conversations in train set: {len(train_data)}')
print(f'number of conversations in eval set: {len(eval_data)}')
number of conversations in the data set: 10438
number of conversations in train set: 9917
number of conversations in eval set: 521
3.1 Tokenizing, Batching with Bucketing
We can now proceed in generating tokenized batches of our data. Let’s first define a utility generator function to yield elements from our data sets:
def stream(data):
# loop over the entire data
while True:
# get a random element
= random.choice(data)
d
# yield a tuple pair of identical values
# (i.e. our inputs to the model will also be our targets during training)
yield (d, d)
Now let’s define our data pipeline for tokenizing and batching our data. We will bucket by length and also have an upper bound on the token length.
# trax allows us to use combinators to generate our data pipeline
= trax.data.Serial(
data_pipeline # randomize the stream
trax.data.Shuffle(),
# tokenize the data
=VOCAB_DIR,
trax.data.Tokenize(vocab_dir=VOCAB_FILE),
vocab_file
# filter too long sequences
2048),
trax.data.FilterByLength(
# bucket by length
=[128, 256, 512, 1024],
trax.data.BucketByLength(boundaries=[16, 8, 4, 2, 1]),
batch_sizes
# add loss weights but do not add it to the padding tokens (i.e. 0)
=0)
trax.data.AddLossWeights(id_to_mask
)
# apply the data pipeline to our train and eval sets
= data_pipeline(stream(train_data))
train_stream = data_pipeline(stream(eval_data)) eval_stream
Peek into the train stream.
# the stream generators will yield (input, target, weights). let's just grab the input for inspection
= next(train_stream)
inp, _, _
# print the shape. format is (batch size, token length)
print("input shape: ", inp.shape)
# detokenize the first element
print(trax.data.detokenize(inp[0], vocab_dir=VOCAB_DIR, vocab_file=VOCAB_FILE))
input shape: (4, 512)
Person 1: Hello- I would like some information about visiting Corpus Christi please Person 2: Corpus christi is a college located in the centre of town. The phone number is 01223338000 and is located at king's parade. Person 1: Can I have the post code please? Person 2: The postcode is cb21rh. Person 1: Is there an entrance fee? Person 2: the admission is 2 pounds. Person 1: Can you also find me a place to stay in the centre? Person 2: There are several places that are located in the same area, can you give me some more preferences? Person 1: I'd like a moderately priced hotel with free wifi and parking. Person 2: I have 4 available hotels in the centre. Two of them have a cheap price range, and two have an expensive range. Would one of these do? Person 1: I'm looking for a moderate priced hotel for 6 people and 5 nights from Sunday. Person 2: I'm sorry, I'm not pulling up any matches. Person 1: Okay, how about a moderately-priced hotel in the south area instead that has free wifi and free parking? Person 2: I have two guesthouses that match your request; the Aylesbray Lodge and Bridge Guesthouse. Aylesbray has 4 stars and Bridge Guesthouse has 3. Which would you prefer? Person 1: Aylesbray sounds good. I need a booking for six, five nights starting from sunday. Person 2: Booking was successful reference number is GS1J7NYI. Is there anything else I can help you with today? Person 1: That is all I need today, thank you for your help. Person 2: You are welcome, have a blessed day.
4 Reversible Layers
When running large deep models, you will often run out of memory as each layer allocates memory to store activations for use in backpropagation. To save this resource, we need to be able to recompute these activations during the backward pass without storing them during the forward pass. Lets take a look first at the leftmost diagram below.
- This is how the residual networks are implemented in the standard Transformer. It follows that, given
F()
is Attention andG()
is Feed-forward(FF).
\[\begin{align} \mathrm{y}_\mathrm{a} &= \mathrm{x} + \mathrm{F}\left(\mathrm{x}\right)\tag{1} \\ \mathrm{y}_{b}&=\mathrm{y}_{a}+\mathrm{G}\left(\mathrm{y}_{a}\right)\tag{2}\\ \end{align}\]
As we can see, it requires that \(\mathrm{x}\) and \(\mathrm{y}_{a}\) be saved so it can be used during backpropagation. We want to avoid this to conserve memory and this is where reversible residual connections come in. They are shown in the middle and rightmost diagrams above. The key idea is that we will start with two copies of the input to the model and at each layer we will only update one of them. The activations that we don’t update are the ones that will be used to compute the residuals.
Now in this reversible set up you get the following instead:
\[\begin{align} \mathrm{y}_{1}&=\mathrm{x}_{1}+\mathrm{F}\left(\mathrm{x}_{2}\right)\tag{3}\\ \mathrm{y}_{2}&=\mathrm{x}_{2}+\mathrm{G}\left(\mathrm{y}_{1}\right)\tag{4}\\ \end{align}\] To recover \(\mathrm{(x_1,x_2)}\) from \(\mathrm{(y_1, y_2)}\)
\[\begin{align} \mathrm{x}_{2}&=\mathrm{y}_{2}-\mathrm{G}\left(\mathrm{y}_{1}\right)\tag{5}\\ \mathrm{x}_{1}&=\mathrm{y}_{1}-\mathrm{F}\left(\mathrm{x}_{2}\right)\tag{6}\\ \end{align}\]
With this configuration, we’re now able to run the network fully in reverse. You’ll notice that during the backward pass, \(\mathrm{x2}\) and \(\mathrm{x1}\) can be recomputed based solely on the values of \(\mathrm{y2}\) and \(\mathrm{y1}\). No need to save it during the forward pass.
We will implement the reversible_layer_forward
function using equations 3 and 4 above. This function takes in the input vector x
and the functions f
and g
and returns the concatenation of \(y_1 and y_2\). For this, we will be splitting x
before going through the reversible residual steps\(\mathrm{^1}\). We can then use those two vectors for the reversible_layer_reverse
function. Utilize np.concatenate()
to form the output being careful to match the axis of the np.split()
.
\(\mathrm{^1}\)Take note that this is just for demonstrating the concept in this exercise and there are other ways of processing the input. As we’ll see in the Reformer architecture later, the initial input (i.e. x
) can instead be duplicated instead of split.
def reversible_layer_forward(x, f, g):
"""
Args:
x (np.array): an input vector or matrix
f (function): a function which operates on a vector/matrix
g (function): a function which operates on a vector/matrix
Returns:
y (np.array): an output vector or matrix whose form is determined by 'x', f and g
"""
# split the input vector into two (* along the last axis because it is the depth dimension)
= np.split(x, 2, axis=-1)
x1, x2
# get y1 using equation 3
= x1 + f(x2)
y1
# get y2 using equation 4
= x2 + g(y1)
y2
# concatenate y1 and y2 along the depth dimension. be sure output is of type np.ndarray
= np.concatenate([y1, y2], axis=-1)
y
return y
4.1 reversible_layer_reverse
We will now implement the reversible_layer_reverse
function which is possible because at every time step you have \(x_1\) and \(x_2\) and \(y_2\) and \(y_1\), along with the function f
, and g
. Where f
is the attention and g
is the feedforward. This allows you to compute equations 5 and 6.
We will now implement the reversible_layer_reverse
. Our function takes in the output vector from reversible_layer_forward
and functions f and g. Using equations 5 and 6 above, it computes the inputs to the layer, \(x_1\) and \(x_2\). The output, x, is the concatenation of \(x_1, x_2\). Utilize np.concatenate()
to form the output being careful to match the axis of the np.split()
.
def reversible_layer_reverse(y, f, g):
"""
Args:
y (np.array): an input vector or matrix
f (function): a function which operates on a vector/matrix of the form of 'y'
g (function): a function which operates on a vector/matrix of the form of 'y'
Returns:
y (np.array): an output vector or matrix whose form is determined by 'y', f and g
"""
# split the input vector into two (* along the last axis because it is the depth dimension)
= np.split(y, 2, axis=-1)
y1, y2
# compute x2 using equation 5
= y2 - g(y1)
x2
# compute x1 using equation 6
= y1 - f(x2)
x1
# concatenate x1 and x2 along the depth dimension
= np.concatenate([x1, x2], axis=-1)
x
return x
# UNIT TEST
= lambda x: x + 2
f = lambda x: x * 3
g = np.random.uniform(size=(32,))
input_vector
= reversible_layer_forward(input_vector, f, g)
output_vector = reversible_layer_reverse(output_vector, f, g)
reversed_vector
assert np.allclose(reversed_vector, input_vector)
4.2 Reversible Layers and Randomness
Utilizing the same key, trax.fastmath.random.uniform()
will return the same values. This is required for the backward pass to return the correct layer inputs when random noise is introduced in the layer.
# Layers like dropout have noise, so let's simulate it here:
= lambda x: x + np.random.uniform(size=x.shape)
f
# See that the above doesn't work any more:
= reversible_layer_forward(input_vector, f, g)
output_vector = reversible_layer_reverse(output_vector, f, g)
reversed_vector
assert not np.allclose(reversed_vector, input_vector) # Fails!!
# It failed because the noise when reversing used a different random seed.
= 27686
random_seed = trax.fastmath.random.get_prng(random_seed)
rng = lambda x: x + trax.fastmath.random.uniform(key=rng, shape=x.shape)
f
# See that it works now as the same rng is used on forward and reverse.
= reversible_layer_forward(input_vector, f, g)
output_vector = reversible_layer_reverse(output_vector, f, g)
reversed_vector
assert np.allclose(reversed_vector, input_vector, atol=1e-07)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
5 ReformerLM Training
We will now proceed to training our model. Since we have already know the two main components that differentiates it from the standard Transformer, LSH and reversible layers above, we can just use the pre-built model already implemented in Trax. It will have this architecture:
Similar to the Transformer we learned earlier, we want to apply an attention and feed forward layer to our inputs. For the Reformer, we improve the memory efficiency by using reversible decoder blocks and we can picture its implementation in Trax like below:
We can see that it takes the initial inputs x1
and x2
and does the first equation of the reversible networks we learned in earlier articles. As we’ve also learned, the reversible residual has two equations for the forward-pass so doing just one of them will just constitute half of the reversible decoder block. Before doing the second equation (i.e. second half of the reversible residual), it first needs to swap the elements to take into account the stack semantics in Trax. It simply puts x2
on top of the stack so it can be fed to the add block of the half-residual layer. It then swaps the two outputs again so it can be fed to the next layer of the network. All of these arrives at the two equations it can be used to recompute the activations during the backward pass.
5.1 ReformerLM
We will now implement a wrapper function that returns a Reformer Language Model. We can use Trax’s ReformerLM to do this quickly. It will have the same architecture as shown above.
def ReformerLM(vocab_size=33000, n_layers=2, mode='train', attention_type=tl.SelfAttention):
# initialize an instance of Trax's ReformerLM class
= tl.Serial(
model
trax.models.reformer.ReformerLM( # set vocab size
=vocab_size,
vocab_size# set number of layers
=n_layers,
n_layers# set mode
=mode,
mode# set attention type
=attention_type
attention_type
)
, tl.LogSoftmax()
) return model # tl.Serial(model, tl.LogSoftmax(),)
# display the model
= ReformerLM('train')
temp_model print(str(temp_model))
# free memory
#del temp_model
Serial[
Serial[
Serial[
ShiftRight(1)
]
Embedding_train_512
Dropout
Serial[
PositionalEncoding
]
Dup_out2
ReversibleSerial_in2_out2[
ReversibleHalfResidualDecoderAttn_in2_out2[
Serial[
LayerNorm
]
SelfAttention
]
ReversibleSwap_in2_out2
ReversibleHalfResidualDecoderFF_in2_out2[
Serial[
LayerNorm
Dense_2048
Dropout
Serial[
FastGelu
]
Dense_512
Dropout
]
]
ReversibleSwap_in2_out2
ReversibleHalfResidualDecoderAttn_in2_out2[
Serial[
LayerNorm
]
SelfAttention
]
ReversibleSwap_in2_out2
ReversibleHalfResidualDecoderFF_in2_out2[
Serial[
LayerNorm
Dense_2048
Dropout
Serial[
FastGelu
]
Dense_512
Dropout
]
]
ReversibleSwap_in2_out2
]
Concatenate_in2
LayerNorm
Dropout
Serial[
Dense_train
]
]
LogSoftmax
]
5.2 training_loop
We will now write a function that takes in our model and trains it.
We will implement the training_loop
below to train the neural network above. Here is a list of things we should do:
- Create
TrainTask
andEvalTask
- Create the training loop
trax.supervised.training.Loop
- Pass in the following depending to train_task :
labeled_data=train_gen
loss_layer=tl.CrossEntropyLoss()
optimizer=trax.optimizers.Adam(0.01)
lr_schedule=lr_schedule
n_steps_per_checkpoint=10
We will be using our CrossEntropyLoss loss function with Adam optimizer. Please read the trax documentation to get a full understanding.
- Pass in the following to eval_task:
labeled_data=eval_gen
metrics=[tl.CrossEntropyLoss(), tl.Accuracy()]
This function should return a training.Loop
object. To read more about this check the docs.
def training_loop(ReformerLM, train_gen, eval_gen, output_dir = "./model/"):
"""
Args:
ReformerLM: the Reformer language model you are building
train_gen (generator): train data generator.
eval_gen (generator): Validation generator.
output_dir (string): Path to save the model output. Defaults to './model/'.
Returns:
trax.supervised.training.Loop: Training loop for the model.
"""
# use the warmup_and_rsqrt_decay learning rate schedule
= trax.lr.warmup_and_rsqrt_decay(
lr_schedule =1000, max_value=0.01)
n_warmup_steps
# define the train task
= training.TrainTask(
train_task # labeled data
=train_gen,
labeled_data# loss layer
=tl.CrossEntropyLoss(),
loss_layer# optimizer
=trax.optimizers.Adam(0.01),
optimizer# lr_schedule
=lr_schedule,
lr_schedule# n_steps
=10
n_steps_per_checkpoint
)
# define the eval task
= training.EvalTask(
eval_task # labeled data
=eval_gen,
labeled_data# metrics
=[tl.CrossEntropyLoss(), tl.Accuracy()]
metrics
)
= training.Loop(ReformerLM(mode='train'),
loop
train_task,=[eval_task],
eval_tasks=output_dir)
output_dirreturn loop
# we will now test our function
!rm -f model/model.pkl.gz
= training_loop(ReformerLM, train_stream, eval_stream)
loop 10) loop.run(
Step 1: Total number of trainable weights: 58072296
Step 1: Ran 1 train steps in 53.39 secs
Step 1: train CrossEntropyLoss | 10.45205879
Step 1: eval CrossEntropyLoss | 10.43009472
Step 1: eval Accuracy | 0.00000000
Step 10: Ran 9 train steps in 116.91 secs
Step 10: train CrossEntropyLoss | 10.23098850
Step 10: eval CrossEntropyLoss | 9.81040001
Step 10: eval Accuracy | 0.05645161
6 Decode from a Pretrained Model
We will now proceed on decoding using the model architecture we just implemented. As previously, we will be using a pretrained model so we can observe meaningful output during inference. We will be using the autoregressive_sample_stream() decoding method from Trax to do fast inference. Let’s define a few parameters to initialize our model.
# define the `predict_mem_len` and `predict_drop_len` of tl.SelfAttention
def attention(*args, **kwargs):
# number of input positions to remember in a cache when doing fast inference.
'predict_mem_len'] = 120
kwargs[# number of input elements to drop once the fast inference input cache fills up.
'predict_drop_len'] = 120
kwargs[# return the attention layer with the parameters defined above
return tl.SelfAttention(*args, **kwargs)
# define the model using the ReformerLM function you implemented earlier.
= ReformerLM(
model =33000,
vocab_size=6,
n_layers='predict',
mode=attention,
attention_type
)
# define an input signature so we can initialize our model. shape will be (1, 1) and the data type is int32.
= trax.shapes.ShapeDtype((1, 1), dtype=np.int32) shape11
We can now initialize our model from a file containing the pretrained weights. We will save this starting state so we can reset the model state when we generate a new conversation. This will become clearer in the generate_dialogue()
function later.
# initialize from file
'chatbot_model1.pkl.gz',
model.init_from_file(=True, input_signature=shape11)
weights_only
# save the starting state
= model.state STARTING_STATE
Let’s define a few utility functions as well to help us tokenize and detokenize. We can use the tokenize() and detokenize() from trax.data.tf_inputs
to do this.
def tokenize(sentence, vocab_file, vocab_dir):
return list(trax.data.tokenize(iter([sentence]), vocab_file=vocab_file, vocab_dir=vocab_dir))[0]
def detokenize(tokens, vocab_file, vocab_dir):
return trax.data.detokenize(tokens, vocab_file=vocab_file, vocab_dir=vocab_dir)
We are now ready to define our decoding function. This will return a generator that yields that next symbol output by the model. It will be able to predict the next words by just feeding it a starting sentence.
6.1 ReformerLM_output_gen
We will implement the function below to return a generator that predicts the next word of the conversation.
def ReformerLM_output_gen(ReformerLM, start_sentence, vocab_file, vocab_dir, temperature, tokenize=tokenize):
"""
Args:
ReformerLM: the Reformer language model you just trained
start_sentence (string): starting sentence of the conversation
vocab_file (string): vocabulary filename
vocab_dir (string): directory of the vocabulary file
temperature (float): parameter for sampling ranging from 0.0 to 1.0.
0.0: same as argmax, always pick the most probable token
1.0: sampling from the distribution (can sometimes say random things)
Returns:
generator: yields the next symbol generated by the model
"""
# Create input tokens using the the tokenize function
= tokenize(start_sentence, vocab_file=vocab_file, vocab_dir=vocab_dir)
input_tokens
# Add batch dimension to array. Convert from (n,) to (x, n) where
# x is the batch size. Default is 1. (hint: you can use np.expand_dims() with axis=0)
= np.array(input_tokens)[None, :]
input_tokens_with_batch
# call the autoregressive_sample_stream function from trax
= trax.supervised.decoding.autoregressive_sample_stream(
output_gen # model
ReformerLM,# inputs will be the tokens with batch dimension
=input_tokens_with_batch,
inputs# temperature
=temperature
temperature
)
return output_gen
Now we will be able to see the model in action. The utility function below will call the generator we just implemented and will just format the output to be easier to read.
= trax.shapes.ShapeDtype((1, 1), dtype=np.int32)
shape11
def attention(*args, **kwargs):
'predict_mem_len'] = 120 # max length for predictions
kwargs['predict_drop_len'] = 120 # never drop old stuff
kwargs[return tl.SelfAttention(*args, **kwargs)
= ReformerLM(
model =33000,
vocab_size=6,
n_layers='predict',
mode=attention,
attention_type )
'chatbot_model1.pkl.gz',
model.init_from_file(=True, input_signature=shape11)
weights_only
= model.state STARTING_STATE
def generate_dialogue(ReformerLM, model_state, start_sentence, vocab_file, vocab_dir, max_len, temperature):
"""
Args:
ReformerLM: the Reformer language model you just trained
model_state (np.array): initial state of the model before decoding
start_sentence (string): starting sentence of the conversation
vocab_file (string): vocabulary filename
vocab_dir (string): directory of the vocabulary file
max_len (int): maximum number of tokens to generate
temperature (float): parameter for sampling ranging from 0.0 to 1.0.
0.0: same as argmax, always pick the most probable token
1.0: sampling from the distribution (can sometimes say random things)
Returns:
generator: yields the next symbol generated by the model
"""
# define the delimiters we used during training
= 'Person 1: '
delimiter_1 = 'Person 2: '
delimiter_2
# initialize detokenized output
= ''
sentence
# token counter
= 0
counter
# output tokens. we insert a ': ' for formatting
= [tokenize(': ', vocab_file=vocab_file, vocab_dir=vocab_dir)]
result
# reset the model state when starting a new dialogue
= model_state
ReformerLM.state
# calls the output generator implemented earlier
= ReformerLM_output_gen(ReformerLM, start_sentence, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR, temperature=temperature)
output
# print the starting sentence
print(start_sentence.split(delimiter_2)[0].strip())
# loop below yields the next tokens until max_len is reached. the if-elif is just for prettifying the output.
for o in output:
result.append(o)
= detokenize(np.concatenate(result, axis=0), vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)
sentence
if sentence.endswith(delimiter_1):
= sentence.split(delimiter_1)[0]
sentence print(f'{delimiter_2}{sentence}')
= ''
sentence
result.clear()
elif sentence.endswith(delimiter_2):
= sentence.split(delimiter_2)[0]
sentence print(f'{delimiter_1}{sentence}')
= ''
sentence
result.clear()
+= 1
counter
if counter > max_len:
break
We can now feed in different starting sentences and see how the model generates the dialogue. We can even input our own starting sentence. Just remember to ask a question that covers the topics in the Multiwoz dataset so you can generate a meaningful conversation.
= ' Person 1: Are there theatres in town? Person 2: '
sample_sentence =model, model_state=STARTING_STATE, start_sentence=sample_sentence, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR, max_len=120, temperature=0.2) generate_dialogue(ReformerLM
Person 1: Are there theatres in town?
Person 2: : There are 4 theatres in town. Do you have a specific area in mind?
Person 1: No, I don't have a preference. Which one do you recommend?
Person 2: I would recommend the Mumford Theatre. Would you like their phone number?
Person 1: Yes, please. I would also like to find a train to cambridge on thursday.
Person 1: There are 202 trains that meet your criteria. Do you have a specific you would like to go to a cinema?
= ' Person 1: Is there a hospital nearby? Person 2: '
sample_sentence =model, model_state=STARTING_STATE, start_sentence=sample_sentence, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR, max_len=120, temperature=0.2) generate_dialogue(ReformerLM
Person 1: Is there a hospital nearby?
Person 2: : Addensbrookes Hospital is located at Hills Rd, Cambridge, postcode CB20QQ. Do you need the phone number?
Person 1: No, that's all I needed. Thank you.
Person 2: You're welcome. Have a good day.m.Thanks for contacting the Cambridge TownInfo centre. Goodbye.
Person 1: Thank you for your help.
Person 1: You're welcome. Have a good day.I can find something.
= ' Person 1: Can you book a taxi? Person 2: '
sample_sentence =model, model_state=STARTING_STATE, start_sentence=sample_sentence, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR, max_len=120, temperature=0.2) generate_dialogue(ReformerLM
Person 1: Can you book a taxi?
Person 2: : I sure can. When would you like to arrive?
Person 1: I need to leave after 13:00.
Person 2: I'm sorry, but I'm not able to book that for you. Would you like to try a different time?
Person 1: Yes, let's try for 13:00.
Person 2: I was able to book you a table for 1 at 13:00 on Saturday. Your reference number is YYYOOO
7 Acknowledgements
I’d like to express my thanks to the great Natural Language Processing with Attention Models Course which i completed, and acknowledge the use of some images and other materials from the course in this article.