Skip to content

Commit 79af4ae

Browse files
committed
Add beam search generation w/ Flashlight Text
1 parent bc57394 commit 79af4ae

File tree

4 files changed

+641
-41
lines changed

4 files changed

+641
-41
lines changed

notebooks/hf_with_torchtext_gen.ipynb

Lines changed: 121 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
"name": "stderr",
1717
"output_type": "stream",
1818
"text": [
19-
"/opt/homebrew/Caskroom/miniforge/base/envs/torchtext39/lib/python3.9/site-packages/tqdm-4.64.0-py3.9.egg/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
19+
"/data/home/jrcummings/miniconda/envs/torchtext/lib/python3.9/site-packages/tqdm-4.64.1-py3.9.egg/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
2020
" from .autonotebook import tqdm as notebook_tqdm\n"
2121
]
2222
}
@@ -39,14 +39,14 @@
3939
},
4040
{
4141
"cell_type": "code",
42-
"execution_count": 3,
42+
"execution_count": 5,
4343
"metadata": {},
4444
"outputs": [
4545
{
4646
"name": "stderr",
4747
"output_type": "stream",
4848
"text": [
49-
"/opt/homebrew/Caskroom/miniforge/base/envs/torchtext39/lib/python3.9/site-packages/transformers/models/t5/tokenization_t5.py:164: FutureWarning: This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5.\n",
49+
"/data/home/jrcummings/miniconda/envs/torchtext/lib/python3.9/site-packages/transformers/models/t5/tokenization_t5.py:163: FutureWarning: This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5.\n",
5050
"For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.\n",
5151
"- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.\n",
5252
"- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.\n",
@@ -74,7 +74,55 @@
7474
},
7575
{
7676
"cell_type": "code",
77-
"execution_count": 4,
77+
"execution_count": 6,
78+
"metadata": {},
79+
"outputs": [
80+
{
81+
"name": "stdout",
82+
"output_type": "stream",
83+
"text": [
84+
"['a dog is good for you. studies have shown that dog ownership is good for your overall health and well-being.']\n"
85+
]
86+
}
87+
],
88+
"source": [
89+
"# Testing HuggingFace's T5 w/ Beam Search\n",
90+
"tokens = generative_hf_t5.generate(test_sequence_tk, max_len=100, pad_idx=t5.config.pad_token_id, num_beams=5, beam_size_token=t5.config.vocab_size)\n",
91+
"print(t5_tokenizer.batch_decode(tokens, skip_special_tokens=True))"
92+
]
93+
},
94+
{
95+
"cell_type": "code",
96+
"execution_count": 7,
97+
"metadata": {},
98+
"outputs": [
99+
{
100+
"name": "stdout",
101+
"output_type": "stream",
102+
"text": [
103+
"['a dog is good for you. studies have shown that dog ownership is good for your overall health and well-being.'] 9.786320924758911\n",
104+
"['studies have shown that owning a dog is good for you. studies have shown that owning a dog is good for you.'] 1.3000121116638184\n"
105+
]
106+
}
107+
],
108+
"source": [
109+
"# Testing Decoding Speed HuggingFace's T5 w/ TorchText Beam Search vs. HuggingFace Beam Search\n",
110+
"import time\n",
111+
"\n",
112+
"start = time.time()\n",
113+
"tokens = generative_hf_t5.generate(test_sequence_tk, max_len=100, pad_idx=t5.config.pad_token_id, num_beams=5, beam_size_token=t5.config.vocab_size)\n",
114+
"end = time.time()\n",
115+
"print(t5_tokenizer.batch_decode(tokens, skip_special_tokens=True), end - start)\n",
116+
"\n",
117+
"start = time.time()\n",
118+
"tokens = t5.generate(test_sequence_tk, max_length=100, num_beams=5, do_sample=False)\n",
119+
"end = time.time()\n",
120+
"print(t5_tokenizer.batch_decode(tokens, skip_special_tokens=True), end - start)"
121+
]
122+
},
123+
{
124+
"cell_type": "code",
125+
"execution_count": 8,
78126
"metadata": {},
79127
"outputs": [
80128
{
@@ -99,7 +147,54 @@
99147
},
100148
{
101149
"cell_type": "code",
102-
"execution_count": 5,
150+
"execution_count": 9,
151+
"metadata": {},
152+
"outputs": [
153+
{
154+
"name": "stdout",
155+
"output_type": "stream",
156+
"text": [
157+
"['Nearly. PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions.']\n"
158+
]
159+
}
160+
],
161+
"source": [
162+
"tokens = generative_hf_bart.generate(test_sequence_tk, max_len=20, pad_idx=bart.config.pad_token_id, num_beams=5, beam_size_token=bart.config.vocab_size)\n",
163+
"print(bart_tokenizer.batch_decode(tokens, skip_special_tokens=True))\n"
164+
]
165+
},
166+
{
167+
"cell_type": "code",
168+
"execution_count": 10,
169+
"metadata": {},
170+
"outputs": [
171+
{
172+
"name": "stdout",
173+
"output_type": "stream",
174+
"text": [
175+
"['PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs. The blackouts are expected to last through at least midday tomorrow. to be affected by the shutoffs which were expected to last through at least midday tomorrow. to be affected by the shutoffs which were expected to last through at least midday tomorrow. to be affected by the'] 58.09997892379761\n",
176+
"['PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs. The blackouts were expected to last through at least midday tomorrow.'] 2.456479787826538\n"
177+
]
178+
}
179+
],
180+
"source": [
181+
"# Testing Decoding Speed HuggingFace's BART w/ TorchText Beam Search vs. HuggingFace Beam Search\n",
182+
"import time\n",
183+
"\n",
184+
"start = time.time()\n",
185+
"tokens = generative_hf_bart.generate(test_sequence_tk, max_len=100, pad_idx=t5.config.pad_token_id, num_beams=5, eos_score=1.0, beam_size_token=t5.config.vocab_size)\n",
186+
"end = time.time()\n",
187+
"print(bart_tokenizer.batch_decode(tokens, skip_special_tokens=True), end - start)\n",
188+
"\n",
189+
"start = time.time()\n",
190+
"tokens = bart.generate(test_sequence_tk, max_length=100, num_beams=5, do_sample=False)\n",
191+
"end = time.time()\n",
192+
"print(bart_tokenizer.batch_decode(tokens, skip_special_tokens=True), end - start)"
193+
]
194+
},
195+
{
196+
"cell_type": "code",
197+
"execution_count": 3,
103198
"metadata": {},
104199
"outputs": [
105200
{
@@ -119,11 +214,29 @@
119214
"tokens = generative_hf_gpt2.generate(test_sequence_tk, max_len=20, pad_idx=gpt2.config.pad_token_id)\n",
120215
"print(gpt2_tokenizer.batch_decode(tokens, skip_special_tokens=True))"
121216
]
217+
},
218+
{
219+
"cell_type": "code",
220+
"execution_count": 4,
221+
"metadata": {},
222+
"outputs": [
223+
{
224+
"name": "stdout",
225+
"output_type": "stream",
226+
"text": [
227+
"['I enjoy walking with my cute dog,\" says Kelli Williams-Petersen. The dog loves it so much, that when she']\n"
228+
]
229+
}
230+
],
231+
"source": [
232+
"tokens = generative_hf_gpt2.generate(test_sequence_tk, max_len=20, pad_idx=gpt2.config.pad_token_id, num_beams=5, beam_size_token=gpt2.config.vocab_size)\n",
233+
"print(gpt2_tokenizer.batch_decode(tokens, skip_special_tokens=True))"
234+
]
122235
}
123236
],
124237
"metadata": {
125238
"kernelspec": {
126-
"display_name": "Python 3.9.13 ('torchtext39')",
239+
"display_name": "torchtext",
127240
"language": "python",
128241
"name": "python3"
129242
},
@@ -137,12 +250,12 @@
137250
"name": "python",
138251
"nbconvert_exporter": "python",
139252
"pygments_lexer": "ipython3",
140-
"version": "3.9.13"
253+
"version": "3.9.15"
141254
},
142255
"orig_nbformat": 4,
143256
"vscode": {
144257
"interpreter": {
145-
"hash": "63c8862cb56f124e3ee7674b73de745eeb216416a9b24f78d1fcb7c775bff1b7"
258+
"hash": "1851d106532ddfc6fbd983b9ae95397243fcc3930d811046c990ea169e960650"
146259
}
147260
}
148261
},

test/integration_tests/test_generate.py

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def setUp(self) -> None:
2828
def test_greedy_generate_with_t5(self) -> None:
2929
generation_model = GenerationUtils(self.model)
3030

31-
tokens = generation_model.generate(self.inputs, num_beams=1, max_length=30)
31+
tokens = generation_model.generate(self.inputs, num_beams=1)
3232
generated_text = self.transform.decode(tokens.tolist())
3333

3434
expected_generated_text = [
@@ -41,13 +41,97 @@ def test_greedy_generate_with_t5(self) -> None:
4141

4242
self.assertEqual(generated_text, expected_generated_text)
4343

44+
def test_beam_search_generate_t5(self) -> None:
45+
generation_model = GenerationUtils(self.model)
46+
47+
tokens = generation_model.generate(
48+
self.inputs, num_beams=3, vocab_size=self.model.config.vocab_size, max_length=30
49+
)
50+
generated_text = self.transform.decode(tokens.tolist())
51+
52+
expected_generated_text = [
53+
"kate mccartney: a dog is good for you . she says studies have shown that dog ownership is good for",
54+
"Das ist gut.",
55+
"acceptable",
56+
"4.0",
57+
"a tornado ripped through a swath of a lake in southeastern michigan . a spokesman",
58+
]
59+
60+
self.assertEqual(generated_text, expected_generated_text)
61+
62+
def test_beam_search_generate_t5_small_batch_size(self) -> None:
63+
generation_model = GenerationUtils(self.model)
64+
65+
tokens = generation_model.generate(
66+
self.inputs, num_beams=3, vocab_size=self.model.config.vocab_size, max_length=30, max_inference_batch_size=3
67+
)
68+
generated_text = self.transform.decode(tokens.tolist())
69+
70+
expected_generated_text = [
71+
"kate mccartney: a dog is good for you . she says studies have shown that dog ownership is good for",
72+
"Das ist gut.",
73+
"acceptable",
74+
"4.0",
75+
"a tornado ripped through a swath of a lake in southeastern michigan . a spokesman",
76+
]
77+
78+
self.assertEqual(generated_text, expected_generated_text)
79+
80+
def test_beam_search_generate_t5_with_small_beam_threshold(self) -> None:
81+
generation_model = GenerationUtils(self.model)
82+
83+
tokens = generation_model.generate(
84+
self.inputs, num_beams=3, vocab_size=self.model.config.vocab_size, max_length=30, beam_threshold=5
85+
)
86+
generated_text = self.transform.decode(tokens.tolist())
87+
88+
expected_text = [
89+
"kate mccartney: a dog is good for you . kate mccartney: dogs",
90+
"Das ist gut.",
91+
"acceptable",
92+
"4.0",
93+
"a tornado ripped through a swath of a lake in southeastern mississippi, causing",
94+
]
95+
96+
self.assertEqual(generated_text, expected_text)
97+
98+
def test_beam_search_generate_t5_large_num_beams(self) -> None:
99+
generation_model = GenerationUtils(self.model)
100+
101+
tokens = generation_model.generate(
102+
self.inputs, num_beams=25, vocab_size=self.model.config.vocab_size, max_length=30
103+
)
104+
generated_text = self.transform.decode(tokens.tolist())
105+
106+
expected_text = [
107+
"aaron carroll, aaron jones, aaron jones and aaron jones",
108+
"Das ist gut.",
109+
"acceptable",
110+
"4.0",
111+
"a blizzard and power outages have prompted a blizzard and power outages, a spokesman says",
112+
]
113+
114+
self.assertEqual(generated_text, expected_text)
115+
116+
def test_beam_search_generate_t5_large_num_beams_eos_score(self) -> None:
117+
generation_model = GenerationUtils(self.model)
118+
119+
tokens = generation_model.generate(
120+
self.inputs, num_beams=25, vocab_size=self.model.config.vocab_size, max_length=30, eos_score=10.0
121+
)
122+
generated_text = self.transform.decode(tokens.tolist())
123+
124+
expected_text = ["", "Das ist gut.", "acceptable", "4.0", ""]
125+
126+
self.assertEqual(generated_text, expected_text)
127+
44128
def test_generate_errors_with_incorrect_beams(self) -> None:
45129
generation_model = GenerationUtils(self.model, is_encoder_decoder=True)
46130

47131
with self.assertRaises(ValueError):
48132
generation_model.generate(self.inputs, num_beams=0)
49133

50-
@patch("logging.Logger.warning")
134+
@patch("warnings.warn")
51135
def test_warns_when_no_max_len_provided(self, mock) -> None:
52136
generation_model = GenerationUtils(self.model)
53137
generation_model.generate(self.inputs)

0 commit comments

Comments
 (0)