Skip to content

Commit 69f6575

Browse files
committed
Add advenced shared prefix support
Signed-off-by: Samuel Monson <smonson@redhat.com>
1 parent da29a71 commit 69f6575

File tree

2 files changed

+57
-8
lines changed

2 files changed

+57
-8
lines changed

src/guidellm/dataset/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .hf_datasets import HFDatasetsCreator
55
from .in_memory import InMemoryDatasetCreator
66
from .synthetic import (
7+
PrefixBucketConfig,
78
SyntheticDatasetConfig,
89
SyntheticDatasetCreator,
910
SyntheticTextItemsGenerator,
@@ -15,6 +16,7 @@
1516
"FileDatasetCreator",
1617
"HFDatasetsCreator",
1718
"InMemoryDatasetCreator",
19+
"PrefixBucketConfig",
1820
"SyntheticDatasetConfig",
1921
"SyntheticDatasetCreator",
2022
"SyntheticTextItemsGenerator",

src/guidellm/dataset/synthetic.py

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22
import random
3-
from collections.abc import Iterable, Iterator
3+
from collections.abc import Iterable, Iterator, Sequence
44
from itertools import cycle
55
from pathlib import Path
66
from typing import Any, Literal, Optional, Union
@@ -19,18 +19,36 @@
1919
from guidellm.utils import EndlessTextCreator, IntegerRangeSampler, check_load_processor
2020

2121
__all__ = [
22+
"PrefixBucketConfig",
2223
"SyntheticDatasetConfig",
2324
"SyntheticDatasetCreator",
2425
"SyntheticTextItemsGenerator",
2526
]
2627

2728

28-
class SyntheticDatasetConfig(BaseModel):
29+
class PrefixBucketConfig(BaseModel):
30+
bucket_weight: int = Field(
31+
description="Weight of this bucket in the overall distribution.",
32+
gt=0,
33+
default=100,
34+
)
35+
prefix_count: int = Field(
36+
description="The number of unique prefixs to generate for this bucket.",
37+
ge=1,
38+
default=1,
39+
)
2940
prefix_tokens: int = Field(
30-
description="The number of shared prefix tokens to prepend to each prompt.",
41+
description="The number of prefix tokens per-prompt for this bucket.",
3142
ge=0,
3243
default=0,
3344
)
45+
46+
47+
class SyntheticDatasetConfig(BaseModel):
48+
prefix_buckets: Optional[list[PrefixBucketConfig]] = Field(
49+
description="Buckets for the prefix tokens distribution.",
50+
default=None,
51+
)
3452
prompt_tokens: int = Field(
3553
description="The average number of text tokens generated for prompts.",
3654
gt=0,
@@ -169,17 +187,16 @@ def __iter__(
169187
)
170188
# ensure diff distribution from output tokens
171189
rand = random.Random(self.random_seed + 2) # noqa: S311
190+
shared_prefix_iter = iter(self._create_prefixes(rand))
172191
unique_prefix_iter = cycle(self.processor.get_vocab().values())
173192

174-
prefix_index = rand.randint(0, len(self.text_creator.words))
175-
prefix_tokens = self._create_prompt(self.config.prefix_tokens, prefix_index)
176-
177193
for _, prompt_tokens, output_tokens in zip(
178194
range(self.config.samples),
179195
prompt_tokens_sampler,
180196
output_tokens_sampler,
181197
):
182-
start_index = rand.randint(0, len(self.text_creator.words))
198+
start_index = self._rand_start_index(rand)
199+
prefix_tokens = next(shared_prefix_iter, [])
183200
prompt_text = self.processor.decode(
184201
prefix_tokens
185202
+ self._create_prompt(
@@ -189,10 +206,40 @@ def __iter__(
189206
)
190207
yield {
191208
"prompt": prompt_text,
192-
"prompt_tokens_count": self.config.prefix_tokens + prompt_tokens,
209+
"prompt_tokens_count": len(prefix_tokens) + prompt_tokens,
193210
"output_tokens_count": output_tokens,
194211
}
195212

213+
def _rand_start_index(self, rand: random.Random) -> int:
214+
"""Generate a random start index for text generation."""
215+
return rand.randint(0, len(self.text_creator.words) - 1)
216+
217+
def _create_prefixes(self, rand: random.Random) -> Sequence[list[int]]:
218+
"""Create an iterator for shared prefix tokens."""
219+
buckets = self.config.prefix_buckets
220+
221+
if not buckets:
222+
return []
223+
224+
total_weight = sum(bucket.bucket_weight for bucket in buckets)
225+
if total_weight <= 0:
226+
raise ValueError("Total weight of prefix buckets must be greater than 0.")
227+
228+
prompts = []
229+
for bucket in buckets:
230+
for _ in range(bucket.prefix_count):
231+
start_index = self._rand_start_index(rand)
232+
prompt_tokens = self._create_prompt(bucket.prefix_tokens, start_index)
233+
sample_percent = (
234+
bucket.bucket_weight / bucket.prefix_count / total_weight
235+
)
236+
sample_count = sample_percent * self.config.samples
237+
for _ in range(int(round(sample_count))):
238+
prompts.append(prompt_tokens)
239+
240+
rand.shuffle(prompts)
241+
return prompts
242+
196243
def _create_prompt(
197244
self, prompt_tokens: int, start_index: int, unique_prefix: Optional[int] = None
198245
) -> list[int]:

0 commit comments

Comments
 (0)