Skip to content

Commit a1dc61b

Browse files
authored
Test newly uploaded Flan-T5 weights (#2074)
* Add tests for loading Flan-T5 weights from HF checkpoints * Add expected outputs and update tests for Flan * Add newline at end of file * pin transformers version for testing * Simplify test for HF loading * Fix linting * Fix integration tests w/ proper download path
1 parent bd10e28 commit a1dc61b

File tree

9 files changed

+83
-172
lines changed

9 files changed

+83
-172
lines changed

.github/workflows/integration-test.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ jobs:
5555
python3 -m pip --quiet install sentencepiece
5656
python3 -m pip --quiet install tqdm
5757
python3 -m pip --quiet install expecttest
58-
python3 -m pip --quiet install transformers
5958
# Run Tests
6059
python3 -m torch.utils.collect_env
6160
cd test

README.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ The library currently consist of following pre-trained models:
122122
* `DistilRoBERTa <https://github.com/huggingface/transformers/blob/main/examples/research_projects/distillation/README.md>`_
123123
* XLM-RoBERTa: `Base and Large Architure <https://github.com/pytorch/fairseq/tree/main/examples/xlmr#pre-trained-models>`_
124124
* T5: `Small, Base, Large, 3B, and 11B Architecture <https://github.com/google-research/text-to-text-transfer-transformer>`_
125-
* Flan-T5: `Small, Base, Large, XL, and XXL Architecture <https://github.com/google-research/t5x>`_
125+
* Flan-T5: `Base, Large, XL, and XXL Architecture <https://github.com/google-research/t5x>`_
126126

127127
Tokenizers
128128
==========
Lines changed: 75 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
1+
import os
12
import tempfile
23

34
import pytest # noqa: F401
45
import torch
56
from parameterized import parameterized_class
6-
from torchtext.models import T5Bundle
7+
from torchtext import _TEXT_BUCKET
8+
from torchtext._download_hooks import _TEST_DOWNLOAD_MANAGER
79
from torchtext.models import (
10+
FLAN_T5_BASE,
11+
FLAN_T5_BASE_ENCODER,
12+
FLAN_T5_BASE_GENERATION,
813
T5_BASE,
914
T5_BASE_ENCODER,
1015
T5_BASE_GENERATION,
@@ -14,11 +19,11 @@
1419
T5_SMALL,
1520
T5_SMALL_ENCODER,
1621
T5_SMALL_GENERATION,
22+
T5Bundle,
1723
)
1824
from torchtext_unittest.common.assets import get_asset_path
1925
from torchtext_unittest.common.parameterized_utils import nested_params
2026
from torchtext_unittest.common.torchtext_test_case import TorchtextTestCase
21-
from transformers import T5EncoderModel, T5ForConditionalGeneration, T5Model
2227

2328
BUNDLERS = {
2429
"base_model": T5_BASE,
@@ -30,6 +35,9 @@
3035
"large_model": T5_LARGE,
3136
"large_encoder": T5_LARGE_ENCODER,
3237
"large_generation": T5_LARGE_GENERATION,
38+
"flan_base_encoder": FLAN_T5_BASE_ENCODER,
39+
"flan_base_model": FLAN_T5_BASE,
40+
"flan_base_generation": FLAN_T5_BASE_GENERATION,
3341
}
3442

3543

@@ -45,6 +53,9 @@
4553
("large_model",),
4654
("large_encoder",),
4755
("large_generation",),
56+
("flan_base_encoder",),
57+
("flan_base_model",),
58+
("flan_base_generation",),
4859
],
4960
)
5061
class TestT5Model(TorchtextTestCase):
@@ -74,126 +85,81 @@ def _t5_model(self, is_jit, t5_model, expected_asset_name, test_text):
7485

7586
def _t5_get_encoder(self, model, model_input, encoder_output):
7687
encoder = model.get_encoder()
77-
# Need to set the tgt_key_padding_mask to ensure the same results
88+
# Need to set the key_padding_mask to ensure the same results
7889
encoder_padding_mask = model_input.eq(model.padding_idx)
7990
output_from_get_encoder = encoder(model_input, src_key_padding_mask=encoder_padding_mask)["encoder_output"]
8091
assert torch.all(output_from_get_encoder.eq(encoder_output))
8192

