|
| 1 | +import os |
1 | 2 | import tempfile
|
2 | 3 |
|
3 | 4 | import pytest # noqa: F401
|
4 | 5 | import torch
|
5 | 6 | from parameterized import parameterized_class
|
6 |
| -from torchtext.models import T5Bundle |
| 7 | +from torchtext import _TEXT_BUCKET |
| 8 | +from torchtext._download_hooks import _TEST_DOWNLOAD_MANAGER |
7 | 9 | from torchtext.models import (
|
| 10 | + FLAN_T5_BASE, |
| 11 | + FLAN_T5_BASE_ENCODER, |
| 12 | + FLAN_T5_BASE_GENERATION, |
8 | 13 | T5_BASE,
|
9 | 14 | T5_BASE_ENCODER,
|
10 | 15 | T5_BASE_GENERATION,
|
|
14 | 19 | T5_SMALL,
|
15 | 20 | T5_SMALL_ENCODER,
|
16 | 21 | T5_SMALL_GENERATION,
|
| 22 | + T5Bundle, |
17 | 23 | )
|
18 | 24 | from torchtext_unittest.common.assets import get_asset_path
|
19 | 25 | from torchtext_unittest.common.parameterized_utils import nested_params
|
20 | 26 | from torchtext_unittest.common.torchtext_test_case import TorchtextTestCase
|
21 |
| -from transformers import T5EncoderModel, T5ForConditionalGeneration, T5Model |
22 | 27 |
|
23 | 28 | BUNDLERS = {
|
24 | 29 | "base_model": T5_BASE,
|
|
30 | 35 | "large_model": T5_LARGE,
|
31 | 36 | "large_encoder": T5_LARGE_ENCODER,
|
32 | 37 | "large_generation": T5_LARGE_GENERATION,
|
| 38 | + "flan_base_encoder": FLAN_T5_BASE_ENCODER, |
| 39 | + "flan_base_model": FLAN_T5_BASE, |
| 40 | + "flan_base_generation": FLAN_T5_BASE_GENERATION, |
33 | 41 | }
|
34 | 42 |
|
35 | 43 |
|
|
45 | 53 | ("large_model",),
|
46 | 54 | ("large_encoder",),
|
47 | 55 | ("large_generation",),
|
| 56 | + ("flan_base_encoder",), |
| 57 | + ("flan_base_model",), |
| 58 | + ("flan_base_generation",), |
48 | 59 | ],
|
49 | 60 | )
|
50 | 61 | class TestT5Model(TorchtextTestCase):
|
@@ -74,126 +85,81 @@ def _t5_model(self, is_jit, t5_model, expected_asset_name, test_text):
|
74 | 85 |
|
75 | 86 | def _t5_get_encoder(self, model, model_input, encoder_output):
|
76 | 87 | encoder = model.get_encoder()
|
77 |
| - # Need to set the tgt_key_padding_mask to ensure the same results |
| 88 | + # Need to set the key_padding_mask to ensure the same results |
78 | 89 | encoder_padding_mask = model_input.eq(model.padding_idx)
|
79 | 90 | output_from_get_encoder = encoder(model_input, src_key_padding_mask=encoder_padding_mask)["encoder_output"]
|
80 | 91 | assert torch.all(output_from_get_encoder.eq(encoder_output))
|
81 | 92 |
|
82 |
| - @nested_params(["jit", "not_jit"]) |
| 93 | + @nested_params(["not_jit", "jit"]) |
83 | 94 | def test_t5_model(self, name) -> None:
|
84 |
| - configuration, type = self.model_name.split("_") |
| 95 | + names = self.model_name.split("_") |
| 96 | + |
| 97 | + num_names = len(names) |
| 98 | + |
| 99 | + if num_names == 3: |
| 100 | + # Handled slightly differently for Flan-T5 model naming |
| 101 | + configuration = names[1] |
| 102 | + type = names[2] |
| 103 | + expected_asset_name = f"t5.flan.{configuration}.{type}.output.pt" |
| 104 | + t5_model = BUNDLERS["flan_" + configuration + "_" + type] |
| 105 | + elif num_names == 2: |
| 106 | + configuration = names[0] |
| 107 | + type = names[1] |
| 108 | + expected_asset_name = f"t5.{configuration}.{type}.output.pt" |
| 109 | + t5_model = BUNDLERS[configuration + "_" + type] |
| 110 | + else: |
| 111 | + raise RuntimeError(f"Unknown model name: {self.model_name}") |
85 | 112 |
|
86 |
| - expected_asset_name = f"t5.{configuration}.{type}.output.pt" |
87 | 113 | test_text = ["Hello world", "Attention rocks!"]
|
88 | 114 | is_jit = name == "jit"
|
89 |
| - t5_model = BUNDLERS[configuration + "_" + type] |
90 | 115 | self._t5_model(is_jit=is_jit, t5_model=t5_model, expected_asset_name=expected_asset_name, test_text=test_text)
|
91 | 116 |
|
92 | 117 |
|
| 118 | +@parameterized_class( |
| 119 | + ("model",), |
| 120 | + [ |
| 121 | + ("hf_t5_small_encoder",), |
| 122 | + ("hf_t5_small",), |
| 123 | + ("hf_t5_small_generation",), |
| 124 | + ("hf_flan_base_encoder",), |
| 125 | + ("hf_flan_base",), |
| 126 | + ("hf_flan_base_generation",), |
| 127 | + ], |
| 128 | +) |
93 | 129 | class TestLoadFromHFCheckpoints(TorchtextTestCase):
|
94 | 130 | def setUp(self) -> None:
|
95 | 131 | super().setUp()
|
96 | 132 | self.encoder_input_ids = torch.tensor([[1, 2, 3, 4, 5, 6], [7, 8, 9, 0, 0, 0]])
|
97 |
| - self.encoder_padding_mask = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 0]]) |
| 133 | + self.encoder_padding_mask = torch.tensor( |
| 134 | + [[False, False, False, False, False, False], [False, False, False, True, True, True]] |
| 135 | + ) |
98 | 136 | self.decoder_input_ids = torch.tensor([[7, 8, 9, 0, 0, 0], [10, 11, 12, 0, 0, 0]])
|
99 |
| - self.decoder_padding_mask = torch.tensor([[1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0]]) |
100 |
| - |
101 |
| - def check_outputs_of_models(self, our_output, hf_output, config, encoder_only) -> None: |
102 |
| - # check that encoder layers match |
103 |
| - for i in range(config.num_encoder_layers + 1): |
104 |
| - if i < config.num_encoder_layers: |
105 |
| - hf_output_sa = hf_output.attentions[i] if encoder_only else hf_output.encoder_attentions[i] |
106 |
| - # self-attention scores |
107 |
| - assert torch.equal( |
108 |
| - our_output["encoder_sa_scores"][i], hf_output_sa |
109 |
| - ), f"Mismatched self-attention scores for encoder layer {i}" |
110 |
| - hf_output_hs = hf_output.hidden_states[i] if encoder_only else hf_output.encoder_hidden_states[i] |
111 |
| - # encoder hidden states |
112 |
| - assert torch.equal( |
113 |
| - our_output["encoder_hidden_states"][i], hf_output_hs |
114 |
| - ), f"Mismatched hidden states for encoder layer {i}" |
115 |
| - |
116 |
| - if not encoder_only: |
117 |
| - # check that decoder layers match |
118 |
| - for i in range(config.num_decoder_layers + 1): |
119 |
| - if i < config.num_encoder_layers: |
120 |
| - # self-attention scores |
121 |
| - assert torch.equal( |
122 |
| - our_output["decoder_sa_scores"][i], hf_output.decoder_attentions[i] |
123 |
| - ), f"Mismatched self-attention scores for decoder layer {i}" |
124 |
| - # cross-attention scores |
125 |
| - assert torch.equal( |
126 |
| - our_output["decoder_ca_scores"][i], hf_output.cross_attentions[i] |
127 |
| - ), f"Mismatched cross-attention scores for decoder layer {i}" |
128 |
| - # decoder hidden states |
129 |
| - assert torch.equal( |
130 |
| - our_output["decoder_hidden_states"][i], hf_output.decoder_hidden_states[i] |
131 |
| - ), f"Mismatched hidden states for decoder layer {i}" |
132 |
| - |
133 |
| - def test_t5_bundler_load_hf_ckpt_pretrained_encoder_only(self) -> None: |
134 |
| - with tempfile.TemporaryDirectory() as tmp_dir: |
135 |
| - model_path = f"{tmp_dir}/hf_t5_small_enc" |
136 |
| - |
137 |
| - t5_small_enc = T5EncoderModel.from_pretrained("t5-small") |
138 |
| - t5_small_enc.save_pretrained(model_path) |
139 |
| - |
140 |
| - our_encoder = T5Bundle.build_model_from_huggingface_ckpt(model_path, encoder_only=True) |
141 |
| - |
142 |
| - hf_output = t5_small_enc( |
143 |
| - input_ids=self.encoder_input_ids, |
144 |
| - attention_mask=self.encoder_padding_mask, |
145 |
| - output_hidden_states=True, |
146 |
| - output_attentions=True, |
147 |
| - ) |
148 |
| - |
149 |
| - our_output = our_encoder(self.encoder_input_ids) |
150 |
| - |
151 |
| - self.check_outputs_of_models(our_output, hf_output, our_encoder.config, True) |
152 |
| - |
153 |
| - def test_t5_bundler_load_hf_ckpt_pretrained_encoder_decoder(self) -> None: |
154 |
| - with tempfile.TemporaryDirectory() as tmp_dir: |
155 |
| - model_path = f"{tmp_dir}/hf_t5_small" |
156 |
| - |
157 |
| - t5_small = T5Model.from_pretrained("t5-small") |
158 |
| - t5_small.save_pretrained(model_path) |
159 |
| - |
160 |
| - our_t5 = T5Bundle.build_model_from_huggingface_ckpt(model_path) |
161 |
| - |
162 |
| - hf_output = t5_small( |
163 |
| - input_ids=self.encoder_input_ids, |
164 |
| - decoder_input_ids=self.decoder_input_ids, |
165 |
| - attention_mask=self.encoder_padding_mask, |
166 |
| - decoder_attention_mask=self.decoder_padding_mask, |
167 |
| - output_hidden_states=True, |
168 |
| - output_attentions=True, |
169 |
| - ) |
170 |
| - |
171 |
| - our_output = our_t5(self.encoder_input_ids, self.decoder_input_ids) |
172 |
| - |
173 |
| - self.check_outputs_of_models(our_output, hf_output, our_t5.config, False) |
174 |
| - |
175 |
| - def test_t5_bundler_load_hf_ckpt_pretrained_encoder_decoder_with_gen(self) -> None: |
176 |
| - with tempfile.TemporaryDirectory() as tmp_dir: |
177 |
| - model_path = f"{tmp_dir}/hf_t5_small_gen" |
178 |
| - |
179 |
| - t5_small_gen = T5ForConditionalGeneration.from_pretrained("t5-small") |
180 |
| - t5_small_gen.save_pretrained(model_path) |
181 |
| - |
182 |
| - our_t5 = T5Bundle.build_model_from_huggingface_ckpt(model_path) |
183 |
| - |
184 |
| - hf_output = t5_small_gen( |
185 |
| - input_ids=self.encoder_input_ids, |
186 |
| - decoder_input_ids=self.decoder_input_ids, |
187 |
| - attention_mask=self.encoder_padding_mask, |
188 |
| - decoder_attention_mask=self.decoder_padding_mask, |
189 |
| - output_hidden_states=True, |
190 |
| - output_attentions=True, |
191 |
| - ) |
192 |
| - |
193 |
| - our_output = our_t5(self.encoder_input_ids, self.decoder_input_ids) |
194 |
| - |
195 |
| - self.check_outputs_of_models(our_output, hf_output, our_t5.config, False) |
196 |
| - |
197 |
| - def test_flan_t5_bundler_load_hf_ckpt_pretrained_encoder_decoder(self) -> None: |
198 |
| - # TODO(joecummings): Download FLAN-T5 chkpts and test here |
199 |
| - pass |
| 137 | + self.decoder_padding_mask = torch.tensor( |
| 138 | + [[False, False, False, True, True, True], [False, False, False, True, True, True]] |
| 139 | + ) |
| 140 | + |
| 141 | + def test_t5_bundler_load_hf_ckpt_pretrained(self) -> None: |
| 142 | + with tempfile.TemporaryDirectory() as tmp: |
| 143 | + local_path = f"{tmp}/{self.model}" |
| 144 | + remote_bucket = f"{_TEXT_BUCKET}test_models" |
| 145 | + |
| 146 | + os.mkdir(local_path) |
| 147 | + |
| 148 | + for f in {"config.json", "pytorch_model.bin"}: |
| 149 | + destination = f"{local_path}/{f}" |
| 150 | + remote_path = f"{remote_bucket}/{self.model}/{f}" |
| 151 | + _TEST_DOWNLOAD_MANAGER.get_local_path(url=remote_path, destination=destination) |
| 152 | + |
| 153 | + names = self.model.split("_") |
| 154 | + is_encoder_only = names[-1] == "encoder" |
| 155 | + |
| 156 | + model = T5Bundle.build_model_from_huggingface_ckpt(local_path, encoder_only=is_encoder_only) |
| 157 | + if is_encoder_only: |
| 158 | + model(self.encoder_input_ids, encoder_padding_mask=self.encoder_padding_mask) |
| 159 | + else: |
| 160 | + model( |
| 161 | + self.encoder_input_ids, |
| 162 | + self.decoder_input_ids, |
| 163 | + encoder_padding_mask=self.encoder_padding_mask, |
| 164 | + decoder_padding_mask=self.decoder_padding_mask, |
| 165 | + ) |
0 commit comments