From a5d57728480d5ef9c4d215a1ad089bbe942ed62e Mon Sep 17 00:00:00 2001 From: Samuel Monson Date: Tue, 10 Jun 2025 13:54:50 -0400 Subject: [PATCH 1/4] Add fixed prefix option to synthetic data Signed-off-by: Samuel Monson --- src/guidellm/dataset/synthetic.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/guidellm/dataset/synthetic.py b/src/guidellm/dataset/synthetic.py index 9868ab52..854f478f 100644 --- a/src/guidellm/dataset/synthetic.py +++ b/src/guidellm/dataset/synthetic.py @@ -25,6 +25,11 @@ class SyntheticDatasetConfig(BaseModel): + prefix_tokens: int = Field( + description="The number of shared prefix tokens to prepend to each prompt.", + ge=0, + default=0, + ) prompt_tokens: int = Field( description="The average number of text tokens generated for prompts.", gt=0, @@ -164,6 +169,10 @@ def __iter__( # ensure diff distribution from output tokens rand = random.Random(self.random_seed + 2) # noqa: S311 + prefix_index = rand.randint(0, len(self.text_creator.words)) + prefix_tokens = self.config.prefix_tokens + prefix = self._create_prompt(prefix_tokens, prefix_index) + for _, prompt_tokens, output_tokens in zip( range(self.config.samples), prompt_tokens_sampler, @@ -171,8 +180,8 @@ def __iter__( ): start_index = rand.randint(0, len(self.text_creator.words)) yield { - "prompt": self._create_prompt(prompt_tokens, start_index), - "prompt_tokens_count": prompt_tokens, + "prompt": prefix + self._create_prompt(prompt_tokens, start_index), + "prompt_tokens_count": prefix_tokens + prompt_tokens, "output_tokens_count": output_tokens, } From a3eed17215e8588ec288e19443e05c96f1f4a153 Mon Sep 17 00:00:00 2001 From: Samuel Monson Date: Tue, 10 Jun 2025 14:35:18 -0400 Subject: [PATCH 2/4] Add prefix before decode Signed-off-by: Samuel Monson --- src/guidellm/dataset/synthetic.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/src/guidellm/dataset/synthetic.py b/src/guidellm/dataset/synthetic.py index 854f478f..94dd3aa6 100644 --- a/src/guidellm/dataset/synthetic.py +++ b/src/guidellm/dataset/synthetic.py @@ -170,8 +170,7 @@ def __iter__( rand = random.Random(self.random_seed + 2) # noqa: S311 prefix_index = rand.randint(0, len(self.text_creator.words)) - prefix_tokens = self.config.prefix_tokens - prefix = self._create_prompt(prefix_tokens, prefix_index) + prefix_tokens = self._create_prompt(self.config.prefix_tokens, prefix_index) for _, prompt_tokens, output_tokens in zip( range(self.config.samples), @@ -179,15 +178,19 @@ def __iter__( output_tokens_sampler, ): start_index = rand.randint(0, len(self.text_creator.words)) + prompt_text = self.processor.decode( + prefix_tokens + self._create_prompt(prompt_tokens, start_index), + skip_special_tokens=True, + ) yield { - "prompt": prefix + self._create_prompt(prompt_tokens, start_index), - "prompt_tokens_count": prefix_tokens + prompt_tokens, + "prompt": prompt_text, + "prompt_tokens_count": self.config.prefix_tokens + prompt_tokens, "output_tokens_count": output_tokens, } - def _create_prompt(self, prompt_tokens: int, start_index: int) -> str: + def _create_prompt(self, prompt_tokens: int, start_index: int) -> list[int]: if prompt_tokens <= 0: - return "" + return [] left = start_index right = start_index + 4 * prompt_tokens @@ -195,16 +198,17 @@ def _create_prompt(self, prompt_tokens: int, start_index: int) -> str: while left < right: mid = (left + right) // 2 test_prompt = self.text_creator.create_text(start_index, mid - start_index) - test_tokens = len(self.processor.tokenize(test_prompt)) + test_tokens = self.processor.encode(test_prompt) - if test_tokens == prompt_tokens: - return test_prompt - elif test_tokens < prompt_tokens: + if len(test_tokens) == prompt_tokens: + return test_tokens + elif len(test_tokens) < prompt_tokens: left = mid + 1 else: right = mid - return self.text_creator.create_text(start_index, left - start_index) + final_text = self.text_creator.create_text(start_index, left - start_index) + return self.processor.encode(final_text) class SyntheticDatasetCreator(DatasetCreator): From 94a4508e41549f2217dfc82d679741a31102cfc3 Mon Sep 17 00:00:00 2001 From: Samuel Monson Date: Wed, 11 Jun 2025 13:20:28 -0400 Subject: [PATCH 3/4] Document prefix_tokens arg Signed-off-by: Samuel Monson --- docs/datasets.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/datasets.md b/docs/datasets.md index a5d0aa4e..781b23b8 100644 --- a/docs/datasets.md +++ b/docs/datasets.md @@ -76,6 +76,7 @@ guidellm benchmark \ - `output_tokens_stdev`: Standard deviation for output tokens. If not supplied and min/max are not specified, no deviation is applied. If not supplied and min/max are specified, a uniform distribution is used. - `output_tokens_min`: Minimum number of tokens in outputs. If unset and `output_tokens_stdev` is set, the minimum is 1. - `output_tokens_max`: Maximum number of tokens in outputs. If unset and `output_tokens_stdev` is set, the maximum is 5 times the standard deviation. +- `prefix_tokens`: Number of tokens to share as a prefix across all prompts. Is additive to the prompt tokens distribution so each request is `prefix_tokens + prompt_tokens_sample()`. If unset, defaults to 0. - `samples`: Number of samples to generate (default: 1000). More samples will increase the time taken to generate the dataset before benchmarking, but will also decrease the likelihood of caching requests. - `source`: Source text for generation (default: `data:prideandprejudice.txt.gz`). This can be any text file, URL containing a text file, or a compressed text file. The text is used to sample from at a word and punctuation granularity and then combined into a single string of the desired lengths. From 6662be688d8e3aa75245c988bc8b7a942e8d1c34 Mon Sep 17 00:00:00 2001 From: Mehul Date: Tue, 1 Jul 2025 16:48:01 -0400 Subject: [PATCH 4/4] Add unique single-token prefix to every request Co-authored-by: Mehul Co-authored-by: Samuel Monson Signed-off-by: Samuel Monson --- src/guidellm/dataset/synthetic.py | 16 +- tests/unit/dataset/__init__.py | 0 tests/unit/dataset/test_synthetic.py | 634 +++++++++++++++++++++++++++ 3 files changed, 646 insertions(+), 4 deletions(-) create mode 100644 tests/unit/dataset/__init__.py create mode 100644 tests/unit/dataset/test_synthetic.py diff --git a/src/guidellm/dataset/synthetic.py b/src/guidellm/dataset/synthetic.py index 94dd3aa6..8c30f0f7 100644 --- a/src/guidellm/dataset/synthetic.py +++ b/src/guidellm/dataset/synthetic.py @@ -1,6 +1,7 @@ import json import random from collections.abc import Iterable, Iterator +from itertools import cycle from pathlib import Path from typing import Any, Literal, Optional, Union @@ -168,6 +169,7 @@ def __iter__( ) # ensure diff distribution from output tokens rand = random.Random(self.random_seed + 2) # noqa: S311 + unique_prefix_iter = cycle(self.processor.get_vocab().values()) prefix_index = rand.randint(0, len(self.text_creator.words)) prefix_tokens = self._create_prompt(self.config.prefix_tokens, prefix_index) @@ -179,7 +181,10 @@ def __iter__( ): start_index = rand.randint(0, len(self.text_creator.words)) prompt_text = self.processor.decode( - prefix_tokens + self._create_prompt(prompt_tokens, start_index), + prefix_tokens + + self._create_prompt( + prompt_tokens, start_index, next(unique_prefix_iter) + ), skip_special_tokens=True, ) yield { @@ -188,17 +193,20 @@ def __iter__( "output_tokens_count": output_tokens, } - def _create_prompt(self, prompt_tokens: int, start_index: int) -> list[int]: + def _create_prompt( + self, prompt_tokens: int, start_index: int, unique_prefix: Optional[int] = None + ) -> list[int]: if prompt_tokens <= 0: return [] left = start_index right = start_index + 4 * prompt_tokens + start_tokens = [unique_prefix] if unique_prefix else [] while left < right: mid = (left + right) // 2 test_prompt = self.text_creator.create_text(start_index, mid - start_index) - test_tokens = self.processor.encode(test_prompt) + test_tokens = start_tokens + self.processor.encode(test_prompt) if len(test_tokens) == prompt_tokens: return test_tokens @@ -208,7 +216,7 @@ def _create_prompt(self, prompt_tokens: int, start_index: int) -> list[int]: right = mid final_text = self.text_creator.create_text(start_index, left - start_index) - return self.processor.encode(final_text) + return start_tokens + self.processor.encode(final_text) class SyntheticDatasetCreator(DatasetCreator): diff --git a/tests/unit/dataset/__init__.py b/tests/unit/dataset/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/dataset/test_synthetic.py b/tests/unit/dataset/test_synthetic.py new file mode 100644 index 00000000..1979d8a7 --- /dev/null +++ b/tests/unit/dataset/test_synthetic.py @@ -0,0 +1,634 @@ +""" +Unit tests for guidellm.dataset.synthetic module. +""" + +import json +import tempfile +from pathlib import Path + +import pytest +import yaml +from datasets import Dataset +from transformers import AutoTokenizer + +from guidellm.dataset.synthetic import ( + SyntheticDatasetConfig, + SyntheticDatasetCreator, + SyntheticTextItemsGenerator, +) + + +class TestSyntheticDatasetConfig: + """Test cases for SyntheticDatasetConfig class.""" + + def test_config_creation_with_defaults(self): + """Test creating config with default values.""" + config = SyntheticDatasetConfig(prompt_tokens=50, output_tokens=20) + + assert config.prompt_tokens == 50 + assert config.output_tokens == 20 + assert config.samples == 1000 # default + assert config.source == "data:prideandprejudice.txt.gz" # default + assert config.prompt_tokens_stdev is None + assert config.prompt_tokens_min is None + assert config.prompt_tokens_max is None + + def test_config_creation_with_all_params(self): + """Test creating config with all parameters specified.""" + config = SyntheticDatasetConfig( + prompt_tokens=100, + prompt_tokens_stdev=10, + prompt_tokens_min=50, + prompt_tokens_max=150, + output_tokens=30, + output_tokens_stdev=5, + output_tokens_min=20, + output_tokens_max=40, + samples=500, + source="custom_text.txt", + ) + + assert config.prompt_tokens == 100 + assert config.prompt_tokens_stdev == 10 + assert config.prompt_tokens_min == 50 + assert config.prompt_tokens_max == 150 + assert config.output_tokens == 30 + assert config.output_tokens_stdev == 5 + assert config.output_tokens_min == 20 + assert config.output_tokens_max == 40 + assert config.samples == 500 + assert config.source == "custom_text.txt" + + def test_parse_json_string(self): + json_str = json.dumps( + { + "prompt_tokens": 75, + "output_tokens": 25, + "samples": 200, + "source": "test.txt", + } + ) + + config = SyntheticDatasetConfig.parse_str(json_str) + + assert config.prompt_tokens == 75 + assert config.output_tokens == 25 + assert config.samples == 200 + assert config.source == "test.txt" + + def test_parse_key_value_pairs(self): + kv_str = "prompt_tokens=80,output_tokens=30,samples=300,source=data.txt" + + config = SyntheticDatasetConfig.parse_str(kv_str) + + assert config.prompt_tokens == 80 + assert config.output_tokens == 30 + assert config.samples == 300 + assert config.source == "data.txt" + + def test_parse_yaml_file(self): + config_data = { + "prompt_tokens": 60, + "output_tokens": 15, + "samples": 100, + "source": "yaml_test.txt", + } + + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + yaml.dump(config_data, f) + yaml_path = f.name + + try: + config = SyntheticDatasetConfig.parse_str(yaml_path) + + assert config.prompt_tokens == 60 + assert config.output_tokens == 15 + assert config.samples == 100 + assert config.source == "yaml_test.txt" + finally: + Path(yaml_path).unlink() + + def test_parse_config_file(self): + config_data = {"prompt_tokens": 90, "output_tokens": 35, "samples": 150} + + with tempfile.NamedTemporaryFile(mode="w", suffix=".config", delete=False) as f: + yaml.dump(config_data, f) + config_path = f.name + + try: + config = SyntheticDatasetConfig.parse_str(config_path) + + assert config.prompt_tokens == 90 + assert config.output_tokens == 35 + assert config.samples == 150 + finally: + Path(config_path).unlink() + + def test_parse_invalid_format(self): + with pytest.raises(ValueError, match="Unsupported data format"): + SyntheticDatasetConfig.parse_str("invalid_format_string") + + def test_validation_positive_values(self): + """Test that negative values are rejected.""" + with pytest.raises(ValueError): + SyntheticDatasetConfig(prompt_tokens=-1, output_tokens=20) + + with pytest.raises(ValueError): + SyntheticDatasetConfig(prompt_tokens=20, output_tokens=-1) + + with pytest.raises(ValueError): + SyntheticDatasetConfig(prompt_tokens=20, output_tokens=10, samples=-1) + + +class TestSyntheticTextItemsGenerator: + @pytest.fixture + def tokenizer(self): + """Fixture to provide a tokenizer for testing.""" + tokenizer = AutoTokenizer.from_pretrained("gpt2") + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + return tokenizer + + @pytest.fixture + def simple_config(self): + return SyntheticDatasetConfig( + prompt_tokens=15, + output_tokens=10, + samples=5, + source=( + "The quick brown fox jumps over the lazy dog. Machine learning models " + "require diverse training data." + ), + ) + + @pytest.fixture + def complex_config(self): + return SyntheticDatasetConfig( + prompt_tokens=20, + prompt_tokens_stdev=5, + prompt_tokens_min=10, + prompt_tokens_max=30, + output_tokens=15, + output_tokens_stdev=3, + output_tokens_min=10, + output_tokens_max=20, + samples=10, + source=( + "The quick brown fox jumps over the lazy dog. Machine learning models " + "require diverse training data." + ), + ) + + def test_generator_initialization(self, simple_config, tokenizer): + generator = SyntheticTextItemsGenerator( + simple_config, tokenizer, random_seed=42 + ) + + assert generator.config == simple_config + assert generator.processor == tokenizer + assert generator.random_seed == 42 + assert generator.request_counter == 0 + assert generator.text_creator is not None + + def test_basic_prompt_generation(self, simple_config, tokenizer): + generator = SyntheticTextItemsGenerator( + simple_config, tokenizer, random_seed=42 + ) + + items = list(generator) + + # Verify we get the expected number of items + assert len(items) == simple_config.samples + + # Verify each item has the required keys + for item in items: + assert "prompt" in item + assert "prompt_tokens_count" in item + assert "output_tokens_count" in item + + # Verify types + assert isinstance(item["prompt"], str) + assert isinstance(item["prompt_tokens_count"], int) + assert isinstance(item["output_tokens_count"], int) + + # Verify non-empty prompt + assert len(item["prompt"]) > 0 + + def test_unique_prefix_generation(self, simple_config, tokenizer): + generator = SyntheticTextItemsGenerator( + simple_config, tokenizer, random_seed=42 + ) + + items = list(generator) + prompts = [str(item["prompt"]) for item in items] + + # Verify each prompt starts with a unique request ID + for i, prompt in enumerate(prompts, 1): + assert prompt.startswith(f"{i}: "), ( + f"Prompt {i} should start with '{i}: ', got '{prompt[:10]}...'" + ) + + # Verify no two prompts are identical + assert len(set(prompts)) == len(prompts), "All prompts should be unique" + + def test_prefix_caching_prevention(self, simple_config, tokenizer): + """Test that prefix caching is effectively prevented.""" + generator = SyntheticTextItemsGenerator( + simple_config, tokenizer, random_seed=42 + ) + + items = list(generator) + prompts = [str(item["prompt"]) for item in items] + + # Test that no prompt is a prefix of another + for i, prompt1 in enumerate(prompts): + for j, prompt2 in enumerate(prompts): + if i != j: + assert not prompt1.startswith(prompt2), ( + f"Prompt {i} starts with prompt {j}" + ) + assert not prompt2.startswith(prompt1), ( + f"Prompt {j} starts with prompt {i}" + ) + + # Test that first characters are all different + first_chars = [prompt[0] for prompt in prompts] + assert len(set(first_chars)) == len(first_chars), ( + "First characters should all be different" + ) + + def test_token_count_accuracy(self, simple_config, tokenizer): + generator = SyntheticTextItemsGenerator( + simple_config, tokenizer, random_seed=42 + ) + + items = list(generator) + + for item in items: + actual_tokens = len(tokenizer.tokenize(str(item["prompt"]))) + target_tokens = int(item["prompt_tokens_count"]) + + # Allow small variance due to tokenization differences + assert abs(actual_tokens - target_tokens) <= 2, ( + f"Token count mismatch: expected ~{target_tokens}, got {actual_tokens}" + ) + + def test_variance_in_token_counts(self, complex_config, tokenizer): + generator = SyntheticTextItemsGenerator( + complex_config, tokenizer, random_seed=42 + ) + + items = list(generator) + + prompt_token_counts = [int(item["prompt_tokens_count"]) for item in items] + output_token_counts = [int(item["output_tokens_count"]) for item in items] + + # With variance, we should see different token counts + assert len(set(prompt_token_counts)) > 1, ( + "Should have variance in prompt token counts" + ) + assert len(set(output_token_counts)) > 1, ( + "Should have variance in output token counts" + ) + + # Verify bounds are respected + assert all( + complex_config.prompt_tokens_min + <= count + <= complex_config.prompt_tokens_max + for count in prompt_token_counts + ), "Prompt tokens should be within bounds" + assert all( + complex_config.output_tokens_min + <= count + <= complex_config.output_tokens_max + for count in output_token_counts + ), "Output tokens should be within bounds" + + def test_reproducibility_with_same_seed(self, simple_config, tokenizer): + generator1 = SyntheticTextItemsGenerator( + simple_config, tokenizer, random_seed=42 + ) + generator2 = SyntheticTextItemsGenerator( + simple_config, tokenizer, random_seed=42 + ) + + items1 = list(generator1) + items2 = list(generator2) + + # Results should be identical with same seed + assert len(items1) == len(items2) + for item1, item2 in zip(items1, items2): + assert str(item1["prompt"]) == str(item2["prompt"]) + assert int(item1["prompt_tokens_count"]) == int( + item2["prompt_tokens_count"] + ) + assert int(item1["output_tokens_count"]) == int( + item2["output_tokens_count"] + ) + + def test_different_seeds_produce_different_results(self, simple_config, tokenizer): + """Test that different seeds produce different results.""" + generator1 = SyntheticTextItemsGenerator( + simple_config, tokenizer, random_seed=42 + ) + generator2 = SyntheticTextItemsGenerator( + simple_config, tokenizer, random_seed=123 + ) + + items1 = list(generator1) + items2 = list(generator2) + + # Results should be different with different seeds + prompts1 = [str(item["prompt"]) for item in items1] + prompts2 = [str(item["prompt"]) for item in items2] + + different_content = False + for p1, p2 in zip(prompts1, prompts2): + # Remove the prefix and compare content + content1 = p1.split(": ", 1)[1] if ": " in p1 else p1 + content2 = p2.split(": ", 1)[1] if ": " in p2 else p2 + if content1 != content2: + different_content = True + break + + assert different_content, "Different seeds should produce different content" + + def test_create_prompt_method_directly(self, simple_config, tokenizer): + generator = SyntheticTextItemsGenerator( + simple_config, tokenizer, random_seed=42 + ) + + # Test normal prompt creation + prompt = generator._create_prompt(10, 0, 5) + assert prompt.startswith("5: "), "Prompt should start with request ID" + + actual_tokens = len(tokenizer.tokenize(prompt)) + assert abs(actual_tokens - 10) <= 1, ( + "Token count should be approximately correct" + ) + + # Test empty prompt + empty_prompt = generator._create_prompt(0, 0, 3) + assert empty_prompt == "3: ", "Empty prompt should just be the prefix" + + def test_request_counter_increments_correctly(self, simple_config, tokenizer): + generator = SyntheticTextItemsGenerator( + simple_config, tokenizer, random_seed=42 + ) + + # Initially should be 0 + assert generator.request_counter == 0 + + # Get items one by one and check counter + items = [] + for i, item in enumerate(generator, 1): + items.append(item) + # Counter should increment for each item + assert generator.request_counter == i + if i >= 3: # Just test first 3 + break + + # Verify prompts have correct prefixes + for i, item in enumerate(items, 1): + assert str(item["prompt"]).startswith(f"{i}: ") + + def test_prefix_format_consistency(self, simple_config, tokenizer): + generator = SyntheticTextItemsGenerator( + simple_config, tokenizer, random_seed=42 + ) + + items = list(generator) + + for i, item in enumerate(items, 1): + prompt = str(item["prompt"]) + + # Should start with number followed by colon and space + assert prompt.startswith(f"{i}: "), f"Prompt should start with '{i}: '" + + # Should be able to split on ': ' to get request ID and content + parts = prompt.split(": ", 1) + assert len(parts) == 2, "Prompt should have exactly one ': ' separator" + assert parts[0] == str(i), f"First part should be request ID {i}" + + # Content part should not be empty (unless it's a zero-token prompt) + if int(item["prompt_tokens_count"]) > 0: + assert len(parts[1]) > 0, ( + "Content part should not be empty for non-zero token prompts" + ) + + def test_binary_search_token_accuracy(self, simple_config, tokenizer): + generator = SyntheticTextItemsGenerator( + simple_config, tokenizer, random_seed=42 + ) + + # Test various token counts + test_cases = [5, 10, 15, 20, 25] + + for target_tokens in test_cases: + prompt = generator._create_prompt(target_tokens, 0, 999) + actual_tokens = len(tokenizer.tokenize(prompt)) + + # Should be very close to target + # (allowing for small tokenization differences) + assert abs(actual_tokens - target_tokens) <= 1, ( + f"Target: {target_tokens}, Actual: {actual_tokens}, " + f"Prompt: '{prompt[:50]}...'" + ) + + def test_vllm_cache_simulation_comprehensive(self, simple_config, tokenizer): + # Use larger sample for more thorough testing + config = SyntheticDatasetConfig( + prompt_tokens=20, output_tokens=10, samples=20, source=simple_config.source + ) + + generator = SyntheticTextItemsGenerator(config, tokenizer, random_seed=42) + items = list(generator) + prompts = [str(item["prompt"]) for item in items] + + # Simulate vLLM cache with different granularities + cache_scenarios = [ + ("Character-level", 1), + ("Token-level", 4), + ("Word-level", 10), + ] + + for scenario_name, granularity in cache_scenarios: + cache_hits = 0 + total_comparisons = 0 + + for i, prompt1 in enumerate(prompts): + for _, prompt2 in enumerate(prompts[i + 1 :], i + 1): + total_comparisons += 1 + + # Check for common prefix at specified granularity + min_len = min(len(prompt1), len(prompt2)) + common_prefix_len = 0 + + for k in range(0, min_len, granularity): + chunk1 = prompt1[k : k + granularity] + chunk2 = prompt2[k : k + granularity] + if chunk1 == chunk2: + common_prefix_len += len(chunk1) + else: + break + + # If meaningful common prefix exists, it's a cache hit + if common_prefix_len > granularity: + cache_hits += 1 + + cache_hit_rate = ( + (cache_hits / total_comparisons) * 100 if total_comparisons > 0 else 0 + ) + + # All scenarios should have 0% cache hit rate + assert cache_hit_rate == 0.0, ( + f"{scenario_name} caching: Expected 0% hit rate, " + f"got {cache_hit_rate:.1f}%" + ) + + def test_edge_case_very_short_prompts(self, tokenizer): + config = SyntheticDatasetConfig( + prompt_tokens=1, + output_tokens=5, + samples=5, + source="A B C D E F G H I J K L M N O P Q R S T U V W X Y Z", + ) + + generator = SyntheticTextItemsGenerator(config, tokenizer, random_seed=42) + items = list(generator) + + for i, item in enumerate(items, 1): + # Even very short prompts should have unique prefixes + assert str(item["prompt"]).startswith(f"{i}: ") + + # Should have at least the prefix + assert len(str(item["prompt"])) >= len(f"{i}: ") + + def test_create_prompt_method_signature_and_documentation( + self, simple_config, tokenizer + ): + generator = SyntheticTextItemsGenerator( + simple_config, tokenizer, random_seed=42 + ) + + # Test method exists and is callable + assert hasattr(generator, "_create_prompt") + assert callable(generator._create_prompt) + + # Test method signature by calling with expected parameters + prompt = generator._create_prompt(prompt_tokens=10, start_index=0, request_id=1) + + # Should return a string + assert isinstance(prompt, str) + + # Should start with the request ID + assert prompt.startswith("1: ") + + # Test that docstring exists and mentions key concepts + docstring = generator._create_prompt.__doc__ + assert docstring is not None + assert "prefix" in docstring.lower() + assert "cache" in docstring.lower() or "caching" in docstring.lower() + assert "request_id" in docstring + + +class TestIntegration: + """Integration tests for the complete synthetic dataset workflow.""" + + @pytest.fixture + def tokenizer(self): + """Fixture to provide a tokenizer for testing.""" + tokenizer = AutoTokenizer.from_pretrained("gpt2") + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + return tokenizer + + def test_end_to_end_workflow(self, tokenizer): + """Test the complete workflow from config to dataset.""" + # Create configuration + config_dict = { + "prompt_tokens": 20, + "output_tokens": 15, + "samples": 10, + "source": ( + "The quick brown fox jumps over the lazy dog. Machine learning models " + "require diverse training data to perform well across different tasks " + "and domains." + ), + } + + config_str = json.dumps(config_dict) + + # Create dataset + dataset = SyntheticDatasetCreator.handle_create( + data=config_str, + data_args=None, + processor=tokenizer, + processor_args=None, + random_seed=42, + ) + + # Verify dataset properties + assert isinstance(dataset, Dataset) + assert len(dataset) == 10 + + # Verify all prompts are unique and have correct prefixes + prompts = dataset["prompt"] + for i, prompt in enumerate(prompts, 1): + assert prompt.startswith(f"{i}: "), f"Prompt {i} should start with '{i}: '" + + # Verify no cache hits would occur + for i, prompt1 in enumerate(prompts): + for j, prompt2 in enumerate(prompts): + if i != j: + assert not prompt1.startswith(prompt2) + assert not prompt2.startswith(prompt1) + + # Verify token counts are reasonable + for i, row in enumerate(dataset): + actual_tokens = len(tokenizer.tokenize(row["prompt"])) + target_tokens = row["prompt_tokens_count"] + assert abs(actual_tokens - target_tokens) <= 2, ( + f"Row {i}: token count mismatch" + ) + + def test_cache_prevention_effectiveness(self, tokenizer): + """Test that the cache prevention is effective across larger datasets.""" + config = SyntheticDatasetConfig( + prompt_tokens=25, + output_tokens=20, + samples=50, + source=( + "The quick brown fox jumps over the lazy dog. Machine learning models " + "require diverse training data to perform well across different tasks " + "and domains. Natural language processing has advanced significantly " + "in recent years." + ), + ) + + generator = SyntheticTextItemsGenerator(config, tokenizer, random_seed=42) + items = list(generator) + prompts = [str(item["prompt"]) for item in items] + + prefixes = [prompt.split(": ", 1)[0] for prompt in prompts] + assert len(set(prefixes)) == len(prefixes), "All prefixes should be unique" + + for i, prefix in enumerate(prefixes, 1): + assert prefix == str(i), f"Prefix should be '{i}', got '{prefix}'" + + # Test that no prompt starts with the same prefix as another + for i, prompt1 in enumerate(prompts): + for j, prompt2 in enumerate(prompts): + if i != j: + prefix1 = prompt1.split(": ", 1)[0] + ": " + prefix2 = prompt2.split(": ", 1)[0] + ": " + assert not prompt1.startswith(prefix2), ( + f"Prompt {i} starts with prefix from prompt {j}" + ) + assert not prompt2.startswith(prefix1), ( + f"Prompt {j} starts with prefix from prompt {i}" + ) + + # Test that all prompts are unique + assert len(set(prompts)) == len(prompts), "All prompts should be unique"