82-
@nested_params(["jit", "not_jit"])
93+
@nested_params(["not_jit", "jit"])
8394
def test_t5_model(self, name) -> None:
84-
configuration, type = self.model_name.split("_")
95+
names = self.model_name.split("_")
96+
97+
num_names = len(names)
98+
99+
if num_names == 3:
100+
# Handled slightly differently for Flan-T5 model naming
101+
configuration = names[1]
102+
type = names[2]
103+
expected_asset_name = f"t5.flan.{configuration}.{type}.output.pt"
104+
t5_model = BUNDLERS["flan_" + configuration + "_" + type]
105+
elif num_names == 2:
106+
configuration = names[0]
107+
type = names[1]
108+
expected_asset_name = f"t5.{configuration}.{type}.output.pt"
109+
t5_model = BUNDLERS[configuration + "_" + type]
110+
else:
111+
raise RuntimeError(f"Unknown model name: {self.model_name}")
85112

86-
expected_asset_name = f"t5.{configuration}.{type}.output.pt"
87113
test_text = ["Hello world", "Attention rocks!"]
88114
is_jit = name == "jit"
89-
t5_model = BUNDLERS[configuration + "_" + type]
90115
self._t5_model(is_jit=is_jit, t5_model=t5_model, expected_asset_name=expected_asset_name, test_text=test_text)
91116

92117

118+
@parameterized_class(
119+
("model",),
120+
[
121+
("hf_t5_small_encoder",),
122+
("hf_t5_small",),
123+
("hf_t5_small_generation",),
124+
("hf_flan_base_encoder",),
125+
("hf_flan_base",),
126+
("hf_flan_base_generation",),
127+
],
128+
)
93129
class TestLoadFromHFCheckpoints(TorchtextTestCase):
94130
def setUp(self) -> None:
95131
super().setUp()
96132
self.encoder_input_ids = torch.tensor([[1, 2, 3, 4, 5, 6], [7, 8, 9, 0, 0, 0]])
97-
self.encoder_padding_mask = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 0]])
133+
self.encoder_padding_mask = torch.tensor(
134+
[[False, False, False, False, False, False], [False, False, False, True, True, True]]
135+
)
98136
self.decoder_input_ids = torch.tensor([[7, 8, 9, 0, 0, 0], [10, 11, 12, 0, 0, 0]])
99-
self.decoder_padding_mask = torch.tensor([[1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0]])
100-
101-
def check_outputs_of_models(self, our_output, hf_output, config, encoder_only) -> None:
102-
# check that encoder layers match
103-
for i in range(config.num_encoder_layers + 1):
104-
if i < config.num_encoder_layers:
105-
hf_output_sa = hf_output.attentions[i] if encoder_only else hf_output.encoder_attentions[i]
106-
# self-attention scores
107-
assert torch.equal(
108-
our_output["encoder_sa_scores"][i], hf_output_sa
109-
), f"Mismatched self-attention scores for encoder layer {i}"
110-
hf_output_hs = hf_output.hidden_states[i] if encoder_only else hf_output.encoder_hidden_states[i]
111-
# encoder hidden states
112-
assert torch.equal(
113-
our_output["encoder_hidden_states"][i], hf_output_hs
114-
), f"Mismatched hidden states for encoder layer {i}"
115-
116-
if not encoder_only:
117-
# check that decoder layers match
118-
for i in range(config.num_decoder_layers + 1):
119-
if i < config.num_encoder_layers:
120-
# self-attention scores
121-
assert torch.equal(
122-
our_output["decoder_sa_scores"][i], hf_output.decoder_attentions[i]
123-
), f"Mismatched self-attention scores for decoder layer {i}"
124-
# cross-attention scores
125-
assert torch.equal(
126-
our_output["decoder_ca_scores"][i], hf_output.cross_attentions[i]
127-
), f"Mismatched cross-attention scores for decoder layer {i}"
128-
# decoder hidden states
129-
assert torch.equal(
130-
our_output["decoder_hidden_states"][i], hf_output.decoder_hidden_states[i]
131-
), f"Mismatched hidden states for decoder layer {i}"
132-
133-
def test_t5_bundler_load_hf_ckpt_pretrained_encoder_only(self) -> None:
134-
with tempfile.TemporaryDirectory() as tmp_dir:
135-
model_path = f"{tmp_dir}/hf_t5_small_enc"
136-
137-
t5_small_enc = T5EncoderModel.from_pretrained("t5-small")
138-
t5_small_enc.save_pretrained(model_path)
139-
140-
our_encoder = T5Bundle.build_model_from_huggingface_ckpt(model_path, encoder_only=True)
141-
142-
hf_output = t5_small_enc(
143-
input_ids=self.encoder_input_ids,
144-
attention_mask=self.encoder_padding_mask,
145-
output_hidden_states=True,
146-
output_attentions=True,
147-
)
148-
149-
our_output = our_encoder(self.encoder_input_ids)
150-
151-
self.check_outputs_of_models(our_output, hf_output, our_encoder.config, True)
152-
153-
def test_t5_bundler_load_hf_ckpt_pretrained_encoder_decoder(self) -> None:
154-
with tempfile.TemporaryDirectory() as tmp_dir:
155-
model_path = f"{tmp_dir}/hf_t5_small"
156-
157-
t5_small = T5Model.from_pretrained("t5-small")
158-
t5_small.save_pretrained(model_path)
159-
160-
our_t5 = T5Bundle.build_model_from_huggingface_ckpt(model_path)
161-
162-
hf_output = t5_small(
163-
input_ids=self.encoder_input_ids,
164-
decoder_input_ids=self.decoder_input_ids,
165-
attention_mask=self.encoder_padding_mask,
166-
decoder_attention_mask=self.decoder_padding_mask,
167-
output_hidden_states=True,
168-
output_attentions=True,
169-
)
170-
171-
our_output = our_t5(self.encoder_input_ids, self.decoder_input_ids)
172-
173-
self.check_outputs_of_models(our_output, hf_output, our_t5.config, False)
174-
175-
def test_t5_bundler_load_hf_ckpt_pretrained_encoder_decoder_with_gen(self) -> None:
176-
with tempfile.TemporaryDirectory() as tmp_dir:
177-
model_path = f"{tmp_dir}/hf_t5_small_gen"
178-
179-
t5_small_gen = T5ForConditionalGeneration.from_pretrained("t5-small")
180-
t5_small_gen.save_pretrained(model_path)
181-
182-
our_t5 = T5Bundle.build_model_from_huggingface_ckpt(model_path)
183-
184-
hf_output = t5_small_gen(
185-
input_ids=self.encoder_input_ids,
186-
decoder_input_ids=self.decoder_input_ids,
187-
attention_mask=self.encoder_padding_mask,
188-
decoder_attention_mask=self.decoder_padding_mask,
189-
output_hidden_states=True,
190-
output_attentions=True,
191-
)
192-
193-
our_output = our_t5(self.encoder_input_ids, self.decoder_input_ids)
194-
195-
self.check_outputs_of_models(our_output, hf_output, our_t5.config, False)
196-
197-
def test_flan_t5_bundler_load_hf_ckpt_pretrained_encoder_decoder(self) -> None:
198-
# TODO(joecummings): Download FLAN-T5 chkpts and test here
199-
pass
137+
self.decoder_padding_mask = torch.tensor(
138+
[[False, False, False, True, True, True], [False, False, False, True, True, True]]
139+
)
140+
141+
def test_t5_bundler_load_hf_ckpt_pretrained(self) -> None:
142+
with tempfile.TemporaryDirectory() as tmp:
143+
local_path = f"{tmp}/{self.model}"
144+
remote_bucket = f"{_TEXT_BUCKET}test_models"
145+
146+
os.mkdir(local_path)
147+
148+
for f in {"config.json", "pytorch_model.bin"}:
149+
destination = f"{local_path}/{f}"
150+
remote_path = f"{remote_bucket}/{self.model}/{f}"
151+
_TEST_DOWNLOAD_MANAGER.get_local_path(url=remote_path, destination=destination)
152+
153+
names = self.model.split("_")
154+
is_encoder_only = names[-1] == "encoder"
155+
156+
model = T5Bundle.build_model_from_huggingface_ckpt(local_path, encoder_only=is_encoder_only)
157+
if is_encoder_only:
158+
model(self.encoder_input_ids, encoder_padding_mask=self.encoder_padding_mask)
159+
else:
160+
model(
161+
self.encoder_input_ids,
162+
self.decoder_input_ids,
163+
encoder_padding_mask=self.encoder_padding_mask,
164+
decoder_padding_mask=self.decoder_padding_mask,
165+
)
Binary file not shown.
Binary file not shown.
Binary file not shown.

