|
3 | 3 | ==========================================================================
|
4 | 4 |
|
5 | 5 | **Author**: `Pendo Abbo <pabbo@fb.com>`__
|
| 6 | +**Author**: `Joe Cummings <jrcummings@fb.com>`__ |
6 | 7 |
|
7 | 8 | """
|
8 | 9 |
|
|
24 | 25 | # Common imports
|
25 | 26 | # --------------
|
26 | 27 | import torch
|
27 |
| -import torch.nn.functional as F |
28 | 28 |
|
29 | 29 | DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
30 | 30 |
|
|
47 | 47 | # the T5 model expects the input to be batched.
|
48 | 48 | #
|
49 | 49 |
|
50 |
| -from torchtext.prototype.models import T5Transform |
| 50 | +from torchtext.models import T5Transform |
51 | 51 |
|
52 | 52 | padding_idx = 0
|
53 | 53 | eos_idx = 1
|
|
66 | 66 | #
|
67 | 67 | # ::
|
68 | 68 | #
|
69 |
| -# from torchtext.prototype.models import T5_BASE_GENERATION |
| 69 | +# from torchtext.models import T5_BASE_GENERATION |
70 | 70 | # transform = T5_BASE_GENERATION.transform()
|
71 | 71 | #
|
72 | 72 |
|
|
81 | 81 | # https://pytorch.org/text/main/models.html
|
82 | 82 | #
|
83 | 83 | #
|
84 |
| -from torchtext.prototype.models import T5_BASE_GENERATION |
| 84 | +from torchtext.models import T5_BASE_GENERATION |
85 | 85 |
|
86 | 86 |
|
87 | 87 | t5_base = T5_BASE_GENERATION
|
|
92 | 92 |
|
93 | 93 |
|
94 | 94 | #######################################################################
|
95 |
| -# Sequence Generator |
| 95 | +# GenerationUtils |
96 | 96 | # ------------------
|
97 | 97 | #
|
98 |
| -# We can define a sequence generator to produce an output sequence based on the input sequence provided. This calls on the |
| 98 | +# We can use torchtext's `GenerationUtils` to produce an output sequence based on the input sequence provided. This calls on the |
99 | 99 | # model's encoder and decoder, and iteratively expands the decoded sequences until the end-of-sequence token is generated
|
100 |
| -# for all sequences in the batch. The `generate` method shown below uses a beam search to generate the sequences. Larger |
101 |
| -# beam sizes can result in better generation at the cost of computational complexity, and a beam size of 1 is equivalent to |
102 |
| -# a greedy decoder. |
103 |
| -# |
104 |
| - |
105 |
| -from torch import Tensor |
106 |
| -from torchtext.prototype.models import T5Model |
107 |
| - |
108 |
| - |
109 |
| -def beam_search( |
110 |
| - beam_size: int, |
111 |
| - step: int, |
112 |
| - bsz: int, |
113 |
| - decoder_output: Tensor, |
114 |
| - decoder_tokens: Tensor, |
115 |
| - scores: Tensor, |
116 |
| - incomplete_sentences: Tensor, |
117 |
| -): |
118 |
| - probs = F.log_softmax(decoder_output[:, -1], dim=-1) |
119 |
| - top = torch.topk(probs, beam_size) |
120 |
| - |
121 |
| - # N is number of sequences in decoder_tokens, L is length of sequences, B is beam_size |
122 |
| - # decoder_tokens has shape (N,L) -> (N,B,L) |
123 |
| - # top.indices has shape (N,B) - > (N,B,1) |
124 |
| - # x has shape (N,B,L+1) |
125 |
| - # note that when step == 1, N = batch_size, and when step > 1, N = batch_size * beam_size |
126 |
| - x = torch.cat([decoder_tokens.unsqueeze(1).repeat(1, beam_size, 1), top.indices.unsqueeze(-1)], dim=-1) |
127 |
| - |
128 |
| - # beams are first created for a given sequence |
129 |
| - if step == 1: |
130 |
| - # x has shape (batch_size, B, L+1) -> (batch_size * B, L+1) |
131 |
| - # new_scores has shape (batch_size,B) |
132 |
| - # incomplete_sentences has shape (batch_size * B) = (N) |
133 |
| - new_decoder_tokens = x.view(-1, step + 1) |
134 |
| - new_scores = top.values |
135 |
| - new_incomplete_sentences = incomplete_sentences |
136 |
| - |
137 |
| - # beams already exist, want to expand each beam into possible new tokens to add |
138 |
| - # and for all expanded beams beloning to the same sequences, choose the top k |
139 |
| - else: |
140 |
| - # scores has shape (batch_size,B) -> (N,1) -> (N,B) |
141 |
| - # top.values has shape (N,B) |
142 |
| - # new_scores has shape (N,B) -> (batch_size, B^2) |
143 |
| - new_scores = (scores.view(-1, 1).repeat(1, beam_size) + top.values).view(bsz, -1) |
144 |
| - |
145 |
| - # v, i have shapes (batch_size, B) |
146 |
| - v, i = torch.topk(new_scores, beam_size) |
147 |
| - |
148 |
| - # x has shape (N,B,L+1) -> (batch_size, B, L+1) |
149 |
| - # i has shape (batch_size, B) -> (batch_size, B, L+1) |
150 |
| - # new_decoder_tokens has shape (batch_size, B, L+1) -> (N, L) |
151 |
| - x = x.view(bsz, -1, step + 1) |
152 |
| - new_decoder_tokens = x.gather(index=i.unsqueeze(-1).repeat(1, 1, step + 1), dim=1).view(-1, step + 1) |
153 |
| - |
154 |
| - # need to update incomplete sentences in case one of the beams was kicked out |
155 |
| - # y has shape (N) -> (N, 1) -> (N, B) -> (batch_size, B^2) |
156 |
| - y = incomplete_sentences.unsqueeze(-1).repeat(1, beam_size).view(bsz, -1) |
157 |
| - |
158 |
| - # now can use i to extract those beams that were selected |
159 |
| - # new_incomplete_sentences has shape (batch_size, B^2) -> (batch_size, B) -> (N, 1) -> N |
160 |
| - new_incomplete_sentences = y.gather(index=i, dim=1).view(bsz * beam_size, 1).squeeze(-1) |
161 |
| - |
162 |
| - # new_scores has shape (batch_size, B) |
163 |
| - new_scores = v |
164 |
| - |
165 |
| - return new_decoder_tokens, new_scores, new_incomplete_sentences |
166 |
| - |
167 |
| - |
168 |
| -def generate(encoder_tokens: Tensor, eos_idx: int, model: T5Model, beam_size: int) -> Tensor: |
169 |
| - |
170 |
| - # pass tokens through encoder |
171 |
| - bsz = encoder_tokens.size(0) |
172 |
| - encoder_padding_mask = encoder_tokens.eq(model.padding_idx) |
173 |
| - encoder_embeddings = model.dropout1(model.token_embeddings(encoder_tokens)) |
174 |
| - encoder_output = model.encoder(encoder_embeddings, tgt_key_padding_mask=encoder_padding_mask)[0] |
175 |
| - |
176 |
| - encoder_output = model.norm1(encoder_output) |
177 |
| - encoder_output = model.dropout2(encoder_output) |
178 |
| - |
179 |
| - # initialize decoder input sequence; T5 uses padding index as starter index to decoder sequence |
180 |
| - decoder_tokens = torch.ones((bsz, 1), dtype=torch.long) * model.padding_idx |
181 |
| - scores = torch.zeros((bsz, beam_size)) |
182 |
| - |
183 |
| - # mask to keep track of sequences for which the decoder has not produced an end-of-sequence token yet |
184 |
| - incomplete_sentences = torch.ones(bsz * beam_size, dtype=torch.long) |
185 |
| - |
186 |
| - # iteratively generate output sequence until all sequences in the batch have generated the end-of-sequence token |
187 |
| - for step in range(model.config.max_seq_len): |
188 |
| - |
189 |
| - if step == 1: |
190 |
| - # duplicate and order encoder output so that each beam is treated as its own independent sequence |
191 |
| - new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1) |
192 |
| - new_order = new_order.to(encoder_tokens.device).long() |
193 |
| - encoder_output = encoder_output.index_select(0, new_order) |
194 |
| - encoder_padding_mask = encoder_padding_mask.index_select(0, new_order) |
195 |
| - |
196 |
| - # causal mask and padding mask for decoder sequence |
197 |
| - tgt_len = decoder_tokens.shape[1] |
198 |
| - decoder_mask = torch.triu(torch.ones((tgt_len, tgt_len), dtype=torch.float64), diagonal=1).bool() |
199 |
| - decoder_padding_mask = decoder_tokens.eq(model.padding_idx) |
200 |
| - |
201 |
| - # T5 implemention uses padding idx to start sequence. Want to ignore this when masking |
202 |
| - decoder_padding_mask[:, 0] = False |
203 |
| - |
204 |
| - # pass decoder sequence through decoder |
205 |
| - decoder_embeddings = model.dropout3(model.token_embeddings(decoder_tokens)) |
206 |
| - decoder_output = model.decoder( |
207 |
| - decoder_embeddings, |
208 |
| - memory=encoder_output, |
209 |
| - tgt_mask=decoder_mask, |
210 |
| - tgt_key_padding_mask=decoder_padding_mask, |
211 |
| - memory_key_padding_mask=encoder_padding_mask, |
212 |
| - )[0] |
213 |
| - |
214 |
| - decoder_output = model.norm2(decoder_output) |
215 |
| - decoder_output = model.dropout4(decoder_output) |
216 |
| - decoder_output = decoder_output * (model.config.embedding_dim ** -0.5) |
217 |
| - decoder_output = model.lm_head(decoder_output) |
218 |
| - |
219 |
| - decoder_tokens, scores, incomplete_sentences = beam_search( |
220 |
| - beam_size, step + 1, bsz, decoder_output, decoder_tokens, scores, incomplete_sentences |
221 |
| - ) |
222 |
| - # ignore newest tokens for sentences that are already complete |
223 |
| - decoder_tokens[:, -1] *= incomplete_sentences |
224 |
| - |
225 |
| - # update incomplete_sentences to remove those that were just ended |
226 |
| - incomplete_sentences = incomplete_sentences - (decoder_tokens[:, -1] == eos_idx).long() |
227 |
| - |
228 |
| - # early stop if all sentences have been ended |
229 |
| - if (incomplete_sentences == 0).all(): |
230 |
| - break |
231 |
| - |
232 |
| - # take most likely sequence |
233 |
| - decoder_tokens = decoder_tokens.view(bsz, beam_size, -1)[:, 0, :] |
234 |
| - return decoder_tokens |
| 100 | +# for all sequences in the batch. The `generate` method shown below uses greedy search to generate the sequences. Beam search and |
| 101 | +# other decoding strategies are also supported. |
| 102 | +# |
| 103 | +# |
| 104 | +from torchtext.prototype.generate import GenerationUtils |
| 105 | + |
| 106 | +sequence_generator = GenerationUtils(model) |
235 | 107 |
|
236 | 108 |
|
237 | 109 | #######################################################################
|
@@ -343,16 +215,16 @@ def process_labels(labels, x):
|
343 | 215 | # ------------------
|
344 | 216 | #
|
345 | 217 | # We can put all of the components together to generate summaries on the first batch of articles in the CNNDM test set
|
346 |
| -# using a beam size of 3. |
| 218 | +# using a beam size of 1. |
347 | 219 | #
|
348 | 220 |
|
349 | 221 | batch = next(iter(cnndm_dataloader))
|
350 | 222 | input_text = batch["article"]
|
351 | 223 | target = batch["abstract"]
|
352 |
| -beam_size = 3 |
| 224 | +beam_size = 1 |
353 | 225 |
|
354 | 226 | model_input = transform(input_text)
|
355 |
| -model_output = generate(model=model, encoder_tokens=model_input, eos_idx=eos_idx, beam_size=beam_size) |
| 227 | +model_output = sequence_generator.generate(model_input, eos_idx=eos_idx, beam_size=beam_size) |
356 | 228 | output_text = transform.decode(model_output.tolist())
|
357 | 229 |
|
358 | 230 | for i in range(cnndm_batch_size):
|
@@ -442,7 +314,7 @@ def process_labels(labels, x):
|
442 | 314 | beam_size = 1
|
443 | 315 |
|
444 | 316 | model_input = transform(input_text)
|
445 |
| -model_output = generate(model=model, encoder_tokens=model_input, eos_idx=eos_idx, beam_size=beam_size) |
| 317 | +model_output = sequence_generator.generate(model_input, eos_idx=eos_idx, beam_size=beam_size) |
446 | 318 | output_text = transform.decode(model_output.tolist())
|
447 | 319 |
|
448 | 320 | for i in range(imdb_batch_size):
|
@@ -536,7 +408,7 @@ def process_labels(labels, x):
|
536 | 408 | beam_size = 4
|
537 | 409 |
|
538 | 410 | model_input = transform(input_text)
|
539 |
| -model_output = generate(model=model, encoder_tokens=model_input, eos_idx=eos_idx, beam_size=beam_size) |
| 411 | +model_output = sequence_generator.generate(model_input, eos_idx=eos_idx, beam_size=beam_size) |
540 | 412 | output_text = transform.decode(model_output.tolist())
|
541 | 413 |
|
542 | 414 | for i in range(multi_batch_size):
|
|
0 commit comments