1
1
import json
2
2
import random
3
- from collections .abc import Iterable , Iterator
3
+ from collections .abc import Iterable , Iterator , Sequence
4
4
from itertools import cycle
5
5
from pathlib import Path
6
6
from typing import Any , Literal , Optional , Union
19
19
from guidellm .utils import EndlessTextCreator , IntegerRangeSampler , check_load_processor
20
20
21
21
__all__ = [
22
+ "PrefixBucketConfig" ,
22
23
"SyntheticDatasetConfig" ,
23
24
"SyntheticDatasetCreator" ,
24
25
"SyntheticTextItemsGenerator" ,
25
26
]
26
27
27
28
28
- class SyntheticDatasetConfig (BaseModel ):
29
+ class PrefixBucketConfig (BaseModel ):
30
+ bucket_weight : int = Field (
31
+ description = "Weight of this bucket in the overall distribution." ,
32
+ gt = 0 ,
33
+ default = 100 ,
34
+ )
35
+ prefix_count : int = Field (
36
+ description = "The number of unique prefixs to generate for this bucket." ,
37
+ ge = 1 ,
38
+ default = 1 ,
39
+ )
29
40
prefix_tokens : int = Field (
30
- description = "The number of shared prefix tokens to prepend to each prompt ." ,
41
+ description = "The number of prefix tokens per-prompt for this bucket ." ,
31
42
ge = 0 ,
32
43
default = 0 ,
33
44
)
45
+
46
+
47
+ class SyntheticDatasetConfig (BaseModel ):
48
+ prefix_buckets : Optional [list [PrefixBucketConfig ]] = Field (
49
+ description = "Buckets for the prefix tokens distribution." ,
50
+ default = None ,
51
+ )
34
52
prompt_tokens : int = Field (
35
53
description = "The average number of text tokens generated for prompts." ,
36
54
gt = 0 ,
@@ -169,17 +187,16 @@ def __iter__(
169
187
)
170
188
# ensure diff distribution from output tokens
171
189
rand = random .Random (self .random_seed + 2 ) # noqa: S311
190
+ shared_prefix_iter = iter (self ._create_prefixes (rand ))
172
191
unique_prefix_iter = cycle (self .processor .get_vocab ().values ())
173
192
174
- prefix_index = rand .randint (0 , len (self .text_creator .words ))
175
- prefix_tokens = self ._create_prompt (self .config .prefix_tokens , prefix_index )
176
-
177
193
for _ , prompt_tokens , output_tokens in zip (
178
194
range (self .config .samples ),
179
195
prompt_tokens_sampler ,
180
196
output_tokens_sampler ,
181
197
):
182
- start_index = rand .randint (0 , len (self .text_creator .words ))
198
+ start_index = self ._rand_start_index (rand )
199
+ prefix_tokens = next (shared_prefix_iter , [])
183
200
prompt_text = self .processor .decode (
184
201
prefix_tokens
185
202
+ self ._create_prompt (
@@ -189,10 +206,40 @@ def __iter__(
189
206
)
190
207
yield {
191
208
"prompt" : prompt_text ,
192
- "prompt_tokens_count" : self . config . prefix_tokens + prompt_tokens ,
209
+ "prompt_tokens_count" : len ( prefix_tokens ) + prompt_tokens ,
193
210
"output_tokens_count" : output_tokens ,
194
211
}
195
212
213
+ def _rand_start_index (self , rand : random .Random ) -> int :
214
+ """Generate a random start index for text generation."""
215
+ return rand .randint (0 , len (self .text_creator .words ) - 1 )
216
+
217
+ def _create_prefixes (self , rand : random .Random ) -> Sequence [list [int ]]:
218
+ """Create an iterator for shared prefix tokens."""
219
+ buckets = self .config .prefix_buckets
220
+
221
+ if not buckets :
222
+ return []
223
+
224
+ total_weight = sum (bucket .bucket_weight for bucket in buckets )
225
+ if total_weight <= 0 :
226
+ raise ValueError ("Total weight of prefix buckets must be greater than 0." )
227
+
228
+ prompts = []
229
+ for bucket in buckets :
230
+ for _ in range (bucket .prefix_count ):
231
+ start_index = self ._rand_start_index (rand )
232
+ prompt_tokens = self ._create_prompt (bucket .prefix_tokens , start_index )
233
+ sample_percent = (
234
+ bucket .bucket_weight / bucket .prefix_count / total_weight
235
+ )
236
+ sample_count = sample_percent * self .config .samples
237
+ for _ in range (int (round (sample_count ))):
238
+ prompts .append (prompt_tokens )
239
+
240
+ rand .shuffle (prompts )
241
+ return prompts
242
+
196
243
def _create_prompt (
197
244
self , prompt_tokens : int , start_index : int , unique_prefix : Optional [int ] = None
198
245
) -> list [int ]:
0 commit comments