torchtext/_download_hooks.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,4 @@ def get_local_path(self, url, destination):
5959

6060

6161
_DATASET_DOWNLOAD_MANAGER = DownloadManager()
62+
_TEST_DOWNLOAD_MANAGER = DownloadManager()

torchtext/models/t5/__init__.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
11
from .bundler import (
2-
FLAN_T5_SMALL_ENCODER,
3-
FLAN_T5_SMALL,
4-
FLAN_T5_SMALL_GENERATION,
52
FLAN_T5_BASE_ENCODER,
63
FLAN_T5_BASE,
74
FLAN_T5_BASE_GENERATION,
@@ -53,9 +50,6 @@
5350
"T5_11B_ENCODER",
5451
"T5_11B",
5552
"T5_11B_GENERATION",
56-
"FLAN_T5_SMALL_ENCODER",
57-
"FLAN_T5_SMALL",
58-
"FLAN_T5_SMALL_GENERATION",
5953
"FLAN_T5_BASE_ENCODER",
6054
"FLAN_T5_BASE",
6155
"FLAN_T5_BASE_GENERATION",

torchtext/models/t5/bundler.py

Lines changed: 6 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def build_model_from_huggingface_ckpt(
155155
"""Build T5Model model from a HuggingFace checkpoint.
156156
157157
Note: Only works with Huggingface models saved in the PyTorch format. Will not work with TensorFlow or JAX.
158+
This also requires a fully saved model, sharded checkpoints are not supported.
158159
159160
Args:
160161
ckpt_path (str, Path): Path to the HF checkpoint file. Assumes that the file is local.
@@ -238,12 +239,12 @@ def build_model_from_huggingface_ckpt(
238239

239240
for i in range(config.num_decoder_layers):
240241
if config.is_gated_act:
241-
t5_model_state_dict[f"encoder.layers.{i}.linear1_0.weight"] = hf_weights[
242-
f"decoder.block.{i}.layer.1.DenseReluDense.wi_0.weight"
242+
t5_model_state_dict[f"decoder.layers.{i}.linear1_0.weight"] = hf_weights[
243+
f"decoder.block.{i}.layer.2.DenseReluDense.wi_0.weight"
243244
]
244245

245-
t5_model_state_dict[f"encoder.layers.{i}.linear1_1.weight"] = hf_weights[
246-
f"decoder.block.{i}.layer.1.DenseReluDense.wi_1.weight"
246+
t5_model_state_dict[f"decoder.layers.{i}.linear1_1.weight"] = hf_weights[
247+
f"decoder.block.{i}.layer.2.DenseReluDense.wi_1.weight"
247248
]
248249
else:
249250
t5_model_state_dict[f"decoder.layers.{i}.linear1.weight"] = hf_weights[
@@ -650,56 +651,6 @@ def t5_transform() -> T5Transform:
650651

651652
T5_11B_GENERATION.__doc__ = GENERATION_DOC.format("11B", "11B")
652653

653-
654-
FLAN_T5_SMALL_ENCODER = T5Bundle(
655-
_path=urljoin(_TEXT_BUCKET, "t5.flan.small.encoder.pt"),
656-
_config=T5Conf(
657-
encoder_only=True,
658-
embedding_dim=512,
659-
num_attention_heads=6,
660-
num_encoder_layers=8,
661-
num_decoder_layers=8,
662-
ffn_dimension=1024,
663-
feed_forward_proj="gated-gelu",
664-
),
665-
transform=t5_transform,
666-
)
667-
668-
FLAN_T5_SMALL_ENCODER.__doc__ = FLAN_ENCODER_DOC.format("SMALL", "SMALL")
669-
670-
FLAN_T5_SMALL = T5Bundle(
671-
_path=urljoin(_TEXT_BUCKET, "t5.flan.small.pt"),
672-
_config=T5Conf(
673-
encoder_only=False,
674-
embedding_dim=512,
675-
num_attention_heads=6,
676-
num_encoder_layers=8,
677-
num_decoder_layers=8,
678-
ffn_dimension=1024,
679-
feed_forward_proj="gated-gelu",
680-
),
681-
transform=t5_transform,
682-
)
683-
684-
FLAN_T5_SMALL.__doc__ = FLAN_DOC.format("SMALL", "SMALL")
685-
686-
FLAN_T5_SMALL_GENERATION = T5Bundle(
687-
_path=urljoin(_TEXT_BUCKET, "t5.flan.small.generation.pt"),
688-
_config=T5Conf(
689-
encoder_only=False,
690-
linear_head=True,
691-
embedding_dim=512,
692-
num_attention_heads=6,
693-
num_encoder_layers=8,
694-
num_decoder_layers=8,
695-
ffn_dimension=1024,
696-
feed_forward_proj="gated-gelu",
697-
),
698-
transform=t5_transform,
699-
)
700-
701-
FLAN_T5_SMALL_GENERATION.__doc__ = FLAN_GENERATION_DOC.format("SMALL", "SMALL")
702-
703654
FLAN_T5_BASE_ENCODER = T5Bundle(
704655
_path=urljoin(_TEXT_BUCKET, "t5.flan.base.encoder.pt"),
705656
_config=T5Conf(encoder_only=True, ffn_dimension=2048, feed_forward_proj="gated-gelu"),
@@ -762,7 +713,7 @@ def t5_transform() -> T5Transform:
762713

763714

764715
FLAN_T5_LARGE_GENERATION = T5Bundle(
765-
_path=urljoin(_TEXT_BUCKET, "t5.flan.large.encoder.pt"),
716+
_path=urljoin(_TEXT_BUCKET, "t5.flan.large.generation.pt"),
766717
_config=T5Conf(
767718
encoder_only=False,
768719
linear_head=True,

0 commit comments

Comments
 (0)