|
1 | 1 | import json
|
2 | 2 | import random
|
3 | 3 | from collections.abc import Iterable, Iterator
|
| 4 | +from itertools import cycle |
4 | 5 | from pathlib import Path
|
5 | 6 | from typing import Any, Literal, Optional, Union
|
6 | 7 |
|
|
25 | 26 |
|
26 | 27 |
|
27 | 28 | class SyntheticDatasetConfig(BaseModel):
|
| 29 | + prefix_tokens: int = Field( |
| 30 | + description="The number of shared prefix tokens to prepend to each prompt.", |
| 31 | + ge=0, |
| 32 | + default=0, |
| 33 | + ) |
28 | 34 | prompt_tokens: int = Field(
|
29 | 35 | description="The average number of text tokens generated for prompts.",
|
30 | 36 | gt=0,
|
@@ -163,39 +169,54 @@ def __iter__(
|
163 | 169 | )
|
164 | 170 | # ensure diff distribution from output tokens
|
165 | 171 | rand = random.Random(self.random_seed + 2) # noqa: S311
|
| 172 | + unique_prefix_iter = cycle(self.processor.get_vocab().values()) |
| 173 | + |
| 174 | + prefix_index = rand.randint(0, len(self.text_creator.words)) |
| 175 | + prefix_tokens = self._create_prompt(self.config.prefix_tokens, prefix_index) |
166 | 176 |
|
167 | 177 | for _, prompt_tokens, output_tokens in zip(
|
168 | 178 | range(self.config.samples),
|
169 | 179 | prompt_tokens_sampler,
|
170 | 180 | output_tokens_sampler,
|
171 | 181 | ):
|
172 | 182 | start_index = rand.randint(0, len(self.text_creator.words))
|
| 183 | + prompt_text = self.processor.decode( |
| 184 | + prefix_tokens |
| 185 | + + self._create_prompt( |
| 186 | + prompt_tokens, start_index, next(unique_prefix_iter) |
| 187 | + ), |
| 188 | + skip_special_tokens=True, |
| 189 | + ) |
173 | 190 | yield {
|
174 |
| - "prompt": self._create_prompt(prompt_tokens, start_index), |
175 |
| - "prompt_tokens_count": prompt_tokens, |
| 191 | + "prompt": prompt_text, |
| 192 | + "prompt_tokens_count": self.config.prefix_tokens + prompt_tokens, |
176 | 193 | "output_tokens_count": output_tokens,
|
177 | 194 | }
|
178 | 195 |
|
179 |
| - def _create_prompt(self, prompt_tokens: int, start_index: int) -> str: |
| 196 | + def _create_prompt( |
| 197 | + self, prompt_tokens: int, start_index: int, unique_prefix: Optional[int] = None |
| 198 | + ) -> list[int]: |
180 | 199 | if prompt_tokens <= 0:
|
181 |
| - return "" |
| 200 | + return [] |
182 | 201 |
|
183 | 202 | left = start_index
|
184 | 203 | right = start_index + 4 * prompt_tokens
|
| 204 | + start_tokens = [unique_prefix] if unique_prefix else [] |
185 | 205 |
|
186 | 206 | while left < right:
|
187 | 207 | mid = (left + right) // 2
|
188 | 208 | test_prompt = self.text_creator.create_text(start_index, mid - start_index)
|
189 |
| - test_tokens = len(self.processor.tokenize(test_prompt)) |
| 209 | + test_tokens = start_tokens + self.processor.encode(test_prompt) |
190 | 210 |
|
191 |
| - if test_tokens == prompt_tokens: |
192 |
| - return test_prompt |
193 |
| - elif test_tokens < prompt_tokens: |
| 211 | + if len(test_tokens) == prompt_tokens: |
| 212 | + return test_tokens |
| 213 | + elif len(test_tokens) < prompt_tokens: |
194 | 214 | left = mid + 1
|
195 | 215 | else:
|
196 | 216 | right = mid
|
197 | 217 |
|
198 |
| - return self.text_creator.create_text(start_index, left - start_index) |
| 218 | + final_text = self.text_creator.create_text(start_index, left - start_index) |
| 219 | + return start_tokens + self.processor.encode(final_text) |
199 | 220 |
|
200 | 221 |
|
201 | 222 | class SyntheticDatasetCreator(DatasetCreator):
|
|
0 commit comments