import sys
import os
import w2_tests
import numpy as np
import textwrap
= textwrap.TextWrapper(width=70)
wrapper
import trax
from trax import layers as tl
from trax.fastmath import numpy as jnp
# to print the entire np array
=sys.maxsize) np.set_printoptions(threshold
1 Introduction
In an earlier article we created a transformer decoder model the same kind used to create the famous GPT-2. In this article we will explore summarization using a transfomer decoder model.
Summarization is an important task in natural language processing and could be useful for a number of businesses and use cases. For example, bots can be used to scrape articles, summarize them, and then you can use sentiment analysis to identify the sentiment about certain stocks. Why always read an article or a long email today, when you can build a transformer to summarize text for you.
In this project we will:
- Use built-in functions to preprocess data
- Implement DotProductAttention
- Implement Causal Attention
- Understand how attention works
- Build the transformer model
- Evaluate your model
- Summarize an article
This model is slightly different than the ones we have looked at previously. This is heavily based on attention and does not rely on sequences, which allows for parallel computing.
2 Import Libraries
3 Importing the dataset
The Trax library makes it easy to work with Tensorflow’s datasets:
# This will download the dataset if no data_dir is specified.
# Downloading and processing can take bit of time,
# So I have the data already in 'data/'
# Importing CNN/DailyMail articles dataset
= trax.data.TFDS('cnn_dailymail',
train_stream_fn ='data/',
data_dir=('article', 'highlights'),
keys=True)
train
# This should be much faster as the data is downloaded already.
= trax.data.TFDS('cnn_dailymail',
eval_stream_fn ='data/',
data_dir=('article', 'highlights'),
keys=False) train
3.1 Tokenize & Detokenize helper functions
The cell above loads in the encoder for us. Given any data set, we have to be able to map words to their indices, and indices to their words. The inputs and outputs to your Trax models are usually tensors of numbers where each number corresponds to a word. If we were to process your data manually, we would have to make use of the following:
- word2Ind: a dictionary mapping the word to its index.
- ind2Word: a dictionary mapping the index to its word.
- word2Count: a dictionary mapping the word to the number of times it appears.
- num_words: total number of words that have appeared.
We have created helper functions to simplify this process.
- tokenize: converts a text sentence to its corresponding token list (i.e. list of indices). Also converts words to subwords.
- detokenize: converts a token list to its corresponding sentence (i.e. string).
def tokenize(input_str, EOS=1):
"""Input str to features dict, ready for inference"""
# Use the trax.data.tokenize method. It takes streams and returns streams,
# we get around it by making a 1-element stream with `iter`.
= next(trax.data.tokenize(iter([input_str]),
inputs ='vocab_dir/',
vocab_dir='summarize32k.subword.subwords'))
vocab_file
# Mark the end of the sentence with EOS
return list(inputs) + [EOS]
def detokenize(integers):
"""List of ints to str"""
= trax.data.detokenize(integers,
s ='vocab_dir/',
vocab_dir='summarize32k.subword.subwords')
vocab_file
return wrapper.fill(s)
3.2 Preprocessing for Language Models: Concatenate It!
So we will use a language model – Transformer Decoder – to solve an input-output problem. Language models only predict the next word, they have no notion of inputs. To create a single input suitable for a language model, we concatenate inputs with targets putting a separator in between.
We also need to create a mask – with 0s at inputs and 1s at targets – so that the model is not penalized for mis-predicting the article and only focuses on the summary.
# Special tokens
= 0 # Padding or separator token
SEP = 1 # End of sentence token
EOS
# Concatenate tokenized inputs and targets using 0 as separator.
def preprocess(stream):
for (article, summary) in stream:
= np.array(list(article) + [EOS, SEP] + list(summary) + [EOS])
joint = [0] * (len(list(article)) + 2) + [1] * (len(list(summary)) + 1) # Accounting for EOS and SEP
mask yield joint, joint, np.array(mask)
# We can combine a few data preprocessing steps into a pipeline like this.
= trax.data.Serial(
input_pipeline # Tokenizes
='vocab_dir/',
trax.data.Tokenize(vocab_dir='summarize32k.subword.subwords'),
vocab_file# Uses function defined above
preprocess,# Filters out examples longer than 2048
2048)
trax.data.FilterByLength(
)
# Apply preprocessing to data streams.
= input_pipeline(train_stream_fn())
train_stream = input_pipeline(eval_stream_fn())
eval_stream
= next(train_stream)
train_input, train_target, train_mask
assert sum((train_input - train_target)**2) == 0 # They are the same in Language Model (LM).
# prints mask, 0s on article, 1s on summary
print(f'Single example mask:\n\n {train_mask}')
Single example mask:
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
# prints: [Example][<EOS>][<pad>][Example Summary][<EOS>]
print(f'Single example:\n\n {detokenize(train_input)}')
Single example:
By . Associated Press . PUBLISHED: . 14:11 EST, 25 October 2013 . | .
UPDATED: . 15:36 EST, 25 October 2013 . The bishop of the Fargo
Catholic Diocese in North Dakota has exposed potentially hundreds of
church members in Fargo, Grand Forks and Jamestown to the hepatitis A
virus in late September and early October. The state Health Department
has issued an advisory of exposure for anyone who attended five
churches and took communion. Bishop John Folda (pictured) of the Fargo
Catholic Diocese in North Dakota has exposed potentially hundreds of
church members in Fargo, Grand Forks and Jamestown to the hepatitis A
. State Immunization Program Manager Molly Howell says the risk is
low, but officials feel it's important to alert people to the possible
exposure. The diocese announced on Monday that Bishop John Folda is
taking time off after being diagnosed with hepatitis A. The diocese
says he contracted the infection through contaminated food while
attending a conference for newly ordained bishops in Italy last month.
Symptoms of hepatitis A include fever, tiredness, loss of appetite,
nausea and abdominal discomfort. Fargo Catholic Diocese in North
Dakota (pictured) is where the bishop is located .<EOS><pad>BishopJohn
Folda, of North Dakota, is taking time off after being diagnosed . He
contracted the infection through contaminated food in Italy . Church
members in Fargo, Grand Forks and Jamestown could have been exposed
.<EOS>
3.3 Batching with bucketing
We use bucketing to create batches of data.
# Bucketing to create batched generators.
# Buckets are defined in terms of boundaries and batch sizes.
# Batch_sizes[i] determines the batch size for items with length < boundaries[i]
# So below, we'll take a batch of 16 sentences of length < 128 , 8 of length < 256,
# 4 of length < 512. And so on.
= [128, 256, 512, 1024]
boundaries = [16, 8, 4, 2, 1]
batch_sizes
# Create the streams.
= trax.data.BucketByLength(
train_batch_stream
boundaries, batch_sizes)(train_stream)
= trax.data.BucketByLength(
eval_batch_stream boundaries, batch_sizes)(eval_stream)
# Every execution will result in generation of a different article
# We can try running this cell multiple times to see how the length of the examples affects the batch size
= next(train_batch_stream)
input_batch, _, mask_batch
# Shape of the input_batch
input_batch.shape
(1, 1201)
# print corresponding integer values
print(input_batch[0])
[ 27 23176 4694 1779 1343 28 506 1091 132 28 570 6
78 7124 192 14454 15 3570 2067 23 46 26133 17 1019
635 91 3 5349 23421 494 6 10487 2 728 2 1353
3156 278 1838 28 736 809 28 13481 7511 22 625 28
1311 2396 3 187 22 1353 1510 181 16146 1049 320 103
2 22 26563 651 467 213 826 192 3156 1262 28 13131
4 186 16949 17 71 12319 6604 828 29725 4 5 1081
1083 213 54 138 3 5349 23421 494 6 10487 2 728
8 346 12 1353 354 15 3570 2067 7511 22 24497 570
6 78 71 213 1081 144 3360 691 12319 6604 828 2
705 8 231 24 305 710 272 1838 68 6341 379 9
570 6 78 7124 436 219 132 560 429 3 368 23421
494 6 10487 7 5 1081 1353 10874 20919 217 8 12370
21 12 2713 127 23421 494 6 10487 40 23176 809 518
150 181 290 3892 275 527 8947 171 1269 936 213 9025
3 69 1353 233 8272 527 6056 583 691 4398 3156 809
14507 5429 812 7356 3 3622 6604 828 2 28 705 6
104 6 292 15004 181 29725 4 5 21961 1838 10687 45
2 11985 527 11907 5364 2 40 43 1383 213 2801 1248
1078 809 28 13481 35 40 19 23176 116 4016 2 864
127 3 305 1353 3156 17775 12979 3095 186 77 1353 669
27439 6050 13459 1628 1290 131 143 18 757 320 2501 213
25725 29725 2 41 969 3 16978 1822 9855 1962 2 17347
16 2 127 4601 27439 6050 13459 1628 5349 23421 494 6
10487 29725 4 5 3156 2868 132 213 15191 583 527 28
506 1091 2 12319 6604 828 2 28 583 285 143 18
46 13488 23707 6050 13459 1628 368 23421 494 6 10487 436
213 884 320 3429 61 15 3570 2067 6715 3156 186 2
673 1510 181 16146 1049 320 824 1311 2396 2 1353 90
15438 17 285 22 2214 320 17950 28 346 6 650 13131
4 2 7228 213 1052 763 314 71 213 2358 527 3622
6604 828 29725 4 5 18352 2398 1081 3 3622 6604 828
1353 7214 213 19839 277 527 68 27439 9275 1628 12320 5403
9242 5590 2385 35 710 272 1838 68 6341 132 2642 11969
27439 6050 13459 1628 3622 6604 828 669 27884 4 40 27872
391 28 5302 531 2504 527 68 3 305 1353 43 4925
278 523 1383 163 20812 2801 1248 1078 186 1353 3156 17775
12979 3095 23707 6050 13459 1628 305 40 5945 320 1242 68
1078 7511 131 540 278 320 8916 285 131 40 2362 15627
3 1561 1078 8075 114 369 1613 1838 68 102 41 7584
17 458 23707 6050 13459 1628 3622 6604 828 29725 4 5
583 132 97 2861 6107 17946 5 213 6349 527 354 28
650 6 475 3570 2067 6715 3156 4172 29725 391 2713 25
3630 320 245 17388 181 1884 4140 1838 23421 494 6 10487
1820 2 35 132 4140 329 926 102 213 5556 22 1353
86 25070 918 155 213 6700 6 2057 3602 3 9 4038
2256 1248 864 285 22 62 18 46 95 213 3602 809
213 55 15 651 6866 4604 279 1205 3622 6604 828 29725
4 5 2498 12320 5403 9242 5590 2385 78 28 826 542
15902 3569 2 11985 527 11907 5364 2 78 560 253 2
429 3 405 2067 992 1606 22 1353 43 17997 595 239
213 55 527 213 7124 3 6753 1565 8120 479 2 1838
12887 26509 21380 328 29725 4 5 1839 25725 2694 1676 2
127 3611 871 5784 1435 1248 12319 7 5 228 809 824
55 3 305 40 46 64 1248 1078 809 28 13481 132
15010 7301 285 2801 2 35 40 19 40 116 4016 1782
871 2694 1606 285 77 1353 1290 131 143 18 757 320
2501 213 25725 186 8075 114 103 919 68 68 177 1782
368 23421 494 6 10487 40 346 126 132 15902 3569 186
1326 1248 1078 809 28 13481 4872 22 6005 6929 809 518
150 320 290 3892 275 527 7468 81 3 69 12402 7
26 209 346 213 13481 320 955 278 7511 213 25725 1841
809 239 128 10 3229 2535 1782 129 8198 7 26 217
320 245 17388 181 1884 4140 1838 134 1820 186 849 1884
576 329 926 102 213 25725 1606 22 1353 25070 918 155
213 3602 2 51 2253 22 62 18 46 95 213 3602
809 213 55 527 213 25725 186 132 13040 2398 61 592
2 213 4038 2256 1782 9 641 527 15 2067 992 1606
285 22 1353 17997 595 78 15 2067 239 213 55 527
213 25725 90 103 7 5 1232 761 824 62 43 18
3625 320 15 4398 3156 186 1201 527 490 2002 23421 494
6 10487 1353 233 8272 527 6056 583 691 4398 3156 355
28 2145 809 14507 5429 812 8 12370 21 12 69 969
3611 368 23421 494 6 10487 39 169 3263 635 91 936
5892 2 35 12319 7 5 228 18 913 68 8232 1782
13 1525 824 39 191 101 362 3060 171 6642 116 4016
186 1269 936 213 9025 2 181 354 28 2067 640 41
7 165 78 213 826 1782 9 26024 527 6700 3156 186
3156 6715 354 28 3570 2067 1435 3787 3 2994 1779 952
320 124 90 993 3736 28 3537 55 132 2173 3 56
347 6335 141 7270 15191 213 4472 527 16972 595 97 23891
6412 49 1151 20327 27439 6050 13459 1628 368 23421 494 6
10487 39 169 3263 635 91 936 5892 2 35 12319 29725
4 5 228 18 913 68 1019 545 3 13 1525 824
39 191 101 362 3060 171 6642 116 4016 186 1269 936
213 9025 2 181 354 28 2067 640 41 29725 4 165
78 213 826 3 56 347 6335 141 7270 15191 213 4472
527 16972 595 97 23891 6412 49 1151 4172 29725 391 23421
494 6 10487 2 527 14735 2 11985 527 11907 5364 2
1353 43 24306 5831 4461 1838 3156 1019 1223 91 27439 9275
1628 102 1480 22 39 18 320 976 163 2008 165 6
1166 10 1 0 5349 23421 494 6 10487 2 728 2
40 23176 809 518 150 3892 275 171 3156 1081 16346 27439
6774 1628 5670 354 2067 7511 22 26563 651 467 826 132
15902 3569 2 11985 527 11907 5364 16346 27439 6774 1628 3481
3094 570 6 78 71 705 6 104 6 292 12319 6604
828 7 5 1081 2 1779 710 132 2642 16346 27439 6774
1628 2713 476 22 62 18 46 95 904 6700 6 2057
3602 809 55 527 7124 16346 27439 6774 1628 69 1353 233
8272 809 14507 5429 812 527 6056 583 691 4398 3156 2104
1]
Things to notice: - First we see the corresponding values of the words. - The first 1, which represents the <EOS>
tag of the article. - Followed by a 0, which represents a <pad>
tag. - After the first 0 (<pad>
tag) the corresponding values are of the words that are used for the summary of the article. - The second 1 represents the <EOS>
tag for the summary. - All the trailing 0s represent <pad>
tags which are appended to maintain consistent length (If you don’t see them then it would mean it is already of max length)
# print the article and its summary
print('Article:\n\n', detokenize(input_batch[0]))
Article:
A drunk driver who killed a young woman in a head-on crash while
checking his mobile phone has been jailed for six years. Craig
Eccleston-Todd, 27, was driving home from a night at a pub when he
received a text message. As he was reading or replying to it, he
veered across the road while driving round a bend and smashed into
Rachel Titley’s car coming the other way. Craig Eccleston-Todd, 27
(left) was using his mobile phone when he crashed head-on into the car
being driven by Rachel Titley, 28 (right). She died later from her
injuries . The head-on crash took place in October 2013. Mr Eccleston-
Todd's car was barely recognisable (pictured) Police said Eccleston-
Todd had drunk at least three or four pints of beer before getting
behind the wheel. He was found guilty of causing death by dangerous
driving at Portsmouth Crown Court yesterday. Miss Titley, a 28-year-
old solicitor’s clerk from Cowes, Isle of Wight, had also spent the
evening with friends at a pub but had not drunk any alcohol, police
said. She was driving responsibly and there was ‘nothing she could
have done to avoid the collision’, they added. Lindsay Pennell,
prosecuting, said: ‘Craig Eccleston-Todd’s driving resulted in the
tragic death of a young woman, Rachel Titley, a death that could have
been avoided. ‘Mr Eccleston-Todd took the decision to pick up his
mobile phone whilst driving and, either reading or replying to this
text message, was so distracted that he failed to negotiate a left-
hand bend, crossing the central white line into the path of Miss
Titley’s oncoming car. Miss Titley was pulled the wreckage of
her Daihatsu Cuore but died later from her injuries in hospital .
‘Miss Titley [had] a bright future ahead of her. She was also
returning home having spent an enjoyable evening with friends and was
driving responsibly. ‘She had arranged to contact her friends when she
got home to confirm that she had arrived safely. Her friends sadly
never heard from her after they parted company. ‘Miss Titley’s death
in these circumstances reiterates the danger of using a hand-held
mobile phone whilst driving.’ Police were unable to take breath or
blood tests from Eccleston-Todd immediately, but in tests several
hours after the accident he was only marginally under the drink-drive
limit. The judge agreed with police that he would have been over the
limit at the time his red Citroen hit Miss Titley’s blue Daihatsu
Cuore on a road near Yarmouth, Isle of Wight, on October 11, 2013. His
phone records showed he was also texting around the time of the crash.
PC Mark Furse, from Hampshire constabulary’s serious collision
investigation unit, said: 'Our thoughts are with Rachel's family at
this time. She had been out with friends at a pub in Shalfleet that
evening, but had not had any alcohol. 'Our investigation showed that
there was nothing she could have done to avoid the collision and sadly
it cost her her life. 'Mr Eccleston-Todd had left work in Yarmouth and
met with friends at a pub where he drank at least three to four pints
of lager. He hadn't long left the pub to return home when the
collision occurred at around 9.30pm. 'We weren't able to take breath
or blood tests from him immediately and although blood taken several
hours after the collision showed he was marginally under the limit, we
maintain he would have been over the limit at the time of the
collision and in summing up today, the judge agreed. 'The analysis of
his phone records showed that he was texting on his phone around the
time of the collision so it's highly likely this would also have
contributed to his dangerous driving and loss of control.' Eccleston-
Todd was found guilty of causing death by dangerous driving following
a trial at Portsmouth Crown Court (pictured) He added: 'Mr Eccleston-
Todd will now spend six years behind bars, but Rachel's family have
lost her forever. 'I hope this will make people think twice before
drinking any alcohol and getting behind the wheel, or using a phone
once they're on the road. 'The dangers of drink driving and driving
whilst using a mobile phone are obvious. Those who continue to do so
risk spending a substantial time in prison. This case highlights just
how tragic the consequences of committing these offences can be.' ‘Mr
Eccleston-Todd will now spend six years behind bars, but Rachel’s
family have lost her for ever. I hope this will make people think
twice before drinking any alcohol and getting behind the wheel, or
using a phone once they’re on the road. This case highlights just how
tragic the consequences of committing these offences can be.’
Eccleston-Todd, of Newport, Isle of Wight, was also disqualified from
driving for eight years after which he will have to complete an
extended re-test.<EOS><pad>CraigEccleston-Todd, 27, had drunk at least
three pints before driving car . Was using phone when he veered across
road in Yarmouth, Isle of Wight . Crashed head-on into 28-year-old
Rachel Titley's car, who died in hospital . Police say he would have
been over legal drink-drive limit at time of crash . He was found
guilty at Portsmouth Crown Court of causing death by dangerous driving
.<EOS>
We can see that the data has the following structure: - [Article] -> <EOS>
-> <pad>
-> [Article Summary] -> <EOS>
-> (possibly) multiple <pad>
The loss is taken only on the summary using cross_entropy as loss function.
4 Summarization with transformer
Now that we have the data generator and have handled the preprocessing, it is time to build our model.
We will be implementing the attention from scratch and then using it in our transformer model. Concretely, we will understand how attention works, and how we use it to connect the encoder and the decoder.
4.1 Dot product attention
Now we will implement dot product attention which takes in a query, key, value, and a mask. It returns the output.
These are some helper functions that will help create tensors and display useful information: - create_tensor
creates a jax numpy array
from a list of lists. - display_tensor
prints out the shape and the actual tensor.
def create_tensor(t):
"""Create tensor from list of lists"""
return jnp.array(t)
def display_tensor(t, name):
"""Display shape and tensor"""
print(f'{name} shape: {t.shape}\n')
print(f'{t}\n')
Before implementing, we can play around with a toy example of dot product attention
without the softmax operation. Technically it would not be dot product attention
without the softmax but this is done to avoid giving away too much of the answer and the idea is to display these tensors to give you a sense of how they look like.
The formula for attention is this one:
\[ \text { Attention }(Q, K, V)=\operatorname{softmax}\left(\frac{Q K^{T}}{\sqrt{d_{k}}}+{M}\right) V\tag{1}\ \]
\(d_{k}\) stands for the dimension of queries and keys.
The query
, key
, value
and mask
vectors are provided for this example.
Notice that the masking is done using very negative values that will yield a similar effect to using $-$.
= create_tensor([[1, 0, 0], [0, 1, 0]])
q 'query')
display_tensor(q, = create_tensor([[1, 2, 3], [4, 5, 6]])
k 'key')
display_tensor(k, = create_tensor([[0, 1, 0], [1, 0, 1]])
v 'value')
display_tensor(v, = create_tensor([[0, 0], [-1e9, 0]])
m 'mask') display_tensor(m,
query shape: (2, 3)
[[1 0 0]
[0 1 0]]
key shape: (2, 3)
[[1 2 3]
[4 5 6]]
value shape: (2, 3)
[[0 1 0]
[1 0 1]]
mask shape: (2, 2)
[[ 0.e+00 0.e+00]
[-1.e+09 0.e+00]]
= q @ k.T / jnp.sqrt(3)
q_dot_k 'query dot key') display_tensor(q_dot_k,
query dot key shape: (2, 2)
[[0.57735026 2.309401 ]
[1.1547005 2.8867514 ]]
= q_dot_k + m
masked 'masked query dot key') display_tensor(masked,
masked query dot key shape: (2, 2)
[[ 5.7735026e-01 2.3094010e+00]
[-1.0000000e+09 2.8867514e+00]]
@ v, 'masked query dot key dot value') display_tensor(masked
masked query dot key dot value shape: (2, 3)
[[ 2.3094010e+00 5.7735026e-01 2.3094010e+00]
[ 2.8867514e+00 -1.0000000e+09 2.8867514e+00]]
In order to use the previous dummy tensors to test some of the graded functions, a batch dimension should be added to them so they mimic the shape of real-life examples. The mask is also replaced by a version of it that resembles the one that is used by trax:
= q[None,:]
q_with_batch 'query with batch dim')
display_tensor(q_with_batch, = k[None,:]
k_with_batch 'key with batch dim')
display_tensor(k_with_batch, = v[None,:]
v_with_batch 'value with batch dim')
display_tensor(v_with_batch, = create_tensor([[True, True], [False, True]])
m_bool 'boolean mask') display_tensor(m_bool,
query with batch dim shape: (1, 2, 3)
[[[1 0 0]
[0 1 0]]]
key with batch dim shape: (1, 2, 3)
[[[1 2 3]
[4 5 6]]]
value with batch dim shape: (1, 2, 3)
[[[0 1 0]
[1 0 1]]]
boolean mask shape: (2, 2)
[[ True True]
[False True]]
Let’s now implement the dot product attention. Concretely, we will implement the following equation
\[ \text { Attention }(Q, K, V)=\operatorname{softmax}\left(\frac{Q K^{T}}{\sqrt{d_{k}}}+{M}\right) V\tag{1}\ \]
\(Q\) - query, \(K\) - key, \(V\) - values, \(M\) - mask, \({d_k}\) - depth/dimension of the queries and keys (used for scaling down)
def DotProductAttention(query, key, value, mask):
"""Dot product self-attention.
Args:
query (jax.interpreters.xla.DeviceArray): array of query representations with shape (L_q by d)
key (jax.interpreters.xla.DeviceArray): array of key representations with shape (L_k by d)
value (jax.interpreters.xla.DeviceArray): array of value representations with shape (L_k by d) where L_v = L_k
mask (jax.interpreters.xla.DeviceArray): attention-mask, gates attention with shape (L_q by L_k)
Returns:
jax.interpreters.xla.DeviceArray: Self-attention array for q, k, v arrays. (L_q by d)
"""
assert query.shape[-1] == key.shape[-1] == value.shape[-1], "Embedding dimensions of q, k, v aren't all the same"
# Save depth/dimension of the query embedding for scaling down the dot product
= query.shape[-1]
depth
# Calculate scaled query key dot product according to formula above
= jnp.matmul(query, jnp.swapaxes(key, -1, -2)) / jnp.sqrt(depth)
dots
# Apply the mask
if mask is not None: # You do not need to replace the 'None' on this line
= jnp.where(mask, dots, jnp.full_like(dots, -1e9))
dots
# Softmax formula implementation
# We use trax.fastmath.logsumexp of masked_qkT to avoid underflow by division by large numbers
= trax.fastmath.logsumexp(dots, axis=-1, keepdims=True)
logsumexp
# Take exponential of dots minus logsumexp to get softmax
= jnp.exp(dots - logsumexp)
dots
# Multiply dots by value to get self-attention
= jnp.matmul(dots, value)
attention
return attention
DotProductAttention(q_with_batch, k_with_batch, v_with_batch, m_bool)
DeviceArray([[[0.8496746 , 0.15032545, 0.8496746 ],
[1. , 0. , 1. ]]], dtype=float32)
4.2 Causal Attention
Now we are going to implement causal attention: multi-headed attention with a mask to attend only to words that occurred before.
In the image above, a word can see everything that is before it, but not what is after it. To implement causal attention, we will have to transform vectors and do many reshapes.
We will implement the following functions that will be needed for Causal Attention:
- compute_attention_heads : Gets an input \(x\) of dimension (n_batch, seqlen, n_heads \(\times\) d_head) and splits the last (depth) dimension and stacks it to the zeroth dimension to allow matrix multiplication (n_batch \(\times\) n_heads, seqlen, d_head).
- dot_product_self_attention : Creates a mask matrix with
False
values above the diagonal andTrue
values below and calls DotProductAttention which implements dot product self attention. - compute_attention_output : Undoes compute_attention_heads by splitting first (vertical) dimension and stacking in the last (depth) dimension (n_batch, seqlen, n_heads \(\times\) d_head). These operations concatenate (stack/merge) the heads.
We use some toy tensors which gives us an idea of the data shapes and opperations involved in Causal Attention. They are also useful to test out our functions!
= create_tensor(q)
tensor2d 'query matrix (2D tensor)')
display_tensor(tensor2d,
= create_tensor([[q, q], [q, q]])
tensor4d2b 'batch of two (multi-head) collections of query matrices (4D tensor)')
display_tensor(tensor4d2b,
= create_tensor([jnp.concatenate([q, q], axis = -1)])
tensor3dc 'one batch of concatenated heads of query matrices (3d tensor)')
display_tensor(tensor3dc,
= create_tensor([jnp.concatenate([q, q], axis = -1), jnp.concatenate([q, q], axis = -1), jnp.concatenate([q, q], axis = -1)])
tensor3dc3b 'three batches of concatenated heads of query matrices (3d tensor)') display_tensor(tensor3dc3b,
query matrix (2D tensor) shape: (2, 3)
[[1 0 0]
[0 1 0]]
batch of two (multi-head) collections of query matrices (4D tensor) shape: (2, 2, 2, 3)
[[[[1 0 0]
[0 1 0]]
[[1 0 0]
[0 1 0]]]
[[[1 0 0]
[0 1 0]]
[[1 0 0]
[0 1 0]]]]
one batch of concatenated heads of query matrices (3d tensor) shape: (1, 2, 6)
[[[1 0 0 1 0 0]
[0 1 0 0 1 0]]]
three batches of concatenated heads of query matrices (3d tensor) shape: (3, 2, 6)
[[[1 0 0 1 0 0]
[0 1 0 0 1 0]]
[[1 0 0 1 0 0]
[0 1 0 0 1 0]]
[[1 0 0 1 0 0]
[0 1 0 0 1 0]]]
It is important to know that the following 3 functions would normally be defined within the CausalAttention
function further below.
However this makes these functions harder to test. Because of this, these functions are shown individually using a closure
(when necessary) that simulates them being inside of the CausalAttention
function. This is done because they rely on some variables that can be accessed from within CausalAttention
.
4.3 Support Functions
compute_attention_heads : Gets an input \(x\) of dimension (n_batch, seqlen, n_heads \(\times\) d_head) and splits the last (depth) dimension and stacks it to the zeroth dimension to allow matrix multiplication (n_batch \(\times\) n_heads, seqlen, d_head).
def compute_attention_heads_closure(n_heads, d_head):
""" Function that simulates environment inside CausalAttention function.
Args:
d_head (int): dimensionality of heads
n_heads (int): number of attention heads
Returns:
function: compute_attention_heads function
"""
def compute_attention_heads(x):
""" Compute the attention heads.
Args:
x (jax.interpreters.xla.DeviceArray): tensor with shape (n_batch, seqlen, n_heads X d_head).
Returns:
jax.interpreters.xla.DeviceArray: reshaped tensor with shape (n_batch X n_heads, seqlen, d_head).
"""
# Size of the x's batch dimension
= x.shape[0]
batch_size # Length of the sequence
# Should be size of x's first dimension without counting the batch dim
= x.shape[1]
seqlen # Reshape x using jnp.reshape()
# n_batch, seqlen, n_heads*d_head -> n_batch, seqlen, n_heads, d_head
= jnp.reshape(x, (batch_size, seqlen, n_heads, d_head))
x # Transpose x using jnp.transpose()
# n_batch, seqlen, n_heads, d_head -> n_batch, n_heads, seqlen, d_head
# Note that the values within the tuple are the indexes of the dimensions of x and we must rearrange them
= jnp.transpose(x, (0, 2, 1, 3))
x # Reshape x using jnp.reshape()
# n_batch, n_heads, seqlen, d_head -> n_batch*n_heads, seqlen, d_head
= jnp.reshape(x, (batch_size*n_heads, seqlen, d_head))
x
return x
return compute_attention_heads
"input tensor")
display_tensor(tensor3dc3b, = compute_attention_heads_closure(2,3)(tensor3dc3b)
result_cah "output tensor") display_tensor(result_cah,
input tensor shape: (3, 2, 6)
[[[1 0 0 1 0 0]
[0 1 0 0 1 0]]
[[1 0 0 1 0 0]
[0 1 0 0 1 0]]
[[1 0 0 1 0 0]
[0 1 0 0 1 0]]]
output tensor shape: (6, 2, 3)
[[[1 0 0]
[0 1 0]]
[[1 0 0]
[0 1 0]]
[[1 0 0]
[0 1 0]]
[[1 0 0]
[0 1 0]]
[[1 0 0]
[0 1 0]]
[[1 0 0]
[0 1 0]]]
dot_product_self_attention : Creates a mask matrix with False
values above the diagonal and True
values below and calls DotProductAttention which implements dot product self attention.
def dot_product_self_attention(q, k, v):
""" Masked dot product self attention.
Args:
q (jax.interpreters.xla.DeviceArray): queries.
k (jax.interpreters.xla.DeviceArray): keys.
v (jax.interpreters.xla.DeviceArray): values.
Returns:
jax.interpreters.xla.DeviceArray: masked dot product self attention tensor.
"""
# Mask size should be equal to L_q. Q has shape (batch_size, L_q, d)
= q.shape[1]
mask_size
# Creates a matrix with ones below the diagonal and 0s above. It should have shape (1, mask_size, mask_size)
# Notice that 1's and 0's get casted to True/False by setting dtype to jnp.bool_
= jnp.tril(jnp.ones((1, mask_size, mask_size), dtype=jnp.bool_), k=0)
mask
return DotProductAttention(q, k, v, mask)
dot_product_self_attention(q_with_batch, k_with_batch, v_with_batch)
DeviceArray([[[0. , 1. , 0. ],
[0.8496746 , 0.15032543, 0.8496746 ]]], dtype=float32)
compute_attention_output : Undoes compute_attention_heads by splitting first (vertical) dimension and stacking in the last (depth) dimension (n_batch, seqlen, n_heads \(\times\) d_head). These operations concatenate (stack/merge) the heads.
def compute_attention_output_closure(n_heads, d_head):
""" Function that simulates environment inside CausalAttention function.
Args:
d_head (int): dimensionality of heads
n_heads (int): number of attention heads
Returns:
function: compute_attention_output function
"""
def compute_attention_output(x):
""" Compute the attention output.
Args:
x (jax.interpreters.xla.DeviceArray): tensor with shape (n_batch X n_heads, seqlen, d_head).
Returns:
jax.interpreters.xla.DeviceArray: reshaped tensor with shape (n_batch, seqlen, n_heads X d_head).
"""
# Length of the sequence
# Should be size of x's first dimension without counting the batch dim
= x.shape[1]
seqlen # Reshape x using jnp.reshape() to shape (n_batch, n_heads, seqlen, d_head)
= jnp.reshape(x, (-1, n_heads, seqlen, d_head))
x # Transpose x using jnp.transpose() to shape (n_batch, seqlen, n_heads, d_head)
= jnp.transpose(x, (0,2,1,3))
x
# Reshape to allow to concatenate the heads
return jnp.reshape(x, (-1, seqlen, n_heads * d_head))
return compute_attention_output
"input tensor")
display_tensor(result_cah, = compute_attention_output_closure(2,3)(result_cah)
result_cao "output tensor") display_tensor(result_cao,
input tensor shape: (6, 2, 3)
[[[1 0 0]
[0 1 0]]
[[1 0 0]
[0 1 0]]
[[1 0 0]
[0 1 0]]
[[1 0 0]
[0 1 0]]
[[1 0 0]
[0 1 0]]
[[1 0 0]
[0 1 0]]]
output tensor shape: (3, 2, 6)
[[[1 0 0 1 0 0]
[0 1 0 0 1 0]]
[[1 0 0 1 0 0]
[0 1 0 0 1 0]]
[[1 0 0 1 0 0]
[0 1 0 0 1 0]]]
4.4 Causal Attention Function
Now it is time for us to put everything together within the CausalAttention
or Masked multi-head attention function:
We will implement causal attention. Our model returns the causal attention through a \(tl.Serial\) with the following:
- tl.Branch : consisting of 3 [tl.Dense(d_feature), ComputeAttentionHeads] to account for the queries, keys, and values.
- tl.Fn: Takes in dot_product_self_attention function and uses it to compute the dot product using \(Q\), \(K\), \(V\).
- tl.Fn: Takes in compute_attention_output_closure to allow for parallel computing.
- tl.Dense: Final Dense layer, with dimension
d_feature
.
In order for trax to properly handle the functions we just defined, they need to be added as layers using the tl.Fn()
function.
def CausalAttention(d_feature,
n_heads, =compute_attention_heads_closure,
compute_attention_heads_closure=dot_product_self_attention,
dot_product_self_attention=compute_attention_output_closure,
compute_attention_output_closure='train'):
mode"""Transformer-style multi-headed causal attention.
Args:
d_feature (int): dimensionality of feature embedding.
n_heads (int): number of attention heads.
compute_attention_heads_closure (function): Closure around compute_attention heads.
dot_product_self_attention (function): dot_product_self_attention function.
compute_attention_output_closure (function): Closure around compute_attention_output.
mode (str): 'train' or 'eval'.
Returns:
trax.layers.combinators.Serial: Multi-headed self-attention model.
"""
assert d_feature % n_heads == 0
= d_feature // n_heads
d_head
# The second argument to tl.Fn() is an uncalled function (without the parentheses)
# Since we are dealing with closures we might need to call the outer
# function with the correct parameters to get the actual uncalled function.
= tl.Fn('AttnHeads', compute_attention_heads_closure(n_heads, d_head), n_out=1)
ComputeAttentionHeads
return tl.Serial(
# creates three towers for one input, takes activations and creates queries keys and values
tl.Branch( # queries
[tl.Dense(d_feature), ComputeAttentionHeads], # keys
[tl.Dense(d_feature), ComputeAttentionHeads], # values
[tl.Dense(d_feature), ComputeAttentionHeads],
),
'DotProductAttn', dot_product_self_attention, n_out=1), # takes QKV
tl.Fn(# The second argument to tl.Fn() is an uncalled function
# Since we are dealing with closures we might need to call the outer
# function with the correct parameters to get the actual uncalled function.
'AttnOutput', compute_attention_output_closure(n_heads, d_head), n_out=1), # to allow for parallel
tl.Fn(# Final dense layer
tl.Dense(d_feature) )
# Take a look at the causal attention model
print(CausalAttention(d_feature=512, n_heads=8))
Serial[
Branch_out3[
[Dense_512, AttnHeads]
[Dense_512, AttnHeads]
[Dense_512, AttnHeads]
]
DotProductAttn_in3
AttnOutput
Dense_512
]
4.5 Transformer decoder block
Now that we have implemented the causal part of the transformer, we will implement the transformer decoder block. Concretely we will be implementing this image now.
To implement this function, we will have to call the CausalAttention
or Masked multi-head attention function we implemented above. We will have to add a feedforward which consists of:
- tl.LayerNorm : used to layer normalize
- tl.Dense : the dense layer
- ff_activation : feed forward activation (we use ReLu) here.
- tl.Dropout : dropout layer
- tl.Dense : dense layer
- tl.Dropout : dropout layer
Finally once we implement the feedforward, we can go ahead and implement the entire block using:
tl.Residual : takes in the tl.LayerNorm(), causal attention block, tl.dropout.
tl.Residual : takes in the feedforward block you will implement.
def DecoderBlock(d_model, d_ff, n_heads,
dropout, mode, ff_activation):"""Returns a list of layers that implements a Transformer decoder block.
The input is an activation tensor.
Args:
d_model (int): depth of embedding.
d_ff (int): depth of feed-forward layer.
n_heads (int): number of attention heads.
dropout (float): dropout rate (how much to drop out).
mode (str): 'train' or 'eval'.
ff_activation (function): the non-linearity in feed-forward layer.
Returns:
list: list of trax.layers.combinators.Serial that maps an activation tensor to an activation tensor.
"""
# Create masked multi-head attention block using CausalAttention function
= CausalAttention(
causal_attention =d_model,
d_feature=n_heads,
n_heads=mode
mode
)
# Create feed-forward block (list) with two dense layers with dropout and input normalized
= [
feed_forward # Normalize layer inputs
tl.LayerNorm(),# Add first feed forward (dense) layer (don't forget to set the correct value for n_units)
tl.Dense(d_ff),# Add activation function passed in as a parameter (you need to call it!)
# Generally ReLU
ff_activation(), # Add dropout with rate and mode specified (i.e., don't use dropout during evaluation)
=dropout, mode=mode),
tl.Dropout(rate# Add second feed forward layer (don't forget to set the correct value for n_units)
tl.Dense(d_model),# Add dropout with rate and mode specified (i.e., don't use dropout during evaluation)
=dropout, mode=mode)
tl.Dropout(rate
]
# Add list of two Residual blocks: the attention with normalization and dropout and feed-forward blocks
return [
tl.Residual(# Normalize layer input
tl.LayerNorm(),# Add causal attention block previously defined (without parentheses)
causal_attention,# Add dropout with rate and mode specified
=dropout, mode=mode)
tl.Dropout(rate
),
tl.Residual(# Add feed forward block (without parentheses)
feed_forward
), ]
# Take a look at the decoder block
print(DecoderBlock(d_model=512, d_ff=2048, n_heads=8, dropout=0.1, mode='train', ff_activation=tl.Relu))
[Serial[
Branch_out2[
None
Serial[
LayerNorm
Serial[
Branch_out3[
[Dense_512, AttnHeads]
[Dense_512, AttnHeads]
[Dense_512, AttnHeads]
]
DotProductAttn_in3
AttnOutput
Dense_512
]
Dropout
]
]
Add_in2
], Serial[
Branch_out2[
None
Serial[
LayerNorm
Dense_2048
Serial[
Relu
]
Dropout
Dense_512
Dropout
]
]
Add_in2
]]
4.6 Transformer Language Model
We will now bring it all together. In this part we will use all the subcomponents you previously built to make the final model. Concretely, here is the image we will be implementing.
Previously we coded the decoder block. Now we will code the transformer language model. Here is what we will need.
- positional_enconder - a list containing the following layers:
- A list of
n_layers
decoder blocks. - tl.Serial: takes in the following layers or lists of layers:
- tl.ShiftRight: : shift the tensor to the right by padding on axis 1.
- positional_encoder : encodes the text positions.
- decoder_blocks : the ones you created.
- tl.LayerNorm : a layer norm.
- tl.Dense : takes in the vocab_size.
- tl.LogSoftmax : to predict.
def TransformerLM(vocab_size=33300,
=512,
d_model=2048,
d_ff=6,
n_layers=8,
n_heads=0.1,
dropout=4096,
max_len='train',
mode=tl.Relu):
ff_activation"""Returns a Transformer language model.
The input to the model is a tensor of tokens. (This model uses only the
decoder part of the overall Transformer.)
Args:
vocab_size (int): vocab size.
d_model (int): depth of embedding.
d_ff (int): depth of feed-forward layer.
n_layers (int): number of decoder layers.
n_heads (int): number of attention heads.
dropout (float): dropout rate (how much to drop out).
max_len (int): maximum symbol length for positional encoding.
mode (str): 'train', 'eval' or 'predict', predict mode is for fast inference.
ff_activation (function): the non-linearity in feed-forward layer.
Returns:
trax.layers.combinators.Serial: A Transformer language model as a layer that maps from a tensor of tokens
to activations over a vocab set.
"""
# Embedding inputs and positional encoder
= [
positional_encoder # Add embedding layer of dimension (vocab_size, d_model)
=vocab_size, d_feature=d_model),
tl.Embedding(vocab_size# Use dropout with rate and mode specified
=dropout, mode=mode),
tl.Dropout(rate# Add positional encoding layer with maximum input length and mode specified
=max_len, mode=mode)]
tl.PositionalEncoding(max_len
# Create stack (list) of decoder blocks with n_layers with necessary parameters
= [
decoder_blocks for _ in range(n_layers)]
DecoderBlock(d_model, d_ff, n_heads, dropout, mode, ff_activation)
# Create the complete model as written in the figure
return tl.Serial(
# Use teacher forcing (feed output of previous step to current step)
=mode), # Specify the mode!
tl.ShiftRight(mode# Add positional encoder
positional_encoder,# Add decoder blocks
decoder_blocks,# Normalize layer
tl.LayerNorm(),
# Add dense layer of vocab_size (since need to select a word to translate to)
# (a.k.a., logits layer. Note: activation already set by ff_activation)
tl.Dense(vocab_size),# Get probabilities with Logsoftmax
tl.LogSoftmax() )
# Take a look at the Transformer
print(TransformerLM(n_layers=1))
Serial[
Serial[
ShiftRight(1)
]
Embedding_33300_512
Dropout
PositionalEncoding
Serial[
Branch_out2[
None
Serial[
LayerNorm
Serial[
Branch_out3[
[Dense_512, AttnHeads]
[Dense_512, AttnHeads]
[Dense_512, AttnHeads]
]
DotProductAttn_in3
AttnOutput
Dense_512
]
Dropout
]
]
Add_in2
]
Serial[
Branch_out2[
None
Serial[
LayerNorm
Dense_2048
Serial[
Relu
]
Dropout
Dense_512
Dropout
]
]
Add_in2
]
LayerNorm
Dense_33300
LogSoftmax
]
5 Training
Now we are going to train our model. As usual, we have to define the cost function, the optimizer, and decide whether we will be training it on a gpu
or cpu
. In this case, we will train your model on a cpu for a few steps and we will load in a pre-trained model that we can use to predict with our own words.
5.1 Training the model
We will now write a function that takes in our model and trains it. To train our model we have to decide how many times we want to iterate over the entire data set. Each iteration is defined as an epoch
. For each epoch, we have to go over all the data, using our training iterator.
Lets implement the train_model
program below to train the neural network above. Here is a list of things we should do:
- Create the train task by calling
trax.supervised.training.TrainTask
and pass in the following:- labeled_data = train_gen
- loss_layer = tl.CrossEntropyLoss()
- optimizer = trax.optimizers.Adam(0.01)
- lr_schedule = lr_schedule
- Create the eval task by calling
trax.supervised.training.EvalTask
and pass in the following:- labeled_data = eval_gen
- metrics = tl.CrossEntropyLoss() and tl.Accuracy()
- Create the training loop by calling
trax.supervised.Training.Loop
and pass in the following:- TransformerLM
- train_task
- eval_task = [eval_task]
- output_dir = output_dir
We will be using a cross entropy loss, with Adam optimizer. Read the Trax documentation to get a full understanding.
The training loop that this function returns can be runned using the run()
method by passing in the desired number of steps.
from trax.supervised import training
def training_loop(TransformerLM, train_gen, eval_gen, output_dir = "~/model"):
'''
Input:
TransformerLM (trax.layers.combinators.Serial): The model you are building.
train_gen (generator): Training stream of data.
eval_gen (generator): Evaluation stream of data.
output_dir (str): folder to save your file.
Returns:
trax.supervised.training.Loop: Training loop.
'''
= os.path.expanduser(output_dir) # trainer is an object
output_dir = trax.lr.warmup_and_rsqrt_decay(n_warmup_steps=1000, max_value=0.01)
lr_schedule
= training.TrainTask(
train_task =train_gen, # The training generator
labeled_data=tl.CrossEntropyLoss(), # Loss function
loss_layer=trax.optimizers.Adam(0.01), # Optimizer (Don't forget to set LR to 0.01)
optimizer=lr_schedule,
lr_schedule=10
n_steps_per_checkpoint
)
= training.EvalTask(
eval_task =eval_gen, # The evaluation generator
labeled_data=[tl.CrossEntropyLoss(), tl.Accuracy()] # CrossEntropyLoss and Accuracy
metrics
)
= training.Loop(TransformerLM(d_model=4,
loop =16,
d_ff=1,
n_layers=2,
n_heads='train'),
mode
train_task,=[eval_task],
eval_tasks=output_dir)
output_dir
return loop
Notice that the model will be trained for only 10 steps.
Even with this constraint the model with the original default arguments took a very long time to finish. Because of this some parameters are changed when defining the model that is fed into the training loop in the function above.
# Should take around 1.5 minutes
!rm -f ~/model/model.pkl.gz
= training_loop(TransformerLM, train_batch_stream, eval_batch_stream)
loop 10) loop.run(
Step 1: Total number of trainable weights: 316336
Step 1: Ran 1 train steps in 8.90 secs
Step 1: train CrossEntropyLoss | 10.41016102
Step 1: eval CrossEntropyLoss | 10.41146946
Step 1: eval Accuracy | 0.00000000
Step 10: Ran 9 train steps in 52.26 secs
Step 10: train CrossEntropyLoss | 10.41224766
Step 10: eval CrossEntropyLoss | 10.40876579
Step 10: eval Accuracy | 0.00000000
6 Loading in a Pre-trained model
In this part we will evaluate by loading in an almost exact version of the model we coded, but this has been trained previously to save time.
# THIS STEP COULD TAKE BETWEEN 15 SECONDS TO 15 MINUTES
# Get the model architecture
= TransformerLM(mode='eval')
model
# Load the pre-trained weights
'model.pkl.gz', weights_only=True) model.init_from_file(
7 Testing with our own input
We will now test our input. We are going to implement greedy decoding. This consists of two functions. The first one allows us to identify the next symbol. It gets the argmax of the output of our model and then returns that index.
We will now implement the next symbol function that takes in the cur_output_tokens and the trained model to return the the index of the next word.
def next_symbol(cur_output_tokens, model):
"""Returns the next symbol for a given sentence.
Args:
cur_output_tokens (list): tokenized sentence with EOS and PAD tokens at the end.
model (trax.layers.combinators.Serial): The transformer model.
Returns:
int: tokenized symbol.
"""
# current output tokens length
= len(cur_output_tokens)
token_length # calculate the minimum power of 2 big enough to store token_length
# add 1 to token_length so np.log2() doesn't receive 0 when token_length is 0
= 2**int(np.ceil(np.log2(token_length + 1)))
padded_length
# Fill cur_output_tokens with 0's until it reaches padded_length
= cur_output_tokens + [0] * (padded_length - token_length)
padded = np.array(padded)[None, :] # Don't replace this 'None'! This is a way of setting the batch dim
padded_with_batch
# model expects a tuple containing two padded tensors (with batch)
= model((padded_with_batch, padded_with_batch))
output, _ # To get log_probs you need to index output with 0 in the first dim
# token_length in the second dim and all of the entries for the last dim.
= output[0, token_length, :]
log_probs
return int(np.argmax(log_probs))
# Test it out!
= "I want to fly in the sky."
sentence_test_nxt_symbl +[0], model)]) detokenize([next_symbol(tokenize(sentence_test_nxt_symbl)
'The'
7.1 Greedy decoding
Now we will implement the greedy_decode algorithm that will call the next_symbol
function. It takes in the input_sentence, the trained model and returns the the decoded sentence.
# Decoding functions.
def greedy_decode(input_sentence, model, next_symbol=next_symbol, tokenize=tokenize, detokenize=detokenize):
"""Greedy decode function.
Args:
input_sentence (string): a sentence or article.
model (trax.layers.combinators.Serial): Transformer model.
Returns:
string: summary of the input.
"""
# Use tokenize()
= tokenize(input_sentence) + [0]
cur_output_tokens = []
generated_output = 0
cur_output = 1
EOS
while cur_output != EOS:
# Get next symbol
= next_symbol(cur_output_tokens, model)
cur_output # Append next symbol to original sentence
cur_output_tokens.append(cur_output)# Append next symbol to generated sentence
generated_output.append(cur_output)print(detokenize(generated_output))
return detokenize(generated_output)
# Test it out on a sentence!
= "It was a sunny day when I went to the market to buy some flowers. But I only found roses, not tulips."
test_sentence print(wrapper.fill(test_sentence), '\n')
print(greedy_decode(test_sentence, model))
It was a sunny day when I went to the market to buy some flowers. But
I only found roses, not tulips.
:
: I
: I just
: I just found
: I just found ros
: I just found roses
: I just found roses,
: I just found roses, not
: I just found roses, not tu
: I just found roses, not tulips
: I just found roses, not tulips
: I just found roses, not tulips.
: I just found roses, not tulips.<EOS>
: I just found roses, not tulips.<EOS>
# Test it out with a whole article!
= "It’s the posing craze sweeping the U.S. after being brought to fame by skier Lindsey Vonn, soccer star Omar Cummings, baseball player Albert Pujols - and even Republican politician Rick Perry. But now four students at Riverhead High School on Long Island, New York, have been suspended for dropping to a knee and taking up a prayer pose to mimic Denver Broncos quarterback Tim Tebow. Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were all suspended for one day because the ‘Tebowing’ craze was blocking the hallway and presenting a safety hazard to students. Scroll down for video. Banned: Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll (all pictured left) were all suspended for one day by Riverhead High School on Long Island, New York, for their tribute to Broncos quarterback Tim Tebow. Issue: Four of the pupils were suspended for one day because they allegedly did not heed to warnings that the 'Tebowing' craze at the school was blocking the hallway and presenting a safety hazard to students."
article print(wrapper.fill(article), '\n')
print(greedy_decode(article, model))
It’s the posing craze sweeping the U.S. after being brought to fame by
skier Lindsey Vonn, soccer star Omar Cummings, baseball player Albert
Pujols - and even Republican politician Rick Perry. But now four
students at Riverhead High School on Long Island, New York, have been
suspended for dropping to a knee and taking up a prayer pose to mimic
Denver Broncos quarterback Tim Tebow. Jordan Fulcoly, Wayne Drexel,
Tyler Carroll and Connor Carroll were all suspended for one day
because the ‘Tebowing’ craze was blocking the hallway and presenting a
safety hazard to students. Scroll down for video. Banned: Jordan
Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll (all pictured
left) were all suspended for one day by Riverhead High School on Long
Island, New York, for their tribute to Broncos quarterback Tim Tebow.
Issue: Four of the pupils were suspended for one day because they
allegedly did not heed to warnings that the 'Tebowing' craze at the
school was blocking the hallway and presenting a safety hazard to
students.
Jordan
Jordan Ful
Jordan Fulcol
Jordan Fulcoly
Jordan Fulcoly,
Jordan Fulcoly, Wayne
Jordan Fulcoly, Wayne Dre
Jordan Fulcoly, Wayne Drexe
Jordan Fulcoly, Wayne Drexel
Jordan Fulcoly, Wayne Drexel,
Jordan Fulcoly, Wayne Drexel, Tyler
Jordan Fulcoly, Wayne Drexel, Tyler Carroll
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for one
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for one day
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for one day.
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for one day. Four
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for one day. Four students
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for one day. Four students were
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for one day. Four students were suspended
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for one day. Four students were suspended for
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for one day. Four students were suspended for one
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for one day. Four students were suspended for one day
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for one day. Four students were suspended for one day
because
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for one day. Four students were suspended for one day
because they
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for one day. Four students were suspended for one day
because they allegedly
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for one day. Four students were suspended for one day
because they allegedly did
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for one day. Four students were suspended for one day
because they allegedly did not
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for one day. Four students were suspended for one day
because they allegedly did not hee
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for one day. Four students were suspended for one day
because they allegedly did not heed
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for one day. Four students were suspended for one day
because they allegedly did not heed to
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for one day. Four students were suspended for one day
because they allegedly did not heed to warn
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for one day. Four students were suspended for one day
because they allegedly did not heed to warnings
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for one day. Four students were suspended for one day
because they allegedly did not heed to warnings that
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for one day. Four students were suspended for one day
because they allegedly did not heed to warnings that the
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for one day. Four students were suspended for one day
because they allegedly did not heed to warnings that the '
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for one day. Four students were suspended for one day
because they allegedly did not heed to warnings that the 'Te
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for one day. Four students were suspended for one day
because they allegedly did not heed to warnings that the 'Tebow
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for one day. Four students were suspended for one day
because they allegedly did not heed to warnings that the 'Tebowing
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for one day. Four students were suspended for one day
because they allegedly did not heed to warnings that the 'Tebowing'
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for one day. Four students were suspended for one day
because they allegedly did not heed to warnings that the 'Tebowing'
cra
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for one day. Four students were suspended for one day
because they allegedly did not heed to warnings that the 'Tebowing'
craze
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for one day. Four students were suspended for one day
because they allegedly did not heed to warnings that the 'Tebowing'
craze was
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for one day. Four students were suspended for one day
because they allegedly did not heed to warnings that the 'Tebowing'
craze was blocki
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for one day. Four students were suspended for one day
because they allegedly did not heed to warnings that the 'Tebowing'
craze was blocking
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for one day. Four students were suspended for one day
because they allegedly did not heed to warnings that the 'Tebowing'
craze was blocking the
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for one day. Four students were suspended for one day
because they allegedly did not heed to warnings that the 'Tebowing'
craze was blocking the hall
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for one day. Four students were suspended for one day
because they allegedly did not heed to warnings that the 'Tebowing'
craze was blocking the hallway
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for one day. Four students were suspended for one day
because they allegedly did not heed to warnings that the 'Tebowing'
craze was blocking the hallway and
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for one day. Four students were suspended for one day
because they allegedly did not heed to warnings that the 'Tebowing'
craze was blocking the hallway and presenting
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for one day. Four students were suspended for one day
because they allegedly did not heed to warnings that the 'Tebowing'
craze was blocking the hallway and presenting a
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for one day. Four students were suspended for one day
because they allegedly did not heed to warnings that the 'Tebowing'
craze was blocking the hallway and presenting a safety
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for one day. Four students were suspended for one day
because they allegedly did not heed to warnings that the 'Tebowing'
craze was blocking the hallway and presenting a safety hazard
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for one day. Four students were suspended for one day
because they allegedly did not heed to warnings that the 'Tebowing'
craze was blocking the hallway and presenting a safety hazard to
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for one day. Four students were suspended for one day
because they allegedly did not heed to warnings that the 'Tebowing'
craze was blocking the hallway and presenting a safety hazard to
students
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for one day. Four students were suspended for one day
because they allegedly did not heed to warnings that the 'Tebowing'
craze was blocking the hallway and presenting a safety hazard to
students.
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for one day. Four students were suspended for one day
because they allegedly did not heed to warnings that the 'Tebowing'
craze was blocking the hallway and presenting a safety hazard to
students.<EOS>
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for one day. Four students were suspended for one day
because they allegedly did not heed to warnings that the 'Tebowing'
craze was blocking the hallway and presenting a safety hazard to
students.<EOS>
8 Acknowledgements
I’d like to express my thanks to the great Natural Language Processing with Attention Models Course which i completed, and acknowledge the use of some images and other materials from the course in this article.