Skip to content

Commit bc57394

Browse files
authored
Update T5 tutorial for 2.0 release (#2080)
* Update T5 tutorial for 2.0 release * Fix lint
1 parent a1dc61b commit bc57394

File tree

1 file changed

+18
-146
lines changed

1 file changed

+18
-146
lines changed

examples/tutorials/t5_demo.py

Lines changed: 18 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
==========================================================================
44
55
**Author**: `Pendo Abbo <pabbo@fb.com>`__
6+
**Author**: `Joe Cummings <jrcummings@fb.com>`__
67
78
"""
89

@@ -24,7 +25,6 @@
2425
# Common imports
2526
# --------------
2627
import torch
27-
import torch.nn.functional as F
2828

2929
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
3030

@@ -47,7 +47,7 @@
4747
# the T5 model expects the input to be batched.
4848
#
4949

50-
from torchtext.prototype.models import T5Transform
50+
from torchtext.models import T5Transform
5151

5252
padding_idx = 0
5353
eos_idx = 1
@@ -66,7 +66,7 @@
6666
#
6767
# ::
6868
#
69-
# from torchtext.prototype.models import T5_BASE_GENERATION
69+
# from torchtext.models import T5_BASE_GENERATION
7070
# transform = T5_BASE_GENERATION.transform()
7171
#
7272

@@ -81,7 +81,7 @@
8181
# https://pytorch.org/text/main/models.html
8282
#
8383
#
84-
from torchtext.prototype.models import T5_BASE_GENERATION
84+
from torchtext.models import T5_BASE_GENERATION
8585

8686

8787
t5_base = T5_BASE_GENERATION
@@ -92,146 +92,18 @@
9292

9393

9494
#######################################################################
95-
# Sequence Generator
95+
# GenerationUtils
9696
# ------------------
9797
#
98-
# We can define a sequence generator to produce an output sequence based on the input sequence provided. This calls on the
98+
# We can use torchtext's `GenerationUtils` to produce an output sequence based on the input sequence provided. This calls on the
9999
# model's encoder and decoder, and iteratively expands the decoded sequences until the end-of-sequence token is generated
100-
# for all sequences in the batch. The `generate` method shown below uses a beam search to generate the sequences. Larger
101-
# beam sizes can result in better generation at the cost of computational complexity, and a beam size of 1 is equivalent to
102-
# a greedy decoder.
103-
#
104-
105-
from torch import Tensor
106-
from torchtext.prototype.models import T5Model
107-
108-
109-
def beam_search(
110-
beam_size: int,
111-
step: int,
112-
bsz: int,
113-
decoder_output: Tensor,
114-
decoder_tokens: Tensor,
115-
scores: Tensor,
116-
incomplete_sentences: Tensor,
117-
):
118-
probs = F.log_softmax(decoder_output[:, -1], dim=-1)
119-
top = torch.topk(probs, beam_size)
120-
121-
# N is number of sequences in decoder_tokens, L is length of sequences, B is beam_size
122-
# decoder_tokens has shape (N,L) -> (N,B,L)
123-
# top.indices has shape (N,B) - > (N,B,1)
124-
# x has shape (N,B,L+1)
125-
# note that when step == 1, N = batch_size, and when step > 1, N = batch_size * beam_size
126-
x = torch.cat([decoder_tokens.unsqueeze(1).repeat(1, beam_size, 1), top.indices.unsqueeze(-1)], dim=-1)
127-
128-
# beams are first created for a given sequence
129-
if step == 1:
130-
# x has shape (batch_size, B, L+1) -> (batch_size * B, L+1)
131-
# new_scores has shape (batch_size,B)
132-
# incomplete_sentences has shape (batch_size * B) = (N)
133-
new_decoder_tokens = x.view(-1, step + 1)
134-
new_scores = top.values
135-
new_incomplete_sentences = incomplete_sentences
136-
137-
# beams already exist, want to expand each beam into possible new tokens to add
138-
# and for all expanded beams beloning to the same sequences, choose the top k
139-
else:
140-
# scores has shape (batch_size,B) -> (N,1) -> (N,B)
141-
# top.values has shape (N,B)
142-
# new_scores has shape (N,B) -> (batch_size, B^2)
143-
new_scores = (scores.view(-1, 1).repeat(1, beam_size) + top.values).view(bsz, -1)
144-
145-
# v, i have shapes (batch_size, B)
146-
v, i = torch.topk(new_scores, beam_size)
147-
148-
# x has shape (N,B,L+1) -> (batch_size, B, L+1)
149-
# i has shape (batch_size, B) -> (batch_size, B, L+1)
150-
# new_decoder_tokens has shape (batch_size, B, L+1) -> (N, L)
151-
x = x.view(bsz, -1, step + 1)
152-
new_decoder_tokens = x.gather(index=i.unsqueeze(-1).repeat(1, 1, step + 1), dim=1).view(-1, step + 1)
153-
154-
# need to update incomplete sentences in case one of the beams was kicked out
155-
# y has shape (N) -> (N, 1) -> (N, B) -> (batch_size, B^2)
156-
y = incomplete_sentences.unsqueeze(-1).repeat(1, beam_size).view(bsz, -1)
157-
158-
# now can use i to extract those beams that were selected
159-
# new_incomplete_sentences has shape (batch_size, B^2) -> (batch_size, B) -> (N, 1) -> N
160-
new_incomplete_sentences = y.gather(index=i, dim=1).view(bsz * beam_size, 1).squeeze(-1)
161-
162-
# new_scores has shape (batch_size, B)
163-
new_scores = v
164-
165-
return new_decoder_tokens, new_scores, new_incomplete_sentences
166-
167-
168-
def generate(encoder_tokens: Tensor, eos_idx: int, model: T5Model, beam_size: int) -> Tensor:
169-
170-
# pass tokens through encoder
171-
bsz = encoder_tokens.size(0)
172-
encoder_padding_mask = encoder_tokens.eq(model.padding_idx)
173-
encoder_embeddings = model.dropout1(model.token_embeddings(encoder_tokens))
174-
encoder_output = model.encoder(encoder_embeddings, tgt_key_padding_mask=encoder_padding_mask)[0]
175-
176-
encoder_output = model.norm1(encoder_output)
177-
encoder_output = model.dropout2(encoder_output)
178-
179-
# initialize decoder input sequence; T5 uses padding index as starter index to decoder sequence
180-
decoder_tokens = torch.ones((bsz, 1), dtype=torch.long) * model.padding_idx
181-
scores = torch.zeros((bsz, beam_size))
182-
183-
# mask to keep track of sequences for which the decoder has not produced an end-of-sequence token yet
184-
incomplete_sentences = torch.ones(bsz * beam_size, dtype=torch.long)
185-
186-
# iteratively generate output sequence until all sequences in the batch have generated the end-of-sequence token
187-
for step in range(model.config.max_seq_len):
188-
189-
if step == 1:
190-
# duplicate and order encoder output so that each beam is treated as its own independent sequence
191-
new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
192-
new_order = new_order.to(encoder_tokens.device).long()
193-
encoder_output = encoder_output.index_select(0, new_order)
194-
encoder_padding_mask = encoder_padding_mask.index_select(0, new_order)
195-
196-
# causal mask and padding mask for decoder sequence
197-
tgt_len = decoder_tokens.shape[1]
198-
decoder_mask = torch.triu(torch.ones((tgt_len, tgt_len), dtype=torch.float64), diagonal=1).bool()
199-
decoder_padding_mask = decoder_tokens.eq(model.padding_idx)
200-
201-
# T5 implemention uses padding idx to start sequence. Want to ignore this when masking
202-
decoder_padding_mask[:, 0] = False
203-
204-
# pass decoder sequence through decoder
205-
decoder_embeddings = model.dropout3(model.token_embeddings(decoder_tokens))
206-
decoder_output = model.decoder(
207-
decoder_embeddings,
208-
memory=encoder_output,
209-
tgt_mask=decoder_mask,
210-
tgt_key_padding_mask=decoder_padding_mask,
211-
memory_key_padding_mask=encoder_padding_mask,
212-
)[0]
213-
214-
decoder_output = model.norm2(decoder_output)
215-
decoder_output = model.dropout4(decoder_output)
216-
decoder_output = decoder_output * (model.config.embedding_dim ** -0.5)
217-
decoder_output = model.lm_head(decoder_output)
218-
219-
decoder_tokens, scores, incomplete_sentences = beam_search(
220-
beam_size, step + 1, bsz, decoder_output, decoder_tokens, scores, incomplete_sentences
221-
)
222-
# ignore newest tokens for sentences that are already complete
223-
decoder_tokens[:, -1] *= incomplete_sentences
224-
225-
# update incomplete_sentences to remove those that were just ended
226-
incomplete_sentences = incomplete_sentences - (decoder_tokens[:, -1] == eos_idx).long()
227-
228-
# early stop if all sentences have been ended
229-
if (incomplete_sentences == 0).all():
230-
break
231-
232-
# take most likely sequence
233-
decoder_tokens = decoder_tokens.view(bsz, beam_size, -1)[:, 0, :]
234-
return decoder_tokens
100+
# for all sequences in the batch. The `generate` method shown below uses greedy search to generate the sequences. Beam search and
101+
# other decoding strategies are also supported.
102+
#
103+
#
104+
from torchtext.prototype.generate import GenerationUtils
105+
106+
sequence_generator = GenerationUtils(model)
235107

236108

237109
#######################################################################
@@ -343,16 +215,16 @@ def process_labels(labels, x):
343215
# ------------------
344216
#
345217
# We can put all of the components together to generate summaries on the first batch of articles in the CNNDM test set
346-
# using a beam size of 3.
218+
# using a beam size of 1.
347219
#
348220

349221
batch = next(iter(cnndm_dataloader))
350222
input_text = batch["article"]
351223
target = batch["abstract"]
352-
beam_size = 3
224+
beam_size = 1
353225

354226
model_input = transform(input_text)
355-
model_output = generate(model=model, encoder_tokens=model_input, eos_idx=eos_idx, beam_size=beam_size)
227+
model_output = sequence_generator.generate(model_input, eos_idx=eos_idx, beam_size=beam_size)
356228
output_text = transform.decode(model_output.tolist())
357229

358230
for i in range(cnndm_batch_size):
@@ -442,7 +314,7 @@ def process_labels(labels, x):
442314
beam_size = 1
443315

444316
model_input = transform(input_text)
445-
model_output = generate(model=model, encoder_tokens=model_input, eos_idx=eos_idx, beam_size=beam_size)
317+
model_output = sequence_generator.generate(model_input, eos_idx=eos_idx, beam_size=beam_size)
446318
output_text = transform.decode(model_output.tolist())
447319

448320
for i in range(imdb_batch_size):
@@ -536,7 +408,7 @@ def process_labels(labels, x):
536408
beam_size = 4
537409

538410
model_input = transform(input_text)
539-
model_output = generate(model=model, encoder_tokens=model_input, eos_idx=eos_idx, beam_size=beam_size)
411+
model_output = sequence_generator.generate(model_input, eos_idx=eos_idx, beam_size=beam_size)
540412
output_text = transform.decode(model_output.tolist())
541413

542414
for i in range(multi_batch_size):

0 commit comments

Comments
 (0)