Skip to content

Commit ec328c1

Browse files
sjmonsonMML-codermarkurtz
authored
Synthetic Dataset: Prefix Caching Controls (#183)
## Summary <!-- Include a short paragraph of the changes introduced in this PR. If this PR requires additional context or rationale, explain why the changes are necessary. --> Work to allow control of token prefix cache rates with the synthetic data generator. Firstly adds an auto-incrementing single token prefix to ensure we never repeat the same prefix. Secondly adds controls for sharing a fixed prefix between samples. ## Details <!-- Provide a detailed list of all changes introduced in this pull request. --> ### 1. Ensure every prompt is unique When generating a prompt, the first token is now taken from an iterator over the tokenizer vocab. ### 2. Add configurable prefix to simulate system prompt or other common token prefixes Example usage: ```yaml data: prefix_tokens: 2048 prompt_tokens: 256, output_tokens: 256, samples: 1024 ``` ## Test Plan <!-- List the steps needed to test this PR. --> - PR includes unit tests for all synthetic dataset changes (`pytest tests/unit/dataset`) - Scenario in the Details section can be used against a model server with prefix caching and the cache rate can be confirmed by inspecting console output. ## Related Issues <!-- Link any relevant issues that this PR addresses. --> - Resolves #104 - Resolves #186 --- - [x] "I certify that all code in this PR is my own, except as noted below." ## Use of AI - [x] Includes AI-assisted code completion - [ ] Includes code generated by an AI application - [x] Includes AI-generated tests (NOTE: AI written tests should have a docstring that includes `## WRITTEN BY AI ##`) --------- Signed-off-by: Samuel Monson <smonson@redhat.com> Co-authored-by: Mehul <MEHTMEHUL@GMAIL.COM> Co-authored-by: Mark Kurtz <mark.j.kurtz@gmail.com>
1 parent 3e274d3 commit ec328c1

File tree

4 files changed

+904
-9
lines changed

4 files changed

+904
-9
lines changed

docs/datasets.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ guidellm benchmark \
7676
- `output_tokens_stdev`: Standard deviation for output tokens. If not supplied and min/max are not specified, no deviation is applied. If not supplied and min/max are specified, a uniform distribution is used.
7777
- `output_tokens_min`: Minimum number of tokens in outputs. If unset and `output_tokens_stdev` is set, the minimum is 1.
7878
- `output_tokens_max`: Maximum number of tokens in outputs. If unset and `output_tokens_stdev` is set, the maximum is 5 times the standard deviation.
79+
- `prefix_tokens`: Number of tokens to share as a prefix across all prompts. Is additive to the prompt tokens distribution so each request is `prefix_tokens + prompt_tokens_sample()`. If unset, defaults to 0.
7980
- `samples`: Number of samples to generate (default: 1000). More samples will increase the time taken to generate the dataset before benchmarking, but will also decrease the likelihood of caching requests.
8081
- `source`: Source text for generation (default: `data:prideandprejudice.txt.gz`). This can be any text file, URL containing a text file, or a compressed text file. The text is used to sample from at a word and punctuation granularity and then combined into a single string of the desired lengths.
8182

src/guidellm/dataset/synthetic.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
import random
33
from collections.abc import Iterable, Iterator
4+
from itertools import cycle
45
from pathlib import Path
56
from typing import Any, Literal, Optional, Union
67

@@ -25,6 +26,11 @@
2526

2627

2728
class SyntheticDatasetConfig(BaseModel):
29+
prefix_tokens: int = Field(
30+
description="The number of shared prefix tokens to prepend to each prompt.",
31+
ge=0,
32+
default=0,
33+
)
2834
prompt_tokens: int = Field(
2935
description="The average number of text tokens generated for prompts.",
3036
gt=0,
@@ -163,39 +169,54 @@ def __iter__(
163169
)
164170
# ensure diff distribution from output tokens
165171
rand = random.Random(self.random_seed + 2) # noqa: S311
172+
unique_prefix_iter = cycle(self.processor.get_vocab().values())
173+
174+
prefix_index = rand.randint(0, len(self.text_creator.words))
175+
prefix_tokens = self._create_prompt(self.config.prefix_tokens, prefix_index)
166176

167177
for _, prompt_tokens, output_tokens in zip(
168178
range(self.config.samples),
169179
prompt_tokens_sampler,
170180
output_tokens_sampler,
171181
):
172182
start_index = rand.randint(0, len(self.text_creator.words))
183+
prompt_text = self.processor.decode(
184+
prefix_tokens
185+
+ self._create_prompt(
186+
prompt_tokens, start_index, next(unique_prefix_iter)
187+
),
188+
skip_special_tokens=True,
189+
)
173190
yield {
174-
"prompt": self._create_prompt(prompt_tokens, start_index),
175-
"prompt_tokens_count": prompt_tokens,
191+
"prompt": prompt_text,
192+
"prompt_tokens_count": self.config.prefix_tokens + prompt_tokens,
176193
"output_tokens_count": output_tokens,
177194
}
178195

179-
def _create_prompt(self, prompt_tokens: int, start_index: int) -> str:
196+
def _create_prompt(
197+
self, prompt_tokens: int, start_index: int, unique_prefix: Optional[int] = None
198+
) -> list[int]:
180199
if prompt_tokens <= 0:
181-
return ""
200+
return []
182201

183202
left = start_index
184203
right = start_index + 4 * prompt_tokens
204+
start_tokens = [unique_prefix] if unique_prefix else []
185205

186206
while left < right:
187207
mid = (left + right) // 2
188208
test_prompt = self.text_creator.create_text(start_index, mid - start_index)
189-
test_tokens = len(self.processor.tokenize(test_prompt))
209+
test_tokens = start_tokens + self.processor.encode(test_prompt)
190210

191-
if test_tokens == prompt_tokens:
192-
return test_prompt
193-
elif test_tokens < prompt_tokens:
211+
if len(test_tokens) == prompt_tokens:
212+
return test_tokens
213+
elif len(test_tokens) < prompt_tokens:
194214
left = mid + 1
195215
else:
196216
right = mid
197217

198-
return self.text_creator.create_text(start_index, left - start_index)
218+
final_text = self.text_creator.create_text(start_index, left - start_index)
219+
return start_tokens + self.processor.encode(final_text)
199220

200221

201222
class SyntheticDatasetCreator(DatasetCreator):

tests/unit/dataset/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)