From a28bbe36e64171c8f41414e9cd57cb06bc4ed69b Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Fri, 19 Sep 2025 11:48:24 +0000 Subject: [PATCH 1/6] Mock server implementation for guidellm --- src/guidellm/mock_server/__init__.py | 8 + src/guidellm/mock_server/config.py | 84 +++ src/guidellm/mock_server/handlers/__init__.py | 17 + .../mock_server/handlers/chat_completions.py | 280 ++++++++++ .../mock_server/handlers/completions.py | 280 ++++++++++ .../mock_server/handlers/tokenizer.py | 142 +++++ src/guidellm/mock_server/models.py | 510 +++++++++++++++++ src/guidellm/mock_server/server.py | 168 ++++++ src/guidellm/mock_server/utils.py | 307 +++++++++++ tests/unit/mock_server/__init__.py | 1 + tests/unit/mock_server/test_server.py | 518 ++++++++++++++++++ 11 files changed, 2315 insertions(+) create mode 100644 src/guidellm/mock_server/__init__.py create mode 100644 src/guidellm/mock_server/config.py create mode 100644 src/guidellm/mock_server/handlers/__init__.py create mode 100644 src/guidellm/mock_server/handlers/chat_completions.py create mode 100644 src/guidellm/mock_server/handlers/completions.py create mode 100644 src/guidellm/mock_server/handlers/tokenizer.py create mode 100644 src/guidellm/mock_server/models.py create mode 100644 src/guidellm/mock_server/server.py create mode 100644 src/guidellm/mock_server/utils.py create mode 100644 tests/unit/mock_server/__init__.py create mode 100644 tests/unit/mock_server/test_server.py diff --git a/src/guidellm/mock_server/__init__.py b/src/guidellm/mock_server/__init__.py new file mode 100644 index 00000000..f76e98fb --- /dev/null +++ b/src/guidellm/mock_server/__init__.py @@ -0,0 +1,8 @@ +""" +GuideLLM Mock Server for OpenAI and vLLM API compatibility. +""" + +from .config import MockServerConfig +from .server import MockServer + +__all__ = ["MockServer", "MockServerConfig"] diff --git a/src/guidellm/mock_server/config.py b/src/guidellm/mock_server/config.py new file mode 100644 index 00000000..27d1d742 --- /dev/null +++ b/src/guidellm/mock_server/config.py @@ -0,0 +1,84 @@ +""" +Configuration settings for the mock server component. + +Provides centralized configuration management for mock server behavior including +network binding, model identification, response timing characteristics, and token +generation parameters. Supports environment variable configuration for deployment +flexibility with automatic validation through Pydantic settings. +""" + +from __future__ import annotations + +from pydantic import Field +from pydantic_settings import BaseSettings + +__all__ = ["MockServerConfig"] + + +class MockServerConfig(BaseSettings): + """ + Configuration settings for mock server behavior and deployment. + + Centralizes all configurable parameters for mock server operation including + network settings, model identification, response timing characteristics, and + token generation behavior. Environment variables with GUIDELLM_MOCK_SERVER_ + prefix override default values for deployment flexibility. + + Example: + :: + config = MockServerConfig(host="0.0.0.0", port=8080, model="custom-model") + # Use with environment variables: + # GUIDELLM_MOCK_SERVER_HOST=127.0.0.1 GUIDELLM_MOCK_SERVER_PORT=9000 + """ + + host: str = Field( + default="127.0.0.1", description="Host address to bind the server to" + ) + port: int = Field(default=8000, description="Port number to bind the server to") + workers: int = Field(default=1, description="Number of worker processes to spawn") + model: str = Field( + default="llama-3.1-8b-instruct", + description="Model name to present in API responses", + ) + processor: str | None = Field( + default=None, + description=( + "Processor type to use for token stats, tokenize, and detokenize. " + "If None, a mock one is created." + ), + ) + request_latency: float = Field( + default=3.0, + description="Base request latency in seconds for non-streaming responses", + ) + request_latency_std: float = Field( + default=0.0, + description="Standard deviation for request latency variation", + ) + ttft_ms: float = Field( + default=150.0, + description="Time to first token in milliseconds for streaming responses", + ) + ttft_ms_std: float = Field( + default=0.0, + description="Standard deviation for time to first token variation", + ) + itl_ms: float = Field( + default=10.0, + description="Inter-token latency in milliseconds for streaming responses", + ) + itl_ms_std: float = Field( + default=0.0, + description="Standard deviation for inter-token latency variation", + ) + output_tokens: int = Field( + default=128, description="Number of output tokens to generate in responses" + ) + output_tokens_std: float = Field( + default=0.0, + description="Standard deviation for output token count variation", + ) + + class Config: + env_prefix = "GUIDELLM_MOCK_SERVER_" + case_sensitive = False diff --git a/src/guidellm/mock_server/handlers/__init__.py b/src/guidellm/mock_server/handlers/__init__.py new file mode 100644 index 00000000..7dbc209f --- /dev/null +++ b/src/guidellm/mock_server/handlers/__init__.py @@ -0,0 +1,17 @@ +""" +HTTP request handlers for the GuideLLM mock server. + +This module exposes request handlers that implement OpenAI-compatible API endpoints +for the mock server. The handlers provide realistic LLM simulation capabilities +including chat completions, legacy completions, and tokenization services with +configurable timing characteristics, token counting, and proper error handling to +support comprehensive benchmarking and testing scenarios. +""" + +from __future__ import annotations + +from .chat_completions import ChatCompletionsHandler +from .completions import CompletionsHandler +from .tokenizer import TokenizerHandler + +__all__ = ["ChatCompletionsHandler", "CompletionsHandler", "TokenizerHandler"] diff --git a/src/guidellm/mock_server/handlers/chat_completions.py b/src/guidellm/mock_server/handlers/chat_completions.py new file mode 100644 index 00000000..976901f9 --- /dev/null +++ b/src/guidellm/mock_server/handlers/chat_completions.py @@ -0,0 +1,280 @@ +""" +OpenAI Chat Completions API endpoint handler for the mock server. + +Provides a complete implementation of the /v1/chat/completions endpoint that simulates +realistic LLM behavior with configurable timing characteristics. Supports both streaming +and non-streaming responses with proper token counting, latency simulation including +TTFT (Time To First Token) and ITL (Inter-Token Latency), and OpenAI-compatible error +handling for comprehensive benchmarking scenarios. +""" + +from __future__ import annotations + +import asyncio +import json +import math +import time +import uuid + +from pydantic import ValidationError +from sanic import response +from sanic.request import Request +from sanic.response import HTTPResponse, ResponseStream +from transformers import PreTrainedTokenizer + +from guidellm.mock_server.config import MockServerConfig +from guidellm.mock_server.models import ( + ChatCompletionChoice, + ChatCompletionsRequest, + ChatCompletionsResponse, + ChatMessage, + ErrorDetail, + ErrorResponse, + Usage, +) +from guidellm.mock_server.utils import ( + MockTokenizer, + create_fake_text, + create_fake_tokens_str, + sample_number, + times_generator, +) + +__all__ = ["ChatCompletionsHandler"] + + +class ChatCompletionsHandler: + """ + Handles OpenAI Chat Completions API requests with realistic LLM simulation. + + Implements the /v1/chat/completions endpoint behavior including request validation, + response generation, and timing simulation. Supports both streaming and + non-streaming modes with configurable latency characteristics for comprehensive + benchmarking. Uses either a mock tokenizer or a real tokenizer for accurate token + counting and realistic text generation. + + Example: + :: + config = MockServerConfig(ttft_ms=100, itl_ms=50) + handler = ChatCompletionsHandler(config) + response = await handler.handle(request) + """ + + def __init__(self, config: MockServerConfig) -> None: + """ + Initialize the Chat Completions handler with server configuration. + + :param config: Mock server configuration containing timing and behavior settings + """ + self.config = config + self.tokenizer = ( + MockTokenizer() + if config.processor is None + else PreTrainedTokenizer.from_pretrained(config.processor) + ) + + async def handle(self, request: Request) -> HTTPResponse: + """ + Process incoming chat completion requests with validation and routing. + + Validates the request payload, handles errors gracefully, and routes to + appropriate streaming or non-streaming response handlers based on the + request configuration. + + :param request: Sanic HTTP request containing chat completion parameters + :return: HTTP response with completion data or error information + :raises ValidationError: When request payload fails validation + :raises JSONDecodeError: When request contains invalid JSON + """ + try: + # Parse and validate request + req_data = ChatCompletionsRequest(**request.json) + except ValidationError as exc: + return response.json( + ErrorResponse( + error=ErrorDetail( + message=f"Invalid request: {str(exc)}", + type="invalid_request_error", + code="invalid_request", + ) + ).model_dump(), + status=400, + ) + except (json.JSONDecodeError, TypeError): + return response.json( + ErrorResponse( + error=ErrorDetail( + message="Invalid JSON in request body", + type="invalid_request_error", + code="invalid_json", + ) + ).model_dump(), + status=400, + ) + + # Handle streaming vs non-streaming + if req_data.stream: + return await self._handle_stream(req_data) + else: + return await self._handle_non_stream(req_data) + + async def _handle_non_stream(self, req: ChatCompletionsRequest) -> HTTPResponse: + """ + Generate complete non-streaming chat completion response. + + Simulates realistic LLM behavior with TTFT and ITL delays, generates + appropriate token counts, and returns a complete response with usage + statistics and generated content. + + :param req: Validated chat completion request parameters + :return: Complete HTTP response with generated completion data + """ + # TTFT delay + await asyncio.sleep( + sample_number(self.config.ttft_ms, self.config.ttft_ms_std) / 1000.0 + ) + + # Token counts + prompt_text = self.tokenizer.apply_chat_template(req.messages) + prompt_tokens = len(self.tokenizer(prompt_text)) + max_tokens = req.max_completion_tokens or req.max_tokens or math.inf + completion_tokens_count = min( + sample_number(self.config.output_tokens, self.config.output_tokens_std), + max_tokens, + ) + + # ITL delay + itl_delay = 0.0 + delays_iter = iter(times_generator(self.config.itl_ms, self.config.itl_ms_std)) + for _ in range(int(completion_tokens_count) - 1): + itl_delay += next(delays_iter) + await asyncio.sleep(itl_delay / 1000.0) + + # Response + chat_response = ChatCompletionsResponse( + id=f"chatcmpl-{uuid.uuid4().hex[:29]}", + model=req.model, + choices=[ + ChatCompletionChoice( + index=0, + message=ChatMessage( + role="assistant", + content=create_fake_text( + int(completion_tokens_count), self.tokenizer + ), + ), + finish_reason="stop", + ) + ], + usage=Usage( + prompt_tokens=prompt_tokens, + completion_tokens=int(completion_tokens_count), + ), + system_fingerprint=f"fp_{uuid.uuid4().hex[:10]}", + ) + + return response.json(chat_response.model_dump()) + + async def _handle_stream(self, req: ChatCompletionsRequest) -> HTTPResponse: + """ + Generate streaming chat completion response with real-time token delivery. + + Creates a streaming response that delivers tokens incrementally with + realistic timing delays. Supports optional usage statistics in the final + stream chunk when requested via stream_options. + + :param req: Validated chat completion request with streaming enabled + :return: Streaming HTTP response delivering tokens with proper timing + """ + + async def generate_stream(stream_response): + completion_id = f"chatcmpl-{uuid.uuid4().hex[:29]}" + + # TTFT delay + await asyncio.sleep( + sample_number(self.config.ttft_ms, self.config.ttft_ms_std) / 1000.0 + ) + + # Token counts + prompt_text = self.tokenizer.apply_chat_template(req.messages) + prompt_tokens = len(self.tokenizer(prompt_text)) + max_tokens = req.max_completion_tokens or req.max_tokens or math.inf + completion_tokens_count = int( + min( + sample_number( + self.config.output_tokens, self.config.output_tokens_std + ), + max_tokens, + ) + ) + + # Send tokens + tokens = create_fake_tokens_str(completion_tokens_count, self.tokenizer) + delays_iter = iter( + times_generator(self.config.itl_ms, self.config.itl_ms_std) + ) + + for index, token in enumerate(tokens): + if index > 0: + itl_delay = next(delays_iter) + await asyncio.sleep(itl_delay / 1000.0) + + chunk_data = { + "id": completion_id, + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": req.model, + "choices": [ + { + "index": 0, + "delta": {"content": token}, + "finish_reason": None, + } + ], + } + await stream_response.write(f"data: {json.dumps(chunk_data)}\n\n") + + # Send final chunk with finish reason + final_chunk = { + "id": completion_id, + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": req.model, + "choices": [ + { + "index": 0, + "delta": {}, + "finish_reason": "stop", + } + ], + } + await stream_response.write(f"data: {json.dumps(final_chunk)}\n\n") + + # Send usage if requested + if req.stream_options and req.stream_options.get("include_usage"): + usage_chunk = { + "id": completion_id, + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": req.model, + "choices": [], + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens_count, + "total_tokens": prompt_tokens + completion_tokens_count, + }, + } + await stream_response.write(f"data: {json.dumps(usage_chunk)}\n\n") + + # End stream + await stream_response.write("data: [DONE]\n\n") + + return ResponseStream( # type: ignore[return-value] + generate_stream, + content_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) diff --git a/src/guidellm/mock_server/handlers/completions.py b/src/guidellm/mock_server/handlers/completions.py new file mode 100644 index 00000000..418d2b3c --- /dev/null +++ b/src/guidellm/mock_server/handlers/completions.py @@ -0,0 +1,280 @@ +""" +Legacy OpenAI Completions API handler for the mock server. + +This module provides the CompletionsHandler class that implements the /v1/completions +endpoint for the guidellm mock server. It supports both streaming and non-streaming +completions with configurable timing parameters (TTFT, ITL) and token generation to +simulate realistic LLM behavior for benchmarking and testing purposes. +""" + +from __future__ import annotations + +import asyncio +import json +import math +import time +import uuid + +from pydantic import ValidationError +from sanic import response +from sanic.request import Request +from sanic.response import HTTPResponse, ResponseStream +from transformers import PreTrainedTokenizer + +from guidellm.mock_server.config import MockServerConfig +from guidellm.mock_server.models import ( + CompletionChoice, + CompletionsRequest, + CompletionsResponse, + ErrorDetail, + ErrorResponse, + Usage, +) +from guidellm.mock_server.utils import ( + MockTokenizer, + create_fake_text, + create_fake_tokens_str, + sample_number, + times_generator, +) + +__all__ = ["CompletionsHandler"] + + +class CompletionsHandler: + """ + Handler for the OpenAI /v1/completions endpoint in the mock server. + + This handler simulates the legacy OpenAI completions API by processing incoming + requests and generating responses with configurable timing and token generation + patterns. It supports both streaming and non-streaming modes, applying realistic + timing delays (TTFT and ITL) to mimic actual LLM behavior for benchmarking. + + Example: + :: + config = MockServerConfig(ttft_ms=100, itl_ms=50) + handler = CompletionsHandler(config) + response = await handler.handle(sanic_request) + """ + + def __init__(self, config: MockServerConfig) -> None: + """ + Initialize the completions handler with configuration settings. + + :param config: Mock server configuration containing timing parameters + and tokenizer settings + """ + self.config = config + self.tokenizer = ( + MockTokenizer() + if config.processor is None + else PreTrainedTokenizer.from_pretrained(config.processor) + ) + + async def handle(self, request: Request) -> HTTPResponse: + """ + Process a completions request and return the appropriate response. + + Validates the incoming request, determines whether to use streaming or + non-streaming mode, and delegates to the appropriate handler method. + + :param request: Sanic request object containing the completions request data + :return: HTTP response with completion data or error information + :raises ValidationError: When request validation fails + :raises json.JSONDecodeError: When request JSON is malformed + """ + try: + # Parse and validate request + req_data = CompletionsRequest(**request.json) + except ValidationError as e: + return response.json( + ErrorResponse( + error=ErrorDetail( + message=f"Invalid request: {str(e)}", + type="invalid_request_error", + code="invalid_request", + ) + ).model_dump(), + status=400, + ) + except (json.JSONDecodeError, TypeError): + return response.json( + ErrorResponse( + error=ErrorDetail( + message="Invalid JSON in request body", + type="invalid_request_error", + code="invalid_json", + ) + ).model_dump(), + status=400, + ) + + # Handle streaming vs non-streaming + if req_data.stream: + return await self._handle_stream(req_data) + else: + return await self._handle_non_stream(req_data) + + async def _handle_non_stream(self, req: CompletionsRequest) -> HTTPResponse: + """ + Generate a non-streaming completion response. + + Simulates TTFT and ITL delays, generates appropriate token counts, and returns + a complete response with the generated text and usage statistics. + + :param req: Validated completions request containing prompt and parameters + :return: JSON HTTP response with completion text and usage data + :raises NotImplementedError: When batch processing is requested + """ + if isinstance(req.prompt, list): + raise NotImplementedError("Batch processing is not supported.") + + # TTFT delay + await asyncio.sleep( + sample_number(self.config.ttft_ms, self.config.ttft_ms_std) / 1000.0 + ) + + # Token counts + prompt_tokens = len(self.tokenizer(req.prompt)) + max_tokens = req.max_tokens or math.inf + completion_tokens_count = int( + min( + sample_number(self.config.output_tokens, self.config.output_tokens_std), + max_tokens, + ) + if req.stop + else max_tokens + ) + + # ITL delay + itl_delay = 0.0 + delays_iter = iter(times_generator(self.config.itl_ms, self.config.itl_ms_std)) + for _ in range(int(completion_tokens_count) - 1): + itl_delay += next(delays_iter) + await asyncio.sleep(itl_delay / 1000.0) + + # Response + completion_response = CompletionsResponse( + id=f"cmpl-{uuid.uuid4().hex[:29]}", + model=req.model, + choices=[ + CompletionChoice( + text=create_fake_text(completion_tokens_count, self.tokenizer), + index=0, + finish_reason="stop", + ) + ], + usage=Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens_count, + ), + system_fingerprint=f"fp_{uuid.uuid4().hex[:10]}", + ) + + return response.json(completion_response.model_dump()) + + async def _handle_stream(self, req: CompletionsRequest) -> HTTPResponse: + """ + Generate a streaming completion response. + + Creates a server-sent events stream that delivers tokens incrementally with + realistic timing delays between each token. Includes usage statistics if + requested and properly terminates the stream. + + :param req: Validated completions request containing prompt and streaming + options + :return: ResponseStream object that generates server-sent events + """ + + async def generate_stream(stream_response): + completion_id = f"cmpl-{uuid.uuid4().hex[:29]}" + + # TTFT delay + await asyncio.sleep( + sample_number(self.config.ttft_ms, self.config.ttft_ms_std) / 1000.0 + ) + + # Token counts + prompt_tokens = len(self.tokenizer(req.prompt)) + max_tokens = req.max_tokens or math.inf + completion_tokens_count = int( + min( + sample_number( + self.config.output_tokens, self.config.output_tokens_std + ), + max_tokens, + ) + if req.stop + else max_tokens + ) + + # Send tokens + tokens = create_fake_tokens_str(completion_tokens_count, self.tokenizer) + delays_iter = iter( + times_generator(self.config.itl_ms, self.config.itl_ms_std) + ) + + for index, token in enumerate(tokens): + if index > 0: + itl_delay = next(delays_iter) + await asyncio.sleep(itl_delay / 1000.0) + + chunk_data = { + "id": completion_id, + "object": "text_completion", + "created": int(time.time()), + "model": req.model, + "choices": [ + { + "text": token, + "index": index, + "finish_reason": None, + } + ], + } + await stream_response.write(f"data: {json.dumps(chunk_data)}\n\n") + + # Send final chunk with finish reason + final_chunk = { + "id": completion_id, + "object": "text_completion", + "created": int(time.time()), + "model": req.model, + "choices": [ + { + "text": "", + "index": index, + "finish_reason": "stop", + } + ], + } + await stream_response.write(f"data: {json.dumps(final_chunk)}\n\n") + + # Send usage if requested + if req.stream_options and req.stream_options.get("include_usage"): + usage_chunk = { + "id": completion_id, + "object": "text_completion", + "created": int(time.time()), + "model": req.model, + "choices": [], + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens_count, + "total_tokens": prompt_tokens + completion_tokens_count, + }, + } + await stream_response.write(f"data: {json.dumps(usage_chunk)}\n\n") + + # End stream + await stream_response.write("data: [DONE]\n\n") + + return ResponseStream( # type: ignore[return-value] + generate_stream, + content_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) diff --git a/src/guidellm/mock_server/handlers/tokenizer.py b/src/guidellm/mock_server/handlers/tokenizer.py new file mode 100644 index 00000000..430ac0ef --- /dev/null +++ b/src/guidellm/mock_server/handlers/tokenizer.py @@ -0,0 +1,142 @@ +""" +HTTP request handler for vLLM tokenization API endpoints in the mock server. + +This module provides the TokenizerHandler class that implements vLLM-compatible +tokenization and detokenization endpoints for testing and development purposes. +It handles text-to-token conversion, token-to-text reconstruction, request +validation, and error responses with proper HTTP status codes and JSON formatting. +""" + +from __future__ import annotations + +from pydantic import ValidationError +from sanic import response +from sanic.request import Request +from sanic.response import HTTPResponse +from transformers.tokenization_utils import PreTrainedTokenizer + +from guidellm.mock_server.config import MockServerConfig +from guidellm.mock_server.models import ( + DetokenizeRequest, + DetokenizeResponse, + ErrorDetail, + ErrorResponse, + TokenizeRequest, + TokenizeResponse, +) +from guidellm.mock_server.utils import MockTokenizer + +__all__ = ["TokenizerHandler"] + + +class TokenizerHandler: + """ + HTTP request handler for vLLM tokenization and detokenization endpoints. + + Provides mock implementations of vLLM's tokenization API endpoints including + /tokenize for converting text to tokens and /detokenize for reconstructing + text from token sequences. Handles request validation, error responses, and + JSON serialization with proper HTTP status codes. + + Example: + :: + handler = TokenizerHandler(config) + response = await handler.tokenize(request) + response = await handler.detokenize(request) + """ + + def __init__(self, config: MockServerConfig) -> None: + """ + Initialize the tokenizer handler with configuration. + + :param config: Server configuration object containing tokenizer settings + """ + self.config = config + self.tokenizer = ( + MockTokenizer() + if config.processor is None + else PreTrainedTokenizer.from_pretrained(config.processor) + ) + + async def tokenize(self, request: Request) -> HTTPResponse: + """ + Convert input text to token IDs via the /tokenize endpoint. + + Validates the request payload, extracts text content, and returns a JSON + response containing the token sequence and count. Handles validation errors + and malformed JSON with appropriate HTTP error responses. + + :param request: Sanic HTTP request containing JSON payload with text field + :return: JSON response with tokens list and count, or error response + """ + try: + req_data = TokenizeRequest(**request.json) + except ValidationError as exc: + return response.json( + ErrorResponse( + error=ErrorDetail( + message=f"Invalid request: {str(exc)}", + type="invalid_request_error", + code="invalid_request", + ) + ).model_dump(), + status=400, + ) + except (ValueError, TypeError, KeyError): + return response.json( + ErrorResponse( + error=ErrorDetail( + message="Invalid JSON in request body", + type="invalid_request_error", + code="invalid_json", + ) + ).model_dump(), + status=400, + ) + + tokens = self.tokenizer.tokenize(req_data.text) + token_ids = self.tokenizer.convert_tokens_to_ids(tokens) + + return response.json( + TokenizeResponse(tokens=token_ids, count=len(token_ids)).model_dump() + ) + + async def detokenize(self, request: Request) -> HTTPResponse: + """ + Convert token IDs back to text via the /detokenize endpoint. + + Validates the request payload, extracts token sequences, and returns a JSON + response containing the reconstructed text. Handles validation errors and + malformed JSON with appropriate HTTP error responses. + + :param request: Sanic HTTP request containing JSON payload with tokens field + :return: JSON response with reconstructed text, or error response + """ + try: + req_data = DetokenizeRequest(**request.json) + except ValidationError as exc: + return response.json( + ErrorResponse( + error=ErrorDetail( + message=f"Invalid request: {str(exc)}", + type="invalid_request_error", + code="invalid_request", + ) + ).model_dump(), + status=400, + ) + except (ValueError, TypeError, KeyError): + return response.json( + ErrorResponse( + error=ErrorDetail( + message="Invalid JSON in request body", + type="invalid_request_error", + code="invalid_json", + ) + ).model_dump(), + status=400, + ) + + text = self.tokenizer.decode(req_data.tokens, skip_special_tokens=False) + + return response.json(DetokenizeResponse(text=text).model_dump()) diff --git a/src/guidellm/mock_server/models.py b/src/guidellm/mock_server/models.py new file mode 100644 index 00000000..cd342f7a --- /dev/null +++ b/src/guidellm/mock_server/models.py @@ -0,0 +1,510 @@ +""" +Pydantic models for OpenAI API and vLLM API request/response validation. + +This module defines comprehensive data models for validating and serializing API +requests and responses compatible with both OpenAI's API specification and vLLM's +extended parameters. It includes models for chat completions, legacy text completions, +tokenization operations, and error handling, supporting both streaming and non-streaming +responses with full type safety and validation. +""" + +from __future__ import annotations + +import time +from typing import Any, Literal + +from pydantic import BaseModel, Field + +__all__ = [ + "ChatCompletionChoice", + "ChatCompletionChunk", + "ChatCompletionsRequest", + "ChatCompletionsResponse", + "ChatMessage", + "CompletionChoice", + "CompletionsRequest", + "CompletionsResponse", + "DetokenizeRequest", + "DetokenizeResponse", + "ErrorDetail", + "ErrorResponse", + "StreamOptions", + "TokenizeRequest", + "TokenizeResponse", + "Usage", +] + + +class Usage(BaseModel): + """Token usage statistics for API requests and responses. + + Tracks the number of tokens consumed in prompts, completions, and total + usage for billing and monitoring purposes. + """ + + prompt_tokens: int = Field(description="Number of tokens in the input prompt") + completion_tokens: int = Field( + description="Number of tokens in the generated completion" + ) + total_tokens: int = Field(description="Total tokens used (prompt + completion)") + + def __init__(self, prompt_tokens: int = 0, completion_tokens: int = 0, **kwargs): + """Initialize usage statistics. + + :param prompt_tokens: Number of tokens in the input prompt + :param completion_tokens: Number of tokens in the generated completion + :param kwargs: Additional keyword arguments passed to BaseModel + """ + super().__init__( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + **kwargs, + ) + + +class StreamOptions(BaseModel): + """Configuration options for streaming API responses. + + Controls the behavior and content of streamed responses including + whether to include usage statistics in the final chunk. + """ + + include_usage: bool | None = Field( + default=None, + description="Whether to include usage statistics in streaming responses", + ) + + +class ChatMessage(BaseModel): + """A single message in a chat conversation. + + Represents one exchange in a conversational interface with role-based + content and optional metadata for advanced features. + """ + + role: Literal["system", "user", "assistant", "tool"] = Field( + description="Role of the message sender in the conversation" + ) + content: str = Field(description="Text content of the message") + name: str | None = Field( + default=None, description="Optional name identifier for the message sender" + ) + + +class ChatCompletionsRequest(BaseModel): + """Request parameters for chat completion API endpoints. + + Comprehensive model supporting both OpenAI standard parameters and vLLM + extensions for advanced generation control, guided decoding, and performance + optimization. + """ + + model: str = Field(description="Model identifier to use for generation") + messages: list[ChatMessage] = Field( + description="List of messages in the conversation" + ) + max_tokens: int | None = Field( + default=None, description="Maximum number of tokens to generate" + ) + max_completion_tokens: int | None = Field( + default=None, description="Maximum tokens in completion (OpenAI naming)" + ) + temperature: float | None = Field( + default=1.0, description="Sampling temperature for randomness control" + ) + top_p: float | None = Field(default=1.0, description="Nucleus sampling parameter") + n: int | None = Field( + default=1, description="Number of completion choices to generate" + ) + stream: bool | None = Field( + default=False, description="Whether to stream response chunks" + ) + stream_options: StreamOptions | None = Field( + default=None, description="Configuration for streaming responses" + ) + stop: str | list[str] | None = Field( + default=None, description="Stop sequences to end generation" + ) + presence_penalty: float | None = Field( + default=0.0, description="Penalty for token presence to encourage diversity" + ) + frequency_penalty: float | None = Field( + default=0.0, description="Penalty for token frequency to reduce repetition" + ) + logit_bias: dict[str, float] | None = Field( + default=None, description="Bias values for specific tokens" + ) + seed: int | None = Field( + default=None, description="Random seed for reproducible outputs" + ) + user: str | None = Field( + default=None, description="User identifier for tracking and abuse monitoring" + ) + + # vLLM extensions + use_beam_search: bool | None = Field( + default=False, description="Enable beam search for better quality" + ) + top_k: int | None = Field(default=None, description="Top-k sampling parameter") + min_p: float | None = Field( + default=None, description="Minimum probability threshold for sampling" + ) + repetition_penalty: float | None = Field( + default=None, description="Penalty for repeated tokens" + ) + length_penalty: float | None = Field( + default=1.0, description="Length penalty for sequence scoring" + ) + stop_token_ids: list[int] | None = Field( + default=None, description="Token IDs that trigger generation stop" + ) + include_stop_str_in_output: bool | None = Field( + default=False, description="Include stop sequence in output" + ) + ignore_eos: bool | None = Field( + default=False, description="Ignore end-of-sequence tokens" + ) + min_tokens: int | None = Field( + default=0, description="Minimum number of tokens to generate" + ) + skip_special_tokens: bool | None = Field( + default=True, description="Skip special tokens in output" + ) + spaces_between_special_tokens: bool | None = Field( + default=True, description="Add spaces between special tokens" + ) + truncate_prompt_tokens: int | None = Field( + default=None, description="Maximum prompt tokens before truncation" + ) + allowed_token_ids: list[int] | None = Field( + default=None, description="Restrict generation to specific token IDs" + ) + prompt_logprobs: int | None = Field( + default=None, description="Number of logprobs to return for prompt tokens" + ) + add_special_tokens: bool | None = Field( + default=True, description="Add special tokens during processing" + ) + guided_json: str | dict[str, Any] | None = Field( + default=None, description="JSON schema for guided generation" + ) + guided_regex: str | None = Field( + default=None, description="Regex pattern for guided generation" + ) + guided_choice: list[str] | None = Field( + default=None, description="List of choices for guided generation" + ) + guided_grammar: str | None = Field( + default=None, description="Grammar specification for guided generation" + ) + guided_decoding_backend: str | None = Field( + default=None, description="Backend to use for guided decoding" + ) + guided_whitespace_pattern: str | None = Field( + default=None, description="Whitespace pattern for guided generation" + ) + priority: int | None = Field( + default=0, description="Request priority for scheduling" + ) + + +class ChatCompletionChoice(BaseModel): + """A single completion choice from a chat completion response. + + Contains the generated message and metadata about why generation + stopped and the choice's position in the response. + """ + + index: int = Field(description="Index of this choice in the response") + message: ChatMessage = Field(description="Generated message content") + finish_reason: Literal["stop", "length", "content_filter", "tool_calls"] | None = ( + Field(description="Reason why generation finished") + ) + + +class ChatCompletionsResponse(BaseModel): + """Response from chat completion API endpoints. + + Contains generated choices, usage statistics, and metadata for + non-streaming chat completion requests. + """ + + id: str = Field(description="Unique identifier for this completion") + object: Literal["chat.completion"] = Field( + default="chat.completion", description="Object type identifier" + ) + created: int = Field( + default_factory=lambda: int(time.time()), + description="Unix timestamp of creation", + ) + model: str = Field(description="Model used for generation") + choices: list[ChatCompletionChoice] = Field( + description="Generated completion choices" + ) + usage: Usage | None = Field(default=None, description="Token usage statistics") + system_fingerprint: str | None = Field( + default=None, description="System configuration fingerprint" + ) + + +class ChatCompletionChunk(BaseModel): + """A single chunk in a streamed chat completion response. + + Represents one piece of a streaming response with delta content + and optional usage statistics in the final chunk. + """ + + id: str = Field(description="Unique identifier for this completion") + object: Literal["chat.completion.chunk"] = Field( + default="chat.completion.chunk", + description="Object type identifier for streaming chunks", + ) + created: int = Field( + default_factory=lambda: int(time.time()), + description="Unix timestamp of creation", + ) + model: str = Field(description="Model used for generation") + choices: list[dict[str, Any]] = Field(description="Delta choices for streaming") + usage: Usage | None = Field( + default=None, description="Token usage statistics (typically in final chunk)" + ) + + +class CompletionsRequest(BaseModel): + """Request parameters for legacy text completion API endpoints. + + Supports the older text completion format with prompt-based input + and the same extensive parameter set as chat completions for + backward compatibility. + """ + + model: str = Field(description="Model identifier to use for generation") + prompt: str | list[str] = Field(description="Input prompt(s) for completion") + max_tokens: int | None = Field( + default=16, description="Maximum number of tokens to generate" + ) + temperature: float | None = Field( + default=1.0, description="Sampling temperature for randomness control" + ) + top_p: float | None = Field(default=1.0, description="Nucleus sampling parameter") + n: int | None = Field( + default=1, description="Number of completion choices to generate" + ) + stream: bool | None = Field( + default=False, description="Whether to stream response chunks" + ) + stream_options: StreamOptions | None = Field( + default=None, description="Configuration for streaming responses" + ) + logprobs: int | None = Field( + default=None, description="Number of logprobs to return" + ) + echo: bool | None = Field( + default=False, description="Whether to echo the prompt in output" + ) + stop: str | list[str] | None = Field( + default_factory=lambda: ["<|endoftext|>"], + description="Stop sequences to end generation", + ) + presence_penalty: float | None = Field( + default=0.0, description="Penalty for token presence to encourage diversity" + ) + frequency_penalty: float | None = Field( + default=0.0, description="Penalty for token frequency to reduce repetition" + ) + best_of: int | None = Field( + default=1, description="Number of candidates to generate and return the best" + ) + logit_bias: dict[str, float] | None = Field( + default=None, description="Bias values for specific tokens" + ) + seed: int | None = Field( + default=None, description="Random seed for reproducible outputs" + ) + suffix: str | None = Field( + default=None, description="Suffix to append after completion" + ) + user: str | None = Field( + default=None, description="User identifier for tracking and abuse monitoring" + ) + + # vLLM extensions (same as chat completions) + use_beam_search: bool | None = Field( + default=False, description="Enable beam search for better quality" + ) + top_k: int | None = Field(default=None, description="Top-k sampling parameter") + min_p: float | None = Field( + default=None, description="Minimum probability threshold for sampling" + ) + repetition_penalty: float | None = Field( + default=None, description="Penalty for repeated tokens" + ) + length_penalty: float | None = Field( + default=1.0, description="Length penalty for sequence scoring" + ) + stop_token_ids: list[int] | None = Field( + default=None, description="Token IDs that trigger generation stop" + ) + include_stop_str_in_output: bool | None = Field( + default=False, description="Include stop sequence in output" + ) + ignore_eos: bool | None = Field( + default=False, description="Ignore end-of-sequence tokens" + ) + min_tokens: int | None = Field( + default=0, description="Minimum number of tokens to generate" + ) + skip_special_tokens: bool | None = Field( + default=True, description="Skip special tokens in output" + ) + spaces_between_special_tokens: bool | None = Field( + default=True, description="Add spaces between special tokens" + ) + truncate_prompt_tokens: int | None = Field( + default=None, description="Maximum prompt tokens before truncation" + ) + allowed_token_ids: list[int] | None = Field( + default=None, description="Restrict generation to specific token IDs" + ) + prompt_logprobs: int | None = Field( + default=None, description="Number of logprobs to return for prompt tokens" + ) + add_special_tokens: bool | None = Field( + default=True, description="Add special tokens during processing" + ) + guided_json: str | dict[str, Any] | None = Field( + default=None, description="JSON schema for guided generation" + ) + guided_regex: str | None = Field( + default=None, description="Regex pattern for guided generation" + ) + guided_choice: list[str] | None = Field( + default=None, description="List of choices for guided generation" + ) + guided_grammar: str | None = Field( + default=None, description="Grammar specification for guided generation" + ) + guided_decoding_backend: str | None = Field( + default=None, description="Backend to use for guided decoding" + ) + guided_whitespace_pattern: str | None = Field( + default=None, description="Whitespace pattern for guided generation" + ) + priority: int | None = Field( + default=0, description="Request priority for scheduling" + ) + + +class CompletionChoice(BaseModel): + """A single completion choice from a text completion response. + + Contains the generated text and metadata about completion + quality and stopping conditions. + """ + + text: str = Field(description="Generated text content") + index: int = Field(description="Index of this choice in the response") + logprobs: dict[str, Any] | None = Field( + default=None, description="Log probabilities for generated tokens" + ) + finish_reason: Literal["stop", "length", "content_filter"] | None = Field( + description="Reason why generation finished" + ) + + +class CompletionsResponse(BaseModel): + """Response from legacy text completion API endpoints. + + Contains generated text choices, usage statistics, and metadata + for non-streaming text completion requests. + """ + + id: str = Field(description="Unique identifier for this completion") + object: Literal["text_completion"] = Field( + default="text_completion", description="Object type identifier" + ) + created: int = Field( + default_factory=lambda: int(time.time()), + description="Unix timestamp of creation", + ) + model: str = Field(description="Model used for generation") + choices: list[CompletionChoice] = Field(description="Generated completion choices") + usage: Usage | None = Field(default=None, description="Token usage statistics") + system_fingerprint: str | None = Field( + default=None, description="System configuration fingerprint" + ) + + +class TokenizeRequest(BaseModel): + """Request for tokenizing text into token sequences. + + Converts input text into model-specific token representations + with optional special token handling. + """ + + text: str = Field(description="Text to tokenize") + add_special_tokens: bool | None = Field( + default=True, description="Whether to add model-specific special tokens" + ) + + +class TokenizeResponse(BaseModel): + """Response containing tokenized representation of input text. + + Provides both the token sequence and count for analysis + and token budget planning. + """ + + tokens: list[int] = Field(description="List of token IDs") + count: int = Field(description="Total number of tokens") + + +class DetokenizeRequest(BaseModel): + """Request for converting token sequences back to text. + + Reconstructs human-readable text from model token representations + with configurable special token handling. + """ + + tokens: list[int] = Field(description="List of token IDs to convert") + skip_special_tokens: bool | None = Field( + default=True, description="Whether to skip special tokens in output" + ) + spaces_between_special_tokens: bool | None = Field( + default=True, description="Whether to add spaces between special tokens" + ) + + +class DetokenizeResponse(BaseModel): + """Response containing text reconstructed from tokens. + + Provides the human-readable text representation of the + input token sequence. + """ + + text: str = Field(description="Reconstructed text from tokens") + + +class ErrorDetail(BaseModel): + """Detailed error information for API failures. + + Provides structured error data including message, type classification, + and optional error codes for debugging and error handling. + """ + + message: str = Field(description="Human-readable error description") + type: str = Field(description="Error type classification") + code: str | None = Field( + default=None, description="Optional error code for programmatic handling" + ) + + +class ErrorResponse(BaseModel): + """Standardized error response structure for API failures. + + Wraps error details in a consistent format compatible with + OpenAI API error response conventions. + """ + + error: ErrorDetail = Field(description="Detailed error information") diff --git a/src/guidellm/mock_server/server.py b/src/guidellm/mock_server/server.py new file mode 100644 index 00000000..ff9d5fcd --- /dev/null +++ b/src/guidellm/mock_server/server.py @@ -0,0 +1,168 @@ +""" +High-performance mock server for OpenAI and vLLM API compatibility testing. + +This module provides a Sanic-based mock server that simulates OpenAI and vLLM APIs +with configurable latency, token generation patterns, and response characteristics. +The server supports both streaming and non-streaming endpoints, enabling realistic +performance testing and validation of GuideLLM benchmarking workflows without +requiring actual model deployments. +""" + +from __future__ import annotations + +import time + +from sanic import Sanic, response +from sanic.exceptions import NotFound +from sanic.log import logger +from sanic.request import Request +from sanic.response import HTTPResponse + +from guidellm.mock_server.config import MockServerConfig +from guidellm.mock_server.handlers import ( + ChatCompletionsHandler, + CompletionsHandler, + TokenizerHandler, +) + +__all__ = ["MockServer"] + + +class MockServer: + """ + High-performance mock server implementing OpenAI and vLLM API endpoints. + + Provides a Sanic-based web server that simulates API responses with configurable + timing characteristics for testing and benchmarking purposes. Supports chat + completions, text completions, tokenization endpoints, and model listing with + realistic latency patterns to enable comprehensive performance validation. + + Example: + :: + config = ServerConfig(model="test-model", port=8080) + server = MockServer(config) + server.run() + """ + + def __init__(self, config: MockServerConfig) -> None: + """ + Initialize the mock server with configuration. + + :param config: Server configuration containing network settings and response + timing parameters + """ + self.config = config + self.app = Sanic("guidellm-mock-server") + self.chat_handler = ChatCompletionsHandler(config) + self.completions_handler = CompletionsHandler(config) + self.tokenizer_handler = TokenizerHandler(config) + + self._setup_middleware() + self._setup_routes() + self._setup_error_handlers() + + def _setup_middleware(self): + """Setup middleware for CORS, logging, etc.""" + + @self.app.middleware("request") + async def add_cors_headers(_request: Request): + """Add CORS headers to all requests.""" + + @self.app.middleware("response") + async def add_response_headers(_request: Request, resp: HTTPResponse): + """Add standard response headers.""" + resp.headers["Access-Control-Allow-Origin"] = "*" + resp.headers["Access-Control-Allow-Methods"] = "GET, POST, OPTIONS" + resp.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization" + resp.headers["Server"] = "guidellm-mock-server" + + def _setup_routes(self): # noqa: C901 + @self.app.get("/health") + async def health_check(_request: Request): + return response.json({"status": "healthy", "timestamp": time.time()}) + + @self.app.get("/v1/models") + async def list_models(_request: Request): + return response.json( + { + "object": "list", + "data": [ + { + "id": self.config.model, + "object": "model", + "created": int(time.time()), + "owned_by": "guidellm-mock", + } + ], + } + ) + + @self.app.route("/v1/chat/completions", methods=["POST", "OPTIONS"]) + async def chat_completions(request: Request): + if request.method == "OPTIONS": + return response.text("", status=204) + return await self.chat_handler.handle(request) + + @self.app.route("/v1/completions", methods=["POST", "OPTIONS"]) + async def completions(request: Request): + if request.method == "OPTIONS": + return response.text("", status=204) + return await self.completions_handler.handle(request) + + @self.app.route("/tokenize", methods=["POST", "OPTIONS"]) + async def tokenize(request: Request): + if request.method == "OPTIONS": + return response.text("", status=204) + return await self.tokenizer_handler.tokenize(request) + + @self.app.route("/detokenize", methods=["POST", "OPTIONS"]) + async def detokenize(request: Request): + if request.method == "OPTIONS": + return response.text("", status=204) + return await self.tokenizer_handler.detokenize(request) + + def _setup_error_handlers(self): + """Setup error handlers.""" + + @self.app.exception(Exception) + async def generic_error_handler(_request: Request, exception: Exception): + logger.error(f"Unhandled exception: {exception}") + return response.json( + { + "error": { + "message": "Internal server error", + "type": type(exception).__name__, + "error": str(exception), + } + }, + status=500, + ) + + @self.app.exception(NotFound) + async def not_found_handler(_request: Request, _exception): + return response.json( + { + "error": { + "message": "Not Found", + "type": "not_found_error", + "code": "not_found", + } + }, + status=404, + ) + + def run(self) -> None: + """ + Start the mock server with configured settings. + + Runs the Sanic application in single-process mode with access logging enabled + for debugging and monitoring request patterns during testing. + """ + self.app.run( + host=self.config.host, + port=self.config.port, + debug=False, + single_process=True, + access_log=True, + register_sys_signals=False, # Disable signal handlers for threading + ) diff --git a/src/guidellm/mock_server/utils.py b/src/guidellm/mock_server/utils.py new file mode 100644 index 00000000..8348d0a6 --- /dev/null +++ b/src/guidellm/mock_server/utils.py @@ -0,0 +1,307 @@ +""" +Mock server utilities for text generation and tokenization testing. + +This module provides mock tokenization and text generation utilities for testing +guidellm's mock server functionality. It includes a mock tokenizer that simulates +tokenization processes, functions to generate reproducible fake text with specific +token counts, and timing generators for realistic benchmarking scenarios. +""" + +from __future__ import annotations + +import random +import re +from collections.abc import Generator + +from faker import Faker +from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer, TextInput + +__all__ = [ + "MockTokenizer", + "create_fake_text", + "create_fake_tokens_str", + "sample_number", + "times_generator", +] + + +class MockTokenizer(PreTrainedTokenizer): + """ + Mock tokenizer implementation for testing text processing workflows. + + Provides a simplified tokenizer that splits text using regex patterns and + generates deterministic token IDs based on string hashing. Used for testing + guidellm components without requiring actual model tokenizers. + + :cvar VocabSize: Fixed vocabulary size for the mock tokenizer + """ + + VocabSize = 100000007 + + def __len__(self) -> int: + """ + Get the vocabulary size of the tokenizer. + + :return: The total number of tokens in the vocabulary + """ + return self.VocabSize + + def __call__(self, text: str | list[str], **kwargs) -> list[int]: # noqa: ARG002 + """ + Tokenize text and return token IDs (callable interface). + + :param text: Input text to tokenize + :return: List of token IDs + """ + if isinstance(text, str): + tokens = self.tokenize(text) + return self.convert_tokens_to_ids(tokens) + elif isinstance(text, list): + # Handle batch processing + return [self.__call__(t) for t in text] + else: + msg = f"text input must be of type `str` or `list[str]`, got {type(text)}" + raise ValueError(msg) + + def tokenize(self, text: TextInput, **_kwargs) -> list[str]: + """ + Tokenize input text into a list of token strings. + + Splits text using regex to separate words, punctuation, and whitespace + into individual tokens for processing. + + :param text: Input text to tokenize + :return: List of token strings from the input text + """ + # Split text into tokens: words, spaces, and punctuation + return re.findall(r"\w+|[^\w\s]|\s+", text) + + def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]: + """ + Convert token strings to numeric token IDs. + + Uses deterministic hashing to generate consistent token IDs for + reproducible testing scenarios. + + :param tokens: Single token string or list of token strings + :return: Single token ID or list of token IDs + """ + if isinstance(tokens, str): + return hash(tokens) % self.VocabSize + return [hash(token) % self.VocabSize for token in tokens] + + def convert_ids_to_tokens( + self, ids: int | list[int], _skip_special_tokens: bool = False + ) -> str | list[str]: + """ + Convert numeric token IDs back to token strings. + + Generates fake text tokens using Faker library seeded with token IDs + for deterministic and reproducible token generation. + + :param ids: Single token ID or list of token IDs to convert + :return: Single token string or list of token strings + """ + if not ids and not isinstance(ids, list): + return "" + elif not ids: + return [""] + + if isinstance(ids, int): + fake = Faker() + fake.seed_instance(ids % self.VocabSize) + + return fake.word() + + fake = Faker() + fake.seed_instance(sum(ids) % self.VocabSize) + + target_count = len(ids) + current_count = 0 + tokens = [] + + while current_count < target_count: + text = fake.text( + max_nb_chars=(target_count - current_count) * 10 # oversample + ) + new_tokens = self.tokenize(text) + + if current_count > 0: + new_tokens = [".", " "] + new_tokens + + new_tokens = ( + new_tokens[: target_count - current_count] + if len(new_tokens) > (target_count - current_count) + else new_tokens + ) + tokens += new_tokens + current_count += len(new_tokens) + + return tokens + + def convert_tokens_to_string(self, tokens: list[str]) -> str: + """ + Convert a list of token strings back to a single text string. + + :param tokens: List of token strings to concatenate + :return: Concatenated string from all tokens + """ + return "".join(tokens) + + def _add_tokens( + self, + new_tokens: list[str] | list[AddedToken], # noqa: ARG002 + special_tokens: bool = False, # noqa: ARG002 + ) -> int: + """ + Add new tokens to the tokenizer vocabulary (mock implementation). + + :param new_tokens: List of tokens to add to the vocabulary + :param special_tokens: Whether the tokens are special tokens + :return: Number of tokens actually added (always 0 for mock) + """ + return 0 + + def apply_chat_template( + self, + conversation: list, + tokenize: bool = False, # Changed default to False to match transformers + add_generation_prompt: bool = False, # noqa: ARG002 + **kwargs, # noqa: ARG002 + ) -> str | list[int]: + """ + Apply a chat template to format conversation messages. + + Mock implementation that concatenates all message content for testing. + + :param conversation: List of chat messages + :param tokenize: Whether to return tokens or string + :param add_generation_prompt: Whether to add generation prompt + :return: Formatted text string or token IDs + """ + # Simple concatenation of all message content + texts = [] + for message in conversation: + if isinstance(message, dict) and "content" in message: + texts.append(message["content"]) + elif hasattr(message, "content"): + texts.append(message.content) + + formatted_text = " ".join(texts) + + if tokenize: + return self.convert_tokens_to_ids(self.tokenize(formatted_text)) + return formatted_text + + def decode( + self, + token_ids: list[int], + skip_special_tokens: bool = True, + **kwargs, # noqa: ARG002 + ) -> str: + """ + Decode token IDs back to text string. + + :param token_ids: List of token IDs to decode + :param skip_special_tokens: Whether to skip special tokens + :return: Decoded text string + """ + tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens) + return self.convert_tokens_to_string(tokens) + + +def create_fake_text( + num_tokens: int, + processor: PreTrainedTokenizer, + seed: int = 42, + fake: Faker | None = None, +) -> str: + """ + Generate fake text using a tokenizer processor with specified token count. + + Creates text by generating fake tokens and joining them into a string, + ensuring the result has the exact number of tokens when processed by + the given tokenizer. + + :param num_tokens: Target number of tokens in the generated text + :param processor: Tokenizer to use for token generation and validation + :param seed: Random seed for reproducible text generation + :param fake: Optional Faker instance for text generation + :return: Generated text string with the specified token count + """ + return "".join(create_fake_tokens_str(num_tokens, processor, seed, fake)) + + +def create_fake_tokens_str( + num_tokens: int, + processor: PreTrainedTokenizer, + seed: int = 42, + fake: Faker | None = None, +) -> list[str]: + """ + Generate fake token strings using a tokenizer processor. + + Creates a list of token strings by generating fake text and tokenizing it + until the desired token count is reached. Uses the provided tokenizer + for accurate token boundary detection. + + :param num_tokens: Target number of tokens to generate + :param processor: Tokenizer to use for token generation and validation + :param seed: Random seed for reproducible token generation + :param fake: Optional Faker instance for text generation + :return: List of token strings with the specified count + """ + if not fake: + fake = Faker() + fake.seed_instance(seed) + + tokens = [] + + while len(tokens) < num_tokens: + text = fake.text( + max_nb_chars=(num_tokens - len(tokens)) * 30 # oversample + ) + new_tokens = processor.tokenize(text) + + if len(tokens) > 0: + new_tokens = [".", " "] + new_tokens + + new_tokens = ( + new_tokens[: num_tokens - len(tokens)] + if len(new_tokens) > (num_tokens - len(tokens)) + else new_tokens + ) + tokens += new_tokens + + return tokens + + +def times_generator(mean: float, standard_dev: float) -> Generator[float]: + """ + Generate infinite timing values from a normal distribution. + + Creates a generator that yields timing values sampled from a normal + distribution, useful for simulating realistic request timing patterns + in benchmarking scenarios. + + :param mean: Mean value for the normal distribution + :param standard_dev: Standard deviation for the normal distribution + :return: Generator yielding positive timing values from the distribution + """ + while True: + yield sample_number(mean, standard_dev) + + +def sample_number(mean: float, standard_dev: float) -> float: + """ + Generate a single timing value from a normal distribution. + + Samples one timing value from a normal distribution with the specified + parameters, ensuring the result is non-negative for realistic timing + simulation in benchmarking scenarios. + + :param mean: Mean value for the normal distribution + :param standard_dev: Standard deviation for the normal distribution + :return: Non-negative timing value from the distribution + """ + return max(0.0, random.gauss(mean, standard_dev)) diff --git a/tests/unit/mock_server/__init__.py b/tests/unit/mock_server/__init__.py new file mode 100644 index 00000000..e02d60bd --- /dev/null +++ b/tests/unit/mock_server/__init__.py @@ -0,0 +1 @@ +"""Unit tests for the GuideLLM mock server package.""" diff --git a/tests/unit/mock_server/test_server.py b/tests/unit/mock_server/test_server.py new file mode 100644 index 00000000..ed5c7727 --- /dev/null +++ b/tests/unit/mock_server/test_server.py @@ -0,0 +1,518 @@ +from __future__ import annotations + +import asyncio +import json +import multiprocessing + +import httpx +import pytest +import pytest_asyncio +from pydantic import ValidationError + +from guidellm.mock_server.config import MockServerConfig +from guidellm.mock_server.server import MockServer + + +# Start server in a separate process +def _start_server_process(config: MockServerConfig): + server = MockServer(config) + server.run() + + +@pytest_asyncio.fixture(scope="class") +async def mock_server_instance(): + """Instance-level fixture that provides a running server for HTTP testing.""" + + config = MockServerConfig( + host="127.0.0.1", + port=8012, + model="test-model", + ttft_ms=10.0, + itl_ms=1.0, + request_latency=0.1, + ) + base_url = f"http://{config.host}:{config.port}" + server_process = multiprocessing.Process( + target=_start_server_process, args=(config,) + ) + server_process.start() + + # Wait for server to start up and be ready + async def wait_for_startup(): + poll_frequency = 1.0 + async with httpx.AsyncClient() as client: + while True: + try: + response = await client.get(f"{base_url}/health", timeout=1.0) + if response.status_code == 200: + break + except (httpx.RequestError, httpx.TimeoutException): + pass + await asyncio.sleep(poll_frequency) + poll_frequency = min(poll_frequency * 1.5, 2.0) + + timeout = 30.0 + try: + await asyncio.wait_for(wait_for_startup(), timeout) + except TimeoutError: + # Server failed to start within timeout + server_process.terminate() + server_process.kill() + server_process.join(timeout=5) + pytest.fail(f"Server failed to start within {timeout} seconds") + + yield base_url, config + + # Cleanup: terminate the server process + server_process.terminate() + server_process.kill() + server_process.join(timeout=5) + + +class TestMockServerConfig: + """Test suite for MockServerConfig class.""" + + @pytest.mark.smoke + def test_default_initialization(self): + """Test MockServerConfig initialization with default values.""" + config = MockServerConfig() + assert config.host == "127.0.0.1" + assert config.port == 8000 + assert config.workers == 1 + assert config.model == "llama-3.1-8b-instruct" + assert config.processor is None + assert config.request_latency == 3.0 + assert config.request_latency_std == 0.0 + assert config.ttft_ms == 150.0 + assert config.ttft_ms_std == 0.0 + assert config.itl_ms == 10.0 + assert config.itl_ms_std == 0.0 + assert config.output_tokens == 128 + assert config.output_tokens_std == 0.0 + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("kwargs", "expected_values"), + [ + ( + {"host": "127.0.0.1", "port": 9000, "model": "custom-model"}, + {"host": "127.0.0.1", "port": 9000, "model": "custom-model"}, + ), + ( + {"request_latency": 1.5, "ttft_ms": 100.0, "output_tokens": 256}, + {"request_latency": 1.5, "ttft_ms": 100.0, "output_tokens": 256}, + ), + ], + ) + def test_custom_initialization(self, kwargs, expected_values): + """Test MockServerConfig initialization with custom values.""" + config = MockServerConfig(**kwargs) + for key, expected_value in expected_values.items(): + assert getattr(config, key) == expected_value + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("port", "not_int"), + ("request_latency", "not_float"), + ("output_tokens", "not_int"), + ], + ) + def test_invalid_initialization_values(self, field, value): + """Test MockServerConfig with invalid field values.""" + kwargs = {field: value} + with pytest.raises(ValidationError): + MockServerConfig(**kwargs) + + +class TestMockServer: + """Test suite for MockServer class.""" + + @pytest.mark.smoke + def test_class_signatures(self): + """Test MockServer class signatures and attributes.""" + assert hasattr(MockServer, "__init__") + assert hasattr(MockServer, "run") + assert hasattr(MockServer, "_setup_middleware") + assert hasattr(MockServer, "_setup_routes") + assert hasattr(MockServer, "_setup_error_handlers") + + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test MockServer initialization without required config.""" + with pytest.raises(TypeError): + MockServer() + + +class TestMockServerEndpoints: + """Test suite for MockServer HTTP endpoints with real server instances.""" + + @pytest.mark.smoke + @pytest.mark.asyncio + async def test_health_endpoint(self, mock_server_instance): + """Test the health check endpoint.""" + server_url, _ = mock_server_instance + + async with httpx.AsyncClient() as client: + response = await client.get(f"{server_url}/health", timeout=5.0) + assert response.status_code == 200 + + data = response.json() + assert "status" in data + assert data["status"] == "healthy" + assert "timestamp" in data + assert isinstance(data["timestamp"], (int, float)) + + @pytest.mark.smoke + @pytest.mark.asyncio + async def test_models_endpoint(self, mock_server_instance): + """Test the models listing endpoint.""" + server_url, _ = mock_server_instance + + async with httpx.AsyncClient() as client: + response = await client.get(f"{server_url}/v1/models", timeout=5.0) + assert response.status_code == 200 + + data = response.json() + assert "object" in data + assert data["object"] == "list" + assert "data" in data + assert isinstance(data["data"], list) + assert len(data["data"]) > 0 + + model = data["data"][0] + assert "id" in model + assert "object" in model + assert "created" in model + assert "owned_by" in model + assert model["object"] == "model" + assert model["owned_by"] == "guidellm-mock" + assert model["id"] == "test-model" + + @pytest.mark.smoke + @pytest.mark.asyncio + @pytest.mark.parametrize( + ("payload", "expected_fields"), + [ + ( + { + "model": "test-model", + "messages": [{"role": "user", "content": "Hello!"}], + "max_tokens": 10, + }, + ["choices", "usage", "model", "object"], + ), + ( + { + "model": "test-model", + "messages": [{"role": "user", "content": "Test"}], + "max_tokens": 5, + "temperature": 0.7, + }, + ["choices", "usage", "model", "object"], + ), + ], + ) + async def test_chat_completions_endpoint( + self, mock_server_instance, payload, expected_fields + ): + """Test the chat completions endpoint.""" + server_url, _ = mock_server_instance + + async with httpx.AsyncClient() as client: + response = await client.post( + f"{server_url}/v1/chat/completions", json=payload, timeout=10.0 + ) + assert response.status_code == 200 + + data = response.json() + for field in expected_fields: + assert field in data + + assert len(data["choices"]) > 0 + choice = data["choices"][0] + assert "message" in choice + assert "content" in choice["message"] + assert "role" in choice["message"] + assert choice["message"]["role"] == "assistant" + assert isinstance(choice["message"]["content"], str) + assert len(choice["message"]["content"]) > 0 + + # Verify usage information + assert "prompt_tokens" in data["usage"] + assert "completion_tokens" in data["usage"] + assert "total_tokens" in data["usage"] + assert data["usage"]["total_tokens"] == ( + data["usage"]["prompt_tokens"] + data["usage"]["completion_tokens"] + ) + + @pytest.mark.smoke + @pytest.mark.asyncio + async def test_streaming_chat_completions(self, mock_server_instance): + """Test streaming chat completions endpoint.""" + server_url, _ = mock_server_instance + + payload = { + "model": "test-model", + "messages": [{"role": "user", "content": "Hi!"}], + "max_tokens": 5, + "stream": True, + } + + async with ( + httpx.AsyncClient() as client, + client.stream( + "POST", + f"{server_url}/v1/chat/completions", + json=payload, + timeout=10.0, + ) as response, + ): + assert response.status_code == 200 + assert "text/event-stream" in response.headers.get("content-type", "") + + chunks = [] + async for line in response.aiter_lines(): + if line and line.startswith("data: "): + data_str = line[6:] + if data_str.strip() == "[DONE]": + break + try: + chunk_data = json.loads(data_str) + chunks.append(chunk_data) + except json.JSONDecodeError: + continue + + assert len(chunks) > 0 + # Verify chunk structure + for chunk in chunks: + assert "choices" in chunk + assert len(chunk["choices"]) > 0 + assert "delta" in chunk["choices"][0] + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("payload", "expected_fields"), + [ + ( + { + "model": "test-model", + "prompt": "Hello", + "max_tokens": 10, + }, + ["choices", "usage", "model", "object"], + ), + ( + { + "model": "test-model", + "prompt": "Test prompt", + "max_tokens": 5, + "temperature": 0.8, + }, + ["choices", "usage", "model", "object"], + ), + ], + ) + @pytest.mark.asyncio + async def test_completions_endpoint( + self, mock_server_instance, payload, expected_fields + ): + """Test the legacy completions endpoint.""" + server_url, _ = mock_server_instance + + async with httpx.AsyncClient() as client: + response = await client.post( + f"{server_url}/v1/completions", json=payload, timeout=10.0 + ) + assert response.status_code == 200 + + data = response.json() + for field in expected_fields: + assert field in data + + assert len(data["choices"]) > 0 + choice = data["choices"][0] + assert "text" in choice + assert isinstance(choice["text"], str) + assert len(choice["text"]) > 0 + + # Verify usage information + assert "prompt_tokens" in data["usage"] + assert "completion_tokens" in data["usage"] + assert "total_tokens" in data["usage"] + + @pytest.mark.smoke + @pytest.mark.asyncio + async def test_streaming_completions(self, mock_server_instance): + """Test streaming completions endpoint.""" + server_url, _ = mock_server_instance + payload = { + "model": "test-model", + "prompt": "Hello", + "max_tokens": 5, + "stream": True, + } + + async with ( + httpx.AsyncClient() as client, + client.stream( + "POST", + f"{server_url}/v1/completions", + json=payload, + timeout=10.0, + ) as response, + ): + assert response.status_code == 200 + assert "text/event-stream" in response.headers.get("content-type", "") + + chunks = [] + async for line in response.aiter_lines(): + if line and line.startswith("data: "): + data_str = line[6:] + if data_str.strip() == "[DONE]": + break + try: + chunk_data = json.loads(data_str) + chunks.append(chunk_data) + except json.JSONDecodeError: + continue + + assert len(chunks) > 0 + # Verify chunk structure + for chunk in chunks: + assert "choices" in chunk + assert len(chunk["choices"]) > 0 + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("payload", "expected_fields"), + [ + ( + {"text": "Hello world!"}, + ["tokens", "count"], + ), + ( + {"text": "This is a test sentence."}, + ["tokens", "count"], + ), + ], + ) + @pytest.mark.asyncio + async def test_tokenize_endpoint( + self, mock_server_instance, payload, expected_fields + ): + """Test the tokenize endpoint.""" + server_url, _ = mock_server_instance + async with httpx.AsyncClient() as client: + response = await client.post( + f"{server_url}/tokenize", json=payload, timeout=5.0 + ) + assert response.status_code == 200 + + data = response.json() + for field in expected_fields: + assert field in data + + assert isinstance(data["tokens"], list) + assert isinstance(data["count"], int) + assert data["count"] == len(data["tokens"]) + assert len(data["tokens"]) > 0 + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("payload", "expected_fields"), + [ + ( + {"tokens": [123, 456, 789]}, + ["text"], + ), + ( + {"tokens": [100, 200]}, + ["text"], + ), + ], + ) + @pytest.mark.asyncio + async def test_detokenize_endpoint( + self, mock_server_instance, payload, expected_fields + ): + """Test the detokenize endpoint.""" + server_url, _ = mock_server_instance + async with httpx.AsyncClient() as client: + response = await client.post( + f"{server_url}/detokenize", json=payload, timeout=5.0 + ) + assert response.status_code == 200 + + data = response.json() + for field in expected_fields: + assert field in data + + assert isinstance(data["text"], str) + assert len(data["text"]) > 0 + + @pytest.mark.sanity + @pytest.mark.asyncio + async def test_options_endpoint(self, mock_server_instance): + """Test the OPTIONS endpoint for CORS support.""" + server_url, _ = mock_server_instance + async with httpx.AsyncClient() as client: + response = await client.options( + f"{server_url}/v1/chat/completions", timeout=5.0 + ) + assert response.status_code == 204 + assert response.text == "" + + @pytest.mark.sanity + @pytest.mark.asyncio + async def test_cors_headers(self, mock_server_instance): + """Test CORS headers are properly set.""" + server_url, _ = mock_server_instance + async with httpx.AsyncClient() as client: + response = await client.get(f"{server_url}/health", timeout=5.0) + assert response.status_code == 200 + + # Check for CORS headers + assert response.headers.get("Access-Control-Allow-Origin") == "*" + methods_header = response.headers.get("Access-Control-Allow-Methods", "") + assert "GET, POST, OPTIONS" in methods_header + headers_header = response.headers.get("Access-Control-Allow-Headers", "") + assert "Content-Type, Authorization" in headers_header + assert response.headers.get("Server") == "guidellm-mock-server" + + @pytest.mark.sanity + @pytest.mark.asyncio + @pytest.mark.parametrize( + ("endpoint", "method", "payload"), + [ + ("/v1/chat/completions", "POST", {"invalid": "payload"}), + ("/v1/completions", "POST", {"invalid": "payload"}), + ("/tokenize", "POST", {"invalid": "payload"}), + ("/detokenize", "POST", {"invalid": "payload"}), + ], + ) + async def test_invalid_request_handling( + self, mock_server_instance, endpoint, method, payload + ): + """Test handling of invalid requests.""" + server_url, _ = mock_server_instance + async with httpx.AsyncClient() as client: + if method == "POST": + response = await client.post( + f"{server_url}{endpoint}", json=payload, timeout=5.0 + ) + else: + response = await client.get(f"{server_url}{endpoint}", timeout=5.0) + + # Should return an error response, not crash + assert response.status_code in [400, 422, 500] + + @pytest.mark.sanity + @pytest.mark.asyncio + async def test_nonexistent_endpoint(self, mock_server_instance): + """Test handling of requests to nonexistent endpoints.""" + server_url, _ = mock_server_instance + async with httpx.AsyncClient() as client: + response = await client.get(f"{server_url}/nonexistent", timeout=5.0) + assert response.status_code == 404 From bb981934b690f97965f6616b823905553da48b19 Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Fri, 19 Sep 2025 12:26:42 +0000 Subject: [PATCH 2/6] fixes from copilot review Signed-off-by: Mark Kurtz --- src/guidellm/mock_server/handlers/chat_completions.py | 2 +- src/guidellm/mock_server/handlers/completions.py | 2 +- tests/unit/mock_server/test_server.py | 10 +++++----- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/guidellm/mock_server/handlers/chat_completions.py b/src/guidellm/mock_server/handlers/chat_completions.py index 976901f9..de2781b0 100644 --- a/src/guidellm/mock_server/handlers/chat_completions.py +++ b/src/guidellm/mock_server/handlers/chat_completions.py @@ -251,7 +251,7 @@ async def generate_stream(stream_response): await stream_response.write(f"data: {json.dumps(final_chunk)}\n\n") # Send usage if requested - if req.stream_options and req.stream_options.get("include_usage"): + if req.stream_options and req.stream_options.include_usage: usage_chunk = { "id": completion_id, "object": "chat.completion.chunk", diff --git a/src/guidellm/mock_server/handlers/completions.py b/src/guidellm/mock_server/handlers/completions.py index 418d2b3c..5a4fe27d 100644 --- a/src/guidellm/mock_server/handlers/completions.py +++ b/src/guidellm/mock_server/handlers/completions.py @@ -251,7 +251,7 @@ async def generate_stream(stream_response): await stream_response.write(f"data: {json.dumps(final_chunk)}\n\n") # Send usage if requested - if req.stream_options and req.stream_options.get("include_usage"): + if req.stream_options and req.stream_options.include_usage: usage_chunk = { "id": completion_id, "object": "text_completion", diff --git a/tests/unit/mock_server/test_server.py b/tests/unit/mock_server/test_server.py index ed5c7727..008103c3 100644 --- a/tests/unit/mock_server/test_server.py +++ b/tests/unit/mock_server/test_server.py @@ -378,11 +378,11 @@ async def test_streaming_completions(self, mock_server_instance): except json.JSONDecodeError: continue - assert len(chunks) > 0 - # Verify chunk structure - for chunk in chunks: - assert "choices" in chunk - assert len(chunk["choices"]) > 0 + assert len(chunks) > 0 + # Verify chunk structure + for chunk in chunks: + assert "choices" in chunk + assert len(chunk["choices"]) > 0 @pytest.mark.smoke @pytest.mark.parametrize( From a9a082ad3dfc67d7e6842c32cd1085247eb1bfe9 Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Fri, 19 Sep 2025 13:01:29 +0000 Subject: [PATCH 3/6] Any missing changes / working state for refactor Signed-off-by: Mark Kurtz --- src/guidellm/__main__.py | 645 ++++++++---- src/guidellm/benchmark/scenario.py | 4 +- src/guidellm/logger.py | 3 +- src/guidellm/objects/__init__.py | 17 - src/guidellm/objects/statistics.py | 953 ------------------ src/guidellm/presentation/builder.py | 4 +- src/guidellm/presentation/data_models.py | 4 +- src/guidellm/request/loader.py | 2 +- src/guidellm/request/request.py | 2 +- src/guidellm/settings.py | 49 +- src/guidellm/utils/typing.py | 46 + tests/integration/scheduler/__init__.py | 0 tests/integration/scheduler/test_scheduler.py | 177 ++++ .../scheduler/test_worker_group.py | 181 ++++ tests/unit/conftest.py | 195 ---- tests/unit/mock_backend.py | 266 ++--- tests/unit/mock_benchmark.py | 387 +++---- tests/unit/test_cli.py | 105 -- .../unit/{test_config.py => test_settings.py} | 0 tests/unit/utils/test_typing.py | 123 +++ 20 files changed, 1295 insertions(+), 1868 deletions(-) delete mode 100644 src/guidellm/objects/__init__.py delete mode 100644 src/guidellm/objects/statistics.py create mode 100644 src/guidellm/utils/typing.py create mode 100644 tests/integration/scheduler/__init__.py create mode 100644 tests/integration/scheduler/test_scheduler.py create mode 100644 tests/integration/scheduler/test_worker_group.py delete mode 100644 tests/unit/test_cli.py rename tests/unit/{test_config.py => test_settings.py} (100%) create mode 100644 tests/unit/utils/test_typing.py diff --git a/src/guidellm/__main__.py b/src/guidellm/__main__.py index f82c19cf..675003a9 100644 --- a/src/guidellm/__main__.py +++ b/src/guidellm/__main__.py @@ -1,33 +1,117 @@ +""" +GuideLLM command-line interface providing benchmarking, dataset preprocessing, and +mock server functionality. + +This module serves as the primary entry point for the GuideLLM CLI application, +offering a comprehensive suite of tools for language model evaluation and testing. +It provides three main command groups: benchmark operations for performance testing +against generative models, dataset preprocessing utilities for data preparation and +transformation, and a mock server for testing and development scenarios. The CLI +supports various backends, output formats, and configuration options to accommodate +different benchmarking needs and deployment environments. + +Example: +:: + # Run a benchmark against a model + guidellm benchmark run --target http://localhost:8000 --data dataset.json \\ + --profile sweep + + # Preprocess a dataset + guidellm preprocess dataset input.json output.json --processor gpt2 + + # Start a mock server for testing + guidellm mock-server --host 0.0.0.0 --port 8080 +""" + +from __future__ import annotations + import asyncio import codecs from pathlib import Path -from typing import get_args +from typing import Annotated, Union import click -from pydantic import ValidationError -from guidellm.backend import BackendType +try: + import uvloop + + HAS_UVLOOP: Annotated[ + bool, "Flag indicating if uvloop is available for event loop optimization" + ] = True +except ImportError: + uvloop = None + + HAS_UVLOOP: Annotated[ + bool, "Flag indicating if uvloop is available for event loop optimization" + ] = False + +from guidellm.backends import BackendType from guidellm.benchmark import ( + GenerativeConsoleBenchmarkerProgress, + InjectExtrasAggregator, ProfileType, + benchmark_generative_text, reimport_benchmarks_report, ) -from guidellm.benchmark.entrypoints import benchmark_with_scenario -from guidellm.benchmark.scenario import GenerativeTextScenario, get_builtin_scenarios +from guidellm.benchmark.scenario import ( + GenerativeTextScenario, +) +from guidellm.mock_server import MockServer, MockServerConfig from guidellm.preprocess.dataset import ShortPromptStrategy, process_dataset from guidellm.scheduler import StrategyType from guidellm.settings import print_config -from guidellm.utils import DefaultGroupHandler +from guidellm.utils import Console, DefaultGroupHandler, get_literal_vals from guidellm.utils import cli as cli_tools -STRATEGY_PROFILE_CHOICES = list( - set(list(get_args(ProfileType)) + list(get_args(StrategyType))) -) +__all__ = [ + "STRATEGY_PROFILE_CHOICES", + "benchmark", + "cli", + "config", + "dataset", + "decode_escaped_str", + "from_file", + "mock_server", + "preprocess", + "run", +] + +STRATEGY_PROFILE_CHOICES: Annotated[ + list[str], "Available strategy and profile choices for benchmark execution types" +] = list(get_literal_vals(Union[ProfileType, StrategyType])) + + +def decode_escaped_str(_ctx, _param, value): + """ + Decode escape sequences in Click option values. + + Click automatically escapes characters in option values, converting sequences + like "\\n" to "\\\\n". This function properly decodes these escape sequences + to their intended characters for use in CLI options. + + :param _ctx: Click context (unused) + :param _param: Click parameter (unused) + :param value: String value to decode escape sequences from + :return: Decoded string with proper escape sequences + :raises click.BadParameter: When escape sequence decoding fails + """ + if value is None: + return None + try: + return codecs.decode(value, "unicode_escape") + except Exception as e: + raise click.BadParameter(f"Could not decode escape sequences: {e}") from e @click.group() -@click.version_option(package_name="guidellm", message="guidellm version: %(version)s") def cli(): - pass + """ + Main entry point for the GuideLLM command-line interface. + + This is the root command group that organizes all GuideLLM CLI functionality + into logical subgroups for benchmarking, preprocessing, configuration, and + mock server operations. + """ @cli.group( @@ -36,7 +120,13 @@ def cli(): default="run", ) def benchmark(): - pass + """ + Benchmark command group for running and managing performance tests. + + This command group provides functionality to execute new benchmarks against + generative models and load previously saved benchmark reports for analysis. + Supports various benchmarking strategies, output formats, and backend types. + """ @benchmark.command( @@ -45,42 +135,65 @@ def benchmark(): context_settings={"auto_envvar_prefix": "GUIDELLM"}, ) @click.option( - "--scenario", - type=cli_tools.Union( - click.Path( - exists=True, - readable=True, - file_okay=True, - dir_okay=False, - path_type=Path, - ), - click.Choice(get_builtin_scenarios()), + "--target", + type=str, + help="The target path for the backend to run benchmarks against. For example, http://localhost:8000", +) +@click.option( + "--data", + type=str, + help=( + "The HuggingFace dataset ID, a path to a HuggingFace dataset, " + "a path to a data file csv, json, jsonl, or txt, " + "or a synthetic data config as a json or key=value string." ), +) +@click.option( + "--profile", + "--rate-type", # legacy alias + "profile", + type=click.Choice(STRATEGY_PROFILE_CHOICES), + help=( + "The type of benchmark to run. " + f"Supported types {', '.join(STRATEGY_PROFILE_CHOICES)}. " + ), +) +@click.option( + "--rate", default=None, help=( - "The name of a builtin scenario or path to a config file. " - "Missing values from the config will use defaults. " - "Options specified on the commandline will override the scenario." + "The rates to run the benchmark at. " + "Can be a single number or a comma-separated list of numbers. " + "For rate-type=sweep, this is the number of benchmarks it runs in the sweep. " + "For rate-type=concurrent, this is the number of concurrent requests. " + "For rate-type=async,constant,poisson, this is the rate requests per second. " + "For rate-type=synchronous,throughput, this must not be set." ), ) @click.option( - "--target", - type=str, - help="The target path for the backend to run benchmarks against. For example, http://localhost:8000", + "--random-seed", + default=GenerativeTextScenario.get_default("random_seed"), + type=int, + help="The random seed to use for benchmarking to ensure reproducibility.", ) +# Backend configuration @click.option( - "--backend-type", - type=click.Choice(list(get_args(BackendType))), + "--backend", + "--backend-type", # legacy alias + "backend", + type=click.Choice(list(get_literal_vals(BackendType))), help=( "The type of backend to use to run requests against. Defaults to 'openai_http'." - f" Supported types: {', '.join(get_args(BackendType))}" + f" Supported types: {', '.join(get_literal_vals(BackendType))}" ), - default=GenerativeTextScenario.get_default("backend_type"), + default="openai_http", ) @click.option( - "--backend-args", + "--backend-kwargs", + "--backend-args", # legacy alias + "backend_kwargs", callback=cli_tools.parse_json, - default=GenerativeTextScenario.get_default("backend_args"), + default=None, help=( "A JSON string containing any arguments to pass to the backend as a " "dict with **kwargs. Headers can be removed by setting their value to " @@ -90,16 +203,17 @@ def benchmark(): ) @click.option( "--model", - default=GenerativeTextScenario.get_default("model"), + default=None, type=str, help=( "The ID of the model to benchmark within the backend. " "If None provided (default), then it will use the first model available." ), ) +# Data configuration @click.option( "--processor", - default=GenerativeTextScenario.get_default("processor"), + default=None, type=str, help=( "The processor or tokenizer to use to calculate token counts for statistics " @@ -109,25 +223,16 @@ def benchmark(): ) @click.option( "--processor-args", - default=GenerativeTextScenario.get_default("processor_args"), + default=None, callback=cli_tools.parse_json, help=( "A JSON string containing any arguments to pass to the processor constructor " "as a dict with **kwargs." ), ) -@click.option( - "--data", - type=str, - help=( - "The HuggingFace dataset ID, a path to a HuggingFace dataset, " - "a path to a data file csv, json, jsonl, or txt, " - "or a synthetic data config as a json or key=value string." - ), -) @click.option( "--data-args", - default=GenerativeTextScenario.get_default("data_args"), + default=None, callback=cli_tools.parse_json, help=( "A JSON string containing any arguments to pass to the dataset creation " @@ -136,189 +241,226 @@ def benchmark(): ) @click.option( "--data-sampler", - default=GenerativeTextScenario.get_default("data_sampler"), + default=None, type=click.Choice(["random"]), help=( "The data sampler type to use. 'random' will add a random shuffle on the data. " "Defaults to None" ), ) +# Output configuration @click.option( - "--rate-type", - type=click.Choice(STRATEGY_PROFILE_CHOICES), + "--output-path", + type=click.Path(), + default=Path.cwd(), help=( - "The type of benchmark to run. " - f"Supported types {', '.join(STRATEGY_PROFILE_CHOICES)}. " + "The path to save the output formats to, if the format is a file type. " + "If it is a directory, it will save all output formats selected under it. " + "If it is a file, it will save the corresponding output format to that file. " + "Any output formats that were given that do not match the file extension will " + "be saved in the parent directory of the file path. " + "Defaults to the current working directory. " ), ) @click.option( - "--rate", - default=GenerativeTextScenario.get_default("rate"), + "--output-formats", + multiple=True, + type=str, + default=("console", "json"), # ("console", "json", "html", "csv") help=( - "The rates to run the benchmark at. " - "Can be a single number or a comma-separated list of numbers. " - "For rate-type=sweep, this is the number of benchmarks it runs in the sweep. " - "For rate-type=concurrent, this is the number of concurrent requests. " - "For rate-type=async,constant,poisson, this is the rate requests per second. " - "For rate-type=synchronous,throughput, this must not be set." + "The output formats to use for the benchmark results. " + "Defaults to console, json, html, and csv where the file formats " + "will be saved at the specified output path." ), ) @click.option( - "--max-seconds", - type=float, - default=GenerativeTextScenario.get_default("max_seconds"), - help=( - "The maximum number of seconds each benchmark can run for. " - "If None, will run until max_requests or the data is exhausted." - ), + "--disable-console-outputs", + is_flag=True, + help="Set this flag to disable console output", ) +# Updates configuration @click.option( - "--max-requests", - type=int, - default=GenerativeTextScenario.get_default("max_requests"), - help=( - "The maximum number of requests each benchmark can run for. " - "If None, will run until max_seconds or the data is exhausted." - ), + "--disable-progress", + is_flag=True, + help="Set this flag to disable progress updates to the console", +) +@click.option( + "--display-scheduler-stats", + is_flag=True, + help="Set this flag to display stats for the processes running the benchmarks", ) +# Aggregators configuration @click.option( - "--warmup-percent", + "--output-extras", + callback=cli_tools.parse_json, + help="A JSON string of extra data to save with the output benchmarks", +) +@click.option( + "--warmup", + "--warmup-percent", # legacy alias + "warmup", type=float, - default=GenerativeTextScenario.get_default("warmup_percent"), + default=None, help=( - "The percent of the benchmark (based on max-seconds, max-requets, " - "or lenth of dataset) to run as a warmup and not include in the final results. " - "Defaults to None." + "The specification around the number of requests to run before benchmarking. " + "If within (0, 1), then the percent of requests/time to use for warmup. " + "If >=1, then the number of requests or seconds to use for warmup." + "Whether it's requests/time used is dependent on which constraint is active. " + "Default None for no warmup." ), ) @click.option( - "--cooldown-percent", + "--cooldown", + "--cooldown-percent", # legacy alias + "cooldown", type=float, default=GenerativeTextScenario.get_default("cooldown_percent"), help=( - "The percent of the benchmark (based on max-seconds, max-requets, or lenth " - "of dataset) to run as a cooldown and not include in the final results. " - "Defaults to None." + "The specification around the number of requests to run after benchmarking. " + "If within (0, 1), then the percent of requests/time to use for cooldown. " + "If >=1, then the number of requests or seconds to use for cooldown." + "Whether it's requests/time used is dependent on which constraint is active. " + "Default None for no cooldown." ), ) @click.option( - "--disable-progress", - is_flag=True, - help="Set this flag to disable progress updates to the console", -) -@click.option( - "--display-scheduler-stats", - is_flag=True, - help="Set this flag to display stats for the processes running the benchmarks", -) -@click.option( - "--disable-console-outputs", - is_flag=True, - help="Set this flag to disable console output", -) -@click.option( - "--output-path", - type=click.Path(), - default=Path.cwd() / "benchmarks.json", + "--request-samples", + "--output-sampling", # legacy alias + "request_samples", + type=int, help=( - "The path to save the output to. If it is a directory, " - "it will save benchmarks.json under it. " - "Otherwise, json, yaml, csv, or html files are supported for output types " - "which will be read from the extension for the file path." + "The number of samples for each request status and each benchmark to save " + "in the output file. If None (default), will save all samples. " + "Defaults to 20." ), + default=20, ) +# Constraints configuration @click.option( - "--output-extras", - callback=cli_tools.parse_json, - help="A JSON string of extra data to save with the output benchmarks", + "--max-seconds", + type=float, + default=None, + help=( + "The maximum number of seconds each benchmark can run for. " + "If None, will run until max_requests or the data is exhausted." + ), ) @click.option( - "--output-sampling", + "--max-requests", type=int, + default=None, help=( - "The number of samples to save in the output file. " - "If None (default), will save all samples." + "The maximum number of requests each benchmark can run for. " + "If None, will run until max_seconds or the data is exhausted." ), - default=GenerativeTextScenario.get_default("output_sampling"), ) @click.option( - "--random-seed", - default=GenerativeTextScenario.get_default("random_seed"), + "--max-errors", type=int, - help="The random seed to use for benchmarking to ensure reproducibility.", + default=None, + help="Maximum number of errors allowed before stopping the benchmark", +) +@click.option( + "--max-error-rate", + type=float, + default=None, + help="Maximum error rate allowed before stopping the benchmark", +) +@click.option( + "--max-global-error-rate", + type=float, + default=None, + help="Maximum global error rate allowed across all benchmarks", ) def run( - scenario, target, - backend_type, - backend_args, + data, + profile, + rate, + random_seed, + # Backend Configuration + backend, + backend_kwargs, model, + # Data configuration processor, processor_args, - data, data_args, data_sampler, - rate_type, - rate, - max_seconds, - max_requests, - warmup_percent, - cooldown_percent, + # Output configuration + output_path, + output_formats, + # Updates configuration + disable_console_outputs, disable_progress, display_scheduler_stats, - disable_console_outputs, - output_path, + # Aggregators configuration output_extras, - output_sampling, - random_seed, + warmup, + cooldown, + request_samples, + # Constraints configuration + max_seconds, + max_requests, + max_errors, + max_error_rate, + max_global_error_rate, ): - click_ctx = click.get_current_context() - - overrides = cli_tools.set_if_not_default( - click_ctx, - target=target, - backend_type=backend_type, - backend_args=backend_args, - model=model, - processor=processor, - processor_args=processor_args, - data=data, - data_args=data_args, - data_sampler=data_sampler, - rate_type=rate_type, - rate=rate, - max_seconds=max_seconds, - max_requests=max_requests, - warmup_percent=warmup_percent, - cooldown_percent=cooldown_percent, - output_sampling=output_sampling, - random_seed=random_seed, - ) - - try: - # If a scenario file was specified read from it - if scenario is None: - _scenario = GenerativeTextScenario.model_validate(overrides) - elif isinstance(scenario, Path): - _scenario = GenerativeTextScenario.from_file(scenario, overrides) - else: # Only builtins can make it here; click will catch anything else - _scenario = GenerativeTextScenario.from_builtin(scenario, overrides) - except ValidationError as e: - # Translate pydantic valdation error to click argument error - errs = e.errors(include_url=False, include_context=True, include_input=True) - param_name = "--" + str(errs[0]["loc"][0]).replace("_", "-") - raise click.BadParameter( - errs[0]["msg"], ctx=click_ctx, param_hint=param_name - ) from e + """ + Execute a generative text benchmark against a target model backend. + Runs comprehensive performance testing using various strategies and profiles, + collecting metrics on latency, throughput, error rates, and resource usage. + Supports multiple backends, data sources, output formats, and constraint types + for flexible benchmark configuration. + """ + if HAS_UVLOOP: + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.run( - benchmark_with_scenario( - scenario=_scenario, - show_progress=not disable_progress, - show_progress_scheduler_stats=display_scheduler_stats, - output_console=not disable_console_outputs, + benchmark_generative_text( + target=target, + data=data, + profile=profile, + rate=rate, + random_seed=random_seed, + # Backend configuration + backend=backend, + backend_kwargs=backend_kwargs, + model=model, + # Data configuration + processor=processor, + processor_args=processor_args, + data_args=data_args, + data_sampler=data_sampler, + # Output configuration output_path=output_path, - output_extras=output_extras, + output_formats=[ + fmt + for fmt in output_formats + if not disable_console_outputs or fmt != "console" + ], + # Updates configuration + progress=( + [ + GenerativeConsoleBenchmarkerProgress( + display_scheduler_stats=display_scheduler_stats + ) + ] + if not disable_progress + else None + ), + print_updates=not disable_console_outputs, + # Aggregators configuration + add_aggregators={"extras": InjectExtrasAggregator(extras=output_extras)}, + warmup=warmup, + cooldown=cooldown, + request_samples=request_samples, + # Constraints configuration + max_seconds=max_seconds, + max_requests=max_requests, + max_errors=max_errors, + max_error_rate=max_error_rate, + max_global_error_rate=max_global_error_rate, ) ) @@ -348,21 +490,14 @@ def run( ), ) def from_file(path, output_path): - reimport_benchmarks_report(path, output_path) - - -def decode_escaped_str(_ctx, _param, value): """ - Click auto adds characters. For example, when using --pad-char "\n", - it parses it as "\\n". This method decodes the string to handle escape - sequences correctly. + Load and optionally re-export a previously saved benchmark report. + + Imports benchmark results from a saved file and provides optional conversion + to different output formats. Supports JSON, YAML, and CSV export formats + based on the output file extension. """ - if value is None: - return None - try: - return codecs.decode(value, "unicode_escape") - except Exception as e: - raise click.BadParameter(f"Could not decode escape sequences: {e}") from e + reimport_benchmarks_report(path, output_path) @cli.command( @@ -373,12 +508,25 @@ def decode_escaped_str(_ctx, _param, value): ), ) def config(): + """ + Display available GuideLLM configuration environment variables. + + Prints a comprehensive list of all environment variables that can be used + to configure GuideLLM behavior, including their current values, defaults, + and descriptions. + """ print_config() @cli.group(help="General preprocessing tools and utilities.") def preprocess(): - pass + """ + Preprocessing command group for dataset preparation and transformation. + + This command group provides utilities for converting, processing, and + optimizing datasets for use in GuideLLM benchmarks. Includes functionality + for token count adjustments, format conversions, and data validation. + """ @preprocess.command( @@ -494,6 +642,13 @@ def dataset( hub_dataset_id, random_seed, ): + """ + Convert and process datasets for specific prompt and output token requirements. + + Transforms datasets to meet target token length specifications using various + strategies for handling short prompts and output length adjustments. Supports + multiple input formats and can optionally push results to Hugging Face Hub. + """ process_dataset( data=data, output_path=output_path, @@ -511,5 +666,121 @@ def dataset( ) +@cli.command(help="Start the GuideLLM mock OpenAI/vLLM server for testing.") +@click.option("--host", default="127.0.0.1", help="Host to bind the server to") +@click.option("--port", default=8000, type=int, help="Port to bind the server to") +@click.option("--workers", default=1, type=int, help="Number of worker processes") +@click.option( + "--model", default="llama-3.1-8b-instruct", help="The name of the model to mock" +) +@click.option("--processor", default=None, help="The processor to use for requests") +@click.option( + "--request-latency", + default=3, + type=float, + help="Request latency in seconds for non-streaming requests", +) +@click.option( + "--request-latency-std", + default=0, + type=float, + help=( + "Request latency standard deviation (normal distribution) " + "in seconds for non-streaming requests" + ), +) +@click.option( + "--ttft-ms", + default=150, + type=float, + help="Time to first token in milliseconds for streaming requests", +) +@click.option( + "--ttft-ms-std", + default=0, + type=float, + help=( + "Time to first token standard deviation (normal distribution) in milliseconds" + ), +) +@click.option( + "--itl-ms", + default=10, + type=float, + help="Inter token latency in milliseconds for streaming requests", +) +@click.option( + "--itl-ms-std", + default=0, + type=float, + help=( + "Inter token latency standard deviation (normal distribution) " + "in milliseconds for streaming requests" + ), +) +@click.option( + "--output-tokens", + default=128, + type=int, + help="Output tokens for streaming requests", +) +@click.option( + "--output-tokens-std", + default=0, + type=float, + help=( + "Output tokens standard deviation (normal distribution) for streaming requests" + ), +) +def mock_server( + host: str, + port: int, + workers: int, + model: str, + processor: str | None, + request_latency: float, + request_latency_std: float, + ttft_ms: float, + ttft_ms_std: float, + itl_ms: float, + itl_ms_std: float, + output_tokens: int, + output_tokens_std: float, +): + """ + Start a GuideLLM mock OpenAI/vLLM-compatible server for testing and development. + + Launches a mock server that simulates model inference with configurable latency + characteristics, token generation patterns, and response timing. Useful for + testing GuideLLM benchmarks without requiring actual model deployment or for + development scenarios requiring predictable server behavior. + """ + + config = MockServerConfig( + host=host, + port=port, + workers=workers, + model=model, + processor=processor, + request_latency=request_latency, + request_latency_std=request_latency_std, + ttft_ms=ttft_ms, + ttft_ms_std=ttft_ms_std, + itl_ms=itl_ms, + itl_ms_std=itl_ms_std, + output_tokens=output_tokens, + output_tokens_std=output_tokens_std, + ) + + server = MockServer(config) + console = Console() + console.print_update( + title="GuideLLM mock server starting...", + details=f"Listening on http://{host}:{port} for model {model}", + status="success", + ) + server.run() + + if __name__ == "__main__": cli() diff --git a/src/guidellm/benchmark/scenario.py b/src/guidellm/benchmark/scenario.py index 15e3cd81..3f84f868 100644 --- a/src/guidellm/benchmark/scenario.py +++ b/src/guidellm/benchmark/scenario.py @@ -11,9 +11,9 @@ PreTrainedTokenizerBase, ) -from guidellm.backend.backend import BackendType +from guidellm.backends import BackendType from guidellm.benchmark.profile import ProfileType -from guidellm.scheduler.strategy import StrategyType +from guidellm.scheduler import StrategyType from guidellm.utils import StandardBaseModel __ALL__ = ["Scenario", "GenerativeTextScenario", "get_builtin_scenarios"] diff --git a/src/guidellm/logger.py b/src/guidellm/logger.py index da3464f9..48b41a49 100644 --- a/src/guidellm/logger.py +++ b/src/guidellm/logger.py @@ -71,8 +71,7 @@ def configure_logger(config: LoggingSettings = settings.logging): logger.add( sys.stdout, level=config.console_log_level.upper(), - format="{time:YY-MM-DD HH:mm:ss}|{level: <8} \ - |{name}:{function}:{line} - {message}", + format="{time} | {function} | {level} - {message}", ) if config.log_file or config.log_file_level: diff --git a/src/guidellm/objects/__init__.py b/src/guidellm/objects/__init__.py deleted file mode 100644 index f97f1ef3..00000000 --- a/src/guidellm/objects/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -from .statistics import ( - DistributionSummary, - Percentiles, - RunningStats, - StatusDistributionSummary, - TimeRunningStats, -) - -__all__ = [ - "DistributionSummary", - "Percentiles", - "RunningStats", - "StandardBaseModel", - "StatusBreakdown", - "StatusDistributionSummary", - "TimeRunningStats", -] diff --git a/src/guidellm/objects/statistics.py b/src/guidellm/objects/statistics.py deleted file mode 100644 index 8ba504be..00000000 --- a/src/guidellm/objects/statistics.py +++ /dev/null @@ -1,953 +0,0 @@ -import math -import time as timer -from collections import defaultdict -from typing import Any, Literal, Optional - -import numpy as np -from pydantic import Field, computed_field - -from guidellm.objects.pydantic import StandardBaseModel, StatusBreakdown - -__all__ = [ - "DistributionSummary", - "Percentiles", - "RunningStats", - "StatusDistributionSummary", - "TimeRunningStats", -] - - -class Percentiles(StandardBaseModel): - """ - A pydantic model representing the standard percentiles of a distribution. - """ - - p001: float = Field( - description="The 0.1th percentile of the distribution.", - ) - p01: float = Field( - description="The 1st percentile of the distribution.", - ) - p05: float = Field( - description="The 5th percentile of the distribution.", - ) - p10: float = Field( - description="The 10th percentile of the distribution.", - ) - p25: float = Field( - description="The 25th percentile of the distribution.", - ) - p50: float = Field( - description="The 50th percentile of the distribution.", - ) - p75: float = Field( - description="The 75th percentile of the distribution.", - ) - p90: float = Field( - description="The 90th percentile of the distribution.", - ) - p95: float = Field( - description="The 95th percentile of the distribution.", - ) - p99: float = Field( - description="The 99th percentile of the distribution.", - ) - p999: float = Field( - description="The 99.9th percentile of the distribution.", - ) - - -class DistributionSummary(StandardBaseModel): - """ - A pydantic model representing a statistical summary for a given - distribution of numerical values. - """ - - mean: float = Field( - description="The mean/average of the distribution.", - ) - median: float = Field( - description="The median of the distribution.", - ) - mode: float = Field( - description="The mode of the distribution.", - ) - variance: float = Field( - description="The variance of the distribution.", - ) - std_dev: float = Field( - description="The standard deviation of the distribution.", - ) - min: float = Field( - description="The minimum value of the distribution.", - ) - max: float = Field( - description="The maximum value of the distribution.", - ) - count: int = Field( - description="The number of values in the distribution.", - ) - total_sum: float = Field( - description="The total sum of the values in the distribution.", - ) - percentiles: Percentiles = Field( - description="The percentiles of the distribution.", - ) - cumulative_distribution_function: Optional[list[tuple[float, float]]] = Field( - description="The cumulative distribution function (CDF) of the distribution.", - default=None, - ) - - @staticmethod - def from_distribution_function( - distribution: list[tuple[float, float]], - include_cdf: bool = False, - ) -> "DistributionSummary": - """ - Create a statistical summary for a given distribution of weighted numerical - values or a probability distribution function (PDF). - 1. If the distribution is a PDF, it is expected to be a list of tuples - where each tuple contains (value, probability). The sum of the - probabilities should be 1. If it is not, it will be normalized. - 2. If the distribution is a values distribution function, it is expected - to be a list of tuples where each tuple contains (value, weight). - The weights are normalized to a probability distribution function. - - :param distribution: A list of tuples representing the distribution. - Each tuple contains (value, weight) or (value, probability). - :param include_cdf: Whether to include the calculated cumulative distribution - function (CDF) in the output DistributionSummary. - :return: An instance of DistributionSummary with calculated values. - """ - values, weights = zip(*distribution) if distribution else ([], []) - values = np.array(values) # type: ignore[assignment] - weights = np.array(weights) # type: ignore[assignment] - - # create the PDF - probabilities = weights / np.sum(weights) # type: ignore[operator] - pdf = np.column_stack((values, probabilities)) - pdf = pdf[np.argsort(pdf[:, 0])] - values = pdf[:, 0] # type: ignore[assignment] - probabilities = pdf[:, 1] - - # calculate the CDF - cumulative_probabilities = np.cumsum(probabilities) - cdf = np.column_stack((values, cumulative_probabilities)) - - # calculate statistics - mean = np.sum(values * probabilities).item() # type: ignore[attr-defined] - median = cdf[np.argmax(cdf[:, 1] >= 0.5), 0].item() if len(cdf) > 0 else 0 # noqa: PLR2004 - mode = values[np.argmax(probabilities)].item() if len(values) > 0 else 0 # type: ignore[call-overload] - variance = np.sum((values - mean) ** 2 * probabilities).item() # type: ignore[attr-defined] - std_dev = math.sqrt(variance) - minimum = values[0].item() if len(values) > 0 else 0 - maximum = values[-1].item() if len(values) > 0 else 0 - count = len(values) - total_sum = np.sum(values).item() # type: ignore[attr-defined] - - return DistributionSummary( - mean=mean, - median=median, - mode=mode, - variance=variance, - std_dev=std_dev, - min=minimum, - max=maximum, - count=count, - total_sum=total_sum, - percentiles=( - Percentiles( - p001=cdf[np.argmax(cdf[:, 1] >= 0.001), 0].item(), # noqa: PLR2004 - p01=cdf[np.argmax(cdf[:, 1] >= 0.01), 0].item(), # noqa: PLR2004 - p05=cdf[np.argmax(cdf[:, 1] >= 0.05), 0].item(), # noqa: PLR2004 - p10=cdf[np.argmax(cdf[:, 1] >= 0.1), 0].item(), # noqa: PLR2004 - p25=cdf[np.argmax(cdf[:, 1] >= 0.25), 0].item(), # noqa: PLR2004 - p50=cdf[np.argmax(cdf[:, 1] >= 0.50), 0].item(), # noqa: PLR2004 - p75=cdf[np.argmax(cdf[:, 1] >= 0.75), 0].item(), # noqa: PLR2004 - p90=cdf[np.argmax(cdf[:, 1] >= 0.9), 0].item(), # noqa: PLR2004 - p95=cdf[np.argmax(cdf[:, 1] >= 0.95), 0].item(), # noqa: PLR2004 - p99=cdf[np.argmax(cdf[:, 1] >= 0.99), 0].item(), # noqa: PLR2004 - p999=cdf[np.argmax(cdf[:, 1] >= 0.999), 0].item(), # noqa: PLR2004 - ) - if len(cdf) > 0 - else Percentiles( - p001=0, - p01=0, - p05=0, - p10=0, - p25=0, - p50=0, - p75=0, - p90=0, - p95=0, - p99=0, - p999=0, - ) - ), - cumulative_distribution_function=cdf.tolist() if include_cdf else None, - ) - - @staticmethod - def from_values( - values: list[float], - weights: Optional[list[float]] = None, - include_cdf: bool = False, - ) -> "DistributionSummary": - """ - Create a statistical summary for a given distribution of numerical values. - This is a wrapper around from_distribution_function to handle the optional case - of including weights for the values. If weights are not provided, they are - automatically set to 1.0 for each value, so each value is equally weighted. - - :param values: A list of numerical values representing the distribution. - :param weights: A list of weights for each value in the distribution. - If not provided, all values are equally weighted. - :param include_cdf: Whether to include the calculated cumulative distribution - function (CDF) in the output DistributionSummary. - """ - if weights is None: - weights = [1.0] * len(values) - - if len(values) != len(weights): - raise ValueError( - "The length of values and weights must be the same.", - ) - - return DistributionSummary.from_distribution_function( - distribution=list(zip(values, weights)), - include_cdf=include_cdf, - ) - - @staticmethod - def from_request_times( - requests: list[tuple[float, float]], - distribution_type: Literal["concurrency", "rate"], - include_cdf: bool = False, - epsilon: float = 1e-6, - ) -> "DistributionSummary": - """ - Create a statistical summary for a given distribution of request times. - Specifically, this is used to measure concurrency or rate of requests - given an input list containing the start and end time of each request. - This will first convert the request times into a distribution function - and then calculate the statistics with from_distribution_function. - - :param requests: A list of tuples representing the start and end times of - each request. Example: [(start_1, end_1), (start_2, end_2), ...] - :param distribution_type: The type of distribution to calculate. - Either "concurrency" or "rate". - :param include_cdf: Whether to include the calculated cumulative distribution - function (CDF) in the output DistributionSummary. - :param epsilon: The epsilon value for merging close events. - :return: An instance of DistributionSummary with calculated values. - """ - if distribution_type == "concurrency": - # convert to delta changes based on when requests were running - events = [(start, 1) for start, _ in requests] + [ - (end, -1) for _, end in requests - ] - elif distribution_type == "rate": - # convert to events for when requests finished - global_start = min(start for start, _ in requests) if requests else 0 - events = [(global_start, 1)] + [(end, 1) for _, end in requests] - else: - raise ValueError( - f"Invalid distribution_type '{distribution_type}'. " - "Must be 'concurrency' or 'rate'." - ) - - # combine any events that are very close together - flattened_events: list[tuple[float, float]] = [] - for time, val in sorted(events): - last_time, last_val = ( - flattened_events[-1] if flattened_events else (None, None) - ) - - if ( - last_time is not None - and last_val is not None - and abs(last_time - time) <= epsilon - ): - flattened_events[-1] = (last_time, last_val + val) - else: - flattened_events.append((time, val)) - - if distribution_type == "concurrency": - # convert to the events over time measuring concurrency changes - events_over_time: list[tuple[float, float]] = [] - active = 0 - for time, delta in flattened_events: - active += delta # type: ignore [assignment] - events_over_time.append((time, active)) - - flattened_events = events_over_time - - # convert to value distribution function - distribution: dict[float, float] = defaultdict(float) - - for ind in range(len(flattened_events) - 1): - start_time, value = flattened_events[ind] - end_time, _ = flattened_events[ind + 1] - duration = end_time - start_time - - if distribution_type == "concurrency": - # weight the concurrency value by the duration - distribution[value] += duration - elif distribution_type == "rate": - # weight the rate value by the duration - rate = value / duration - distribution[rate] += duration - - distribution_list: list[tuple[float, float]] = sorted(distribution.items()) - - return DistributionSummary.from_distribution_function( - distribution=distribution_list, - include_cdf=include_cdf, - ) - - @staticmethod - def from_iterable_request_times( - requests: list[tuple[float, float]], - first_iter_times: list[float], - iter_counts: list[int], - first_iter_counts: Optional[list[int]] = None, - include_cdf: bool = False, - epsilon: float = 1e-6, - ) -> "DistributionSummary": - """ - Create a statistical summary for a given distribution of request times - for a request with iterable responses between the start and end. - For example, this is used to measure auto regressive requests where - a request is started and at some later point, iterative responses are - received. This will convert the request times and iterable values into - a distribution function and then calculate the statistics with - from_distribution_function. - - :param requests: A list of tuples representing the start and end times of - each request. Example: [(start_1, end_1), (start_2, end_2), ...] - :param first_iter_times: A list of times when the first iteration of - each request was received. Must be the same length as requests. - :param iter_counts: A list of the total number of iterations for each - request that occurred starting at the first iteration and ending - at the request end time. Must be the same length as requests. - :param first_iter_counts: A list of the number of iterations to log - for the first iteration of each request. For example, when calculating - total number of tokens processed, this is set to the prompt tokens number. - If not provided, defaults to 1 for each request. - :param include_cdf: Whether to include the calculated cumulative distribution - function (CDF) in the output DistributionSummary. - :param epsilon: The epsilon value for merging close events. - :return: An instance of DistributionSummary with calculated values. - """ - - if first_iter_counts is None: - first_iter_counts = [1] * len(requests) - - if ( - len(requests) != len(first_iter_times) - or len(requests) != len(iter_counts) - or len(requests) != len(first_iter_counts) - ): - raise ValueError( - "requests, first_iter_times, iter_counts, and first_iter_counts must" - "be the same length." - f"Given {len(requests)}, {len(first_iter_times)}, {len(iter_counts)}, " - f"{len(first_iter_counts)}", - ) - - # first break up the requests into individual iterable events - events = defaultdict(int) - global_start = min(start for start, _ in requests) if requests else 0 - global_end = max(end for _, end in requests) if requests else 0 - events[global_start] = 0 - events[global_end] = 0 - - for (_, end), first_iter, first_iter_count, total_count in zip( - requests, first_iter_times, first_iter_counts, iter_counts - ): - events[first_iter] += first_iter_count - - if total_count > 1: - iter_latency = (end - first_iter) / (total_count - 1) - for ind in range(1, total_count): - events[first_iter + ind * iter_latency] += 1 - - # combine any events that are very close together - flattened_events: list[tuple[float, int]] = [] - - for time, count in sorted(events.items()): - last_time, last_count = ( - flattened_events[-1] if flattened_events else (None, None) - ) - - if ( - last_time is not None - and last_count is not None - and abs(last_time - time) <= epsilon - ): - flattened_events[-1] = (last_time, last_count + count) - else: - flattened_events.append((time, count)) - - # convert to value distribution function - distribution: dict[float, float] = defaultdict(float) - - for ind in range(len(flattened_events) - 1): - start_time, count = flattened_events[ind] - end_time, _ = flattened_events[ind + 1] - duration = end_time - start_time - rate = count / duration - distribution[rate] += duration - - distribution_list = sorted(distribution.items()) - - return DistributionSummary.from_distribution_function( - distribution=distribution_list, - include_cdf=include_cdf, - ) - - -class StatusDistributionSummary( - StatusBreakdown[ - DistributionSummary, - DistributionSummary, - DistributionSummary, - DistributionSummary, - ] -): - """ - A pydantic model representing a statistical summary for a given - distribution of numerical values grouped by status. - Specifically used to represent the total, successful, incomplete, - and errored values for a benchmark or other statistical summary. - """ - - @staticmethod - def from_values( - value_types: list[Literal["successful", "incomplete", "error"]], - values: list[float], - weights: Optional[list[float]] = None, - include_cdf: bool = False, - ) -> "StatusDistributionSummary": - """ - Create a statistical summary by status for a given distribution of numerical - values. This is used to measure the distribution of values for different - statuses (e.g., successful, incomplete, error) and calculate the statistics - for each status. Weights are optional to weight the probability distribution - for each value by. If not provided, all values are equally weighted. - - :param value_types: A list of status types for each value in the distribution. - Must be one of 'successful', 'incomplete', or 'error'. - :param values: A list of numerical values representing the distribution. - Must be the same length as value_types. - :param weights: A list of weights for each value in the distribution. - If not provided, all values are equally weighted (set to 1). - Must be the same length as value_types. - :param include_cdf: Whether to include the calculated cumulative distribution - function (CDF) in the output StatusDistributionSummary. - :return: An instance of StatusDistributionSummary with calculated values. - """ - if any( - type_ not in {"successful", "incomplete", "error"} for type_ in value_types - ): - raise ValueError( - "value_types must be one of 'successful', 'incomplete', or 'error'. " - f"Got {value_types} instead.", - ) - - if weights is None: - weights = [1.0] * len(values) - - if len(value_types) != len(values) or len(value_types) != len(weights): - raise ValueError( - "The length of value_types, values, and weights must be the same.", - ) - - _, successful_values, successful_weights = ( - zip(*successful) - if ( - successful := list( - filter( - lambda val: val[0] == "successful", - zip(value_types, values, weights), - ) - ) - ) - else ([], [], []) - ) - _, incomplete_values, incomplete_weights = ( - zip(*incomplete) - if ( - incomplete := list( - filter( - lambda val: val[0] == "incomplete", - zip(value_types, values, weights), - ) - ) - ) - else ([], [], []) - ) - _, errored_values, errored_weights = ( - zip(*errored) - if ( - errored := list( - filter( - lambda val: val[0] == "error", - zip(value_types, values, weights), - ) - ) - ) - else ([], [], []) - ) - - return StatusDistributionSummary( - total=DistributionSummary.from_values( - values, - weights, - include_cdf=include_cdf, - ), - successful=DistributionSummary.from_values( - successful_values, # type: ignore[arg-type] - successful_weights, # type: ignore[arg-type] - include_cdf=include_cdf, - ), - incomplete=DistributionSummary.from_values( - incomplete_values, # type: ignore[arg-type] - incomplete_weights, # type: ignore[arg-type] - include_cdf=include_cdf, - ), - errored=DistributionSummary.from_values( - errored_values, # type: ignore[arg-type] - errored_weights, # type: ignore[arg-type] - include_cdf=include_cdf, - ), - ) - - @staticmethod - def from_request_times( - request_types: list[Literal["successful", "incomplete", "error"]], - requests: list[tuple[float, float]], - distribution_type: Literal["concurrency", "rate"], - include_cdf: bool = False, - epsilon: float = 1e-6, - ) -> "StatusDistributionSummary": - """ - Create a statistical summary by status for given distribution of request times. - This is used to measure the distribution of request times for different statuses - (e.g., successful, incomplete, error) for concurrency and rates. - This will call into DistributionSummary.from_request_times to calculate - the statistics for each status. - - :param request_types: List of status types for each request in the distribution. - Must be one of 'successful', 'incomplete', or 'error'. - :param requests: A list of tuples representing the start and end times of - each request. Example: [(start_1, end_1), (start_2, end_2), ...]. - Must be the same length as request_types. - :param distribution_type: The type of distribution to calculate. - Either "concurrency" or "rate". - :param include_cdf: Whether to include the calculated cumulative distribution - function (CDF) in the output StatusDistributionSummary. - :param epsilon: The epsilon value for merging close events. - :return: An instance of StatusDistributionSummary with calculated values. - """ - if distribution_type not in {"concurrency", "rate"}: - raise ValueError( - f"Invalid distribution_type '{distribution_type}'. " - "Must be 'concurrency' or 'rate'." - ) - - if any( - type_ not in {"successful", "incomplete", "error"} - for type_ in request_types - ): - raise ValueError( - "request_types must be one of 'successful', 'incomplete', or 'error'. " - f"Got {request_types} instead.", - ) - - if len(request_types) != len(requests): - raise ValueError( - "The length of request_types and requests must be the same. " - f"Got {len(request_types)} and {len(requests)} instead.", - ) - - _, successful_requests = ( - zip(*successful) - if ( - successful := list( - filter( - lambda val: val[0] == "successful", - zip(request_types, requests), - ) - ) - ) - else ([], []) - ) - _, incomplete_requests = ( - zip(*incomplete) - if ( - incomplete := list( - filter( - lambda val: val[0] == "incomplete", - zip(request_types, requests), - ) - ) - ) - else ([], []) - ) - _, errored_requests = ( - zip(*errored) - if ( - errored := list( - filter( - lambda val: val[0] == "error", - zip(request_types, requests), - ) - ) - ) - else ([], []) - ) - - return StatusDistributionSummary( - total=DistributionSummary.from_request_times( - requests, - distribution_type=distribution_type, - include_cdf=include_cdf, - epsilon=epsilon, - ), - successful=DistributionSummary.from_request_times( - successful_requests, # type: ignore[arg-type] - distribution_type=distribution_type, - include_cdf=include_cdf, - epsilon=epsilon, - ), - incomplete=DistributionSummary.from_request_times( - incomplete_requests, # type: ignore[arg-type] - distribution_type=distribution_type, - include_cdf=include_cdf, - epsilon=epsilon, - ), - errored=DistributionSummary.from_request_times( - errored_requests, # type: ignore[arg-type] - distribution_type=distribution_type, - include_cdf=include_cdf, - epsilon=epsilon, - ), - ) - - @staticmethod - def from_iterable_request_times( - request_types: list[Literal["successful", "incomplete", "error"]], - requests: list[tuple[float, float]], - first_iter_times: list[float], - iter_counts: Optional[list[int]] = None, - first_iter_counts: Optional[list[int]] = None, - include_cdf: bool = False, - epsilon: float = 1e-6, - ) -> "StatusDistributionSummary": - """ - Create a statistical summary by status for given distribution of request times - for a request with iterable responses between the start and end. - For example, this is used to measure auto regressive requests where - a request is started and at some later point, iterative responses are - received. This will call into DistributionSummary.from_iterable_request_times - to calculate the statistics for each status. - - :param request_types: List of status types for each request in the distribution. - Must be one of 'successful', 'incomplete', or 'error'. - :param requests: A list of tuples representing the start and end times of - each request. Example: [(start_1, end_1), (start_2, end_2), ...]. - Must be the same length as request_types. - :param first_iter_times: A list of times when the first iteration of - each request was received. Must be the same length as requests. - :param iter_counts: A list of the total number of iterations for each - request that occurred starting at the first iteration and ending - at the request end time. Must be the same length as requests. - If not provided, defaults to 1 for each request. - :param first_iter_counts: A list of the number of iterations to log - for the first iteration of each request. For example, when calculating - total number of tokens processed, this is set to the prompt tokens number. - If not provided, defaults to 1 for each request. - :param include_cdf: Whether to include the calculated cumulative distribution - function (CDF) in the output StatusDistributionSummary. - :param epsilon: The epsilon value for merging close events. - :return: An instance of StatusDistributionSummary with calculated values. - """ - if any( - type_ not in {"successful", "incomplete", "error"} - for type_ in request_types - ): - raise ValueError( - "request_types must be one of 'successful', 'incomplete', or 'error'. " - f"Got {request_types} instead.", - ) - - if iter_counts is None: - iter_counts = [1] * len(requests) - - if first_iter_counts is None: - first_iter_counts = [1] * len(requests) - - if ( - len(request_types) != len(requests) - or len(requests) != len(first_iter_times) - or len(requests) != len(iter_counts) - or len(requests) != len(first_iter_counts) - ): - raise ValueError( - "request_types, requests, first_iter_times, iter_counts, and " - "first_iter_counts must be the same length." - f"Given {len(request_types)}, {len(requests)}, " - f"{len(first_iter_times)}, {len(iter_counts)}, " - f"{len(first_iter_counts)}", - ) - - ( - _, - successful_requests, - successful_first_iter_times, - successful_iter_counts, - successful_first_iter_counts, - ) = ( - zip(*successful) - if ( - successful := list( - filter( - lambda val: val[0] == "successful", - zip( - request_types, - requests, - first_iter_times, - iter_counts, - first_iter_counts, - ), - ) - ) - ) - else ([], [], [], [], []) - ) - ( - _, - incomplete_requests, - incomplete_first_iter_times, - incomplete_iter_counts, - incomplete_first_iter_counts, - ) = ( - zip(*incomplete) - if ( - incomplete := list( - filter( - lambda val: val[0] == "incomplete", - zip( - request_types, - requests, - first_iter_times, - iter_counts, - first_iter_counts, - ), - ) - ) - ) - else ([], [], [], [], []) - ) - ( - _, - errored_requests, - errored_first_iter_times, - errored_iter_counts, - errored_first_iter_counts, - ) = ( - zip(*errored) - if ( - errored := list( - filter( - lambda val: val[0] == "error", - zip( - request_types, - requests, - first_iter_times, - iter_counts, - first_iter_counts, - ), - ) - ) - ) - else ([], [], [], [], []) - ) - - return StatusDistributionSummary( - total=DistributionSummary.from_iterable_request_times( - requests, - first_iter_times, - iter_counts, - first_iter_counts, - include_cdf=include_cdf, - epsilon=epsilon, - ), - successful=DistributionSummary.from_iterable_request_times( - successful_requests, # type: ignore[arg-type] - successful_first_iter_times, # type: ignore[arg-type] - successful_iter_counts, # type: ignore[arg-type] - successful_first_iter_counts, # type: ignore[arg-type] - include_cdf=include_cdf, - epsilon=epsilon, - ), - incomplete=DistributionSummary.from_iterable_request_times( - incomplete_requests, # type: ignore[arg-type] - incomplete_first_iter_times, # type: ignore[arg-type] - incomplete_iter_counts, # type: ignore[arg-type] - incomplete_first_iter_counts, # type: ignore[arg-type] - include_cdf=include_cdf, - epsilon=epsilon, - ), - errored=DistributionSummary.from_iterable_request_times( - errored_requests, # type: ignore[arg-type] - errored_first_iter_times, # type: ignore[arg-type] - errored_iter_counts, # type: ignore[arg-type] - errored_first_iter_counts, # type: ignore[arg-type] - include_cdf=include_cdf, - epsilon=epsilon, - ), - ) - - -class RunningStats(StandardBaseModel): - """ - Create a running statistics object to track the mean, rate, and other - statistics of a stream of values. - 1. The start time is set to the time the object is created. - 2. The count is set to 0. - 3. The total is set to 0. - 4. The last value is set to 0. - 5. The mean is calculated as the total / count. - """ - - start_time: float = Field( - default_factory=timer.time, - description=( - "The time the running statistics object was created. " - "This is used to calculate the rate of the statistics." - ), - ) - count: int = Field( - default=0, - description="The number of values added to the running statistics.", - ) - total: float = Field( - default=0.0, - description="The total sum of the values added to the running statistics.", - ) - last: float = Field( - default=0.0, - description="The last value added to the running statistics.", - ) - - @computed_field # type: ignore[misc] - @property - def mean(self) -> float: - """ - :return: The mean of the running statistics (total / count). - If count is 0, return 0.0. - """ - if self.count == 0: - return 0.0 - return self.total / self.count - - @computed_field # type: ignore[misc] - @property - def rate(self) -> float: - """ - :return: The rate of the running statistics - (total / (time.time() - start_time)). - If count is 0, return 0.0. - """ - if self.count == 0: - return 0.0 - return self.total / (timer.time() - self.start_time) - - def __add__(self, value: Any) -> float: - """ - Enable the use of the + operator to add a value to the running statistics. - - :param value: The value to add to the running statistics. - :return: The mean of the running statistics. - """ - if not isinstance(value, (int, float)): - raise ValueError( - f"Value must be an int or float, got {type(value)} instead.", - ) - - self.update(value) - - return self.mean - - def __iadd__(self, value: Any) -> "RunningStats": - """ - Enable the use of the += operator to add a value to the running statistics. - - :param value: The value to add to the running statistics. - :return: The running statistics object. - """ - if not isinstance(value, (int, float)): - raise ValueError( - f"Value must be an int or float, got {type(value)} instead.", - ) - - self.update(value) - - return self - - def update(self, value: float, count: int = 1) -> None: - """ - Update the running statistics with a new value. - - :param value: The new value to add to the running statistics. - :param count: The number of times to 'count' for the value. - If not provided, defaults to 1. - """ - self.count += count - self.total += value - self.last = value - - -class TimeRunningStats(RunningStats): - """ - Create a running statistics object to track the mean, rate, and other - statistics of a stream of time values. This is used to track time values - in milliseconds and seconds. - - Adds time specific computed_fields such as measurements in milliseconds and seconds. - """ - - @computed_field # type: ignore[misc] - @property - def total_ms(self) -> float: - """ - :return: The total time multiplied by 1000.0 to convert to milliseconds. - """ - return self.total * 1000.0 - - @computed_field # type: ignore[misc] - @property - def last_ms(self) -> float: - """ - :return: The last time multiplied by 1000.0 to convert to milliseconds. - """ - return self.last * 1000.0 - - @computed_field # type: ignore[misc] - @property - def mean_ms(self) -> float: - """ - :return: The mean time multiplied by 1000.0 to convert to milliseconds. - """ - return self.mean * 1000.0 - - @computed_field # type: ignore[misc] - @property - def rate_ms(self) -> float: - """ - :return: The rate of the running statistics multiplied by 1000.0 - to convert to milliseconds. - """ - return self.rate * 1000.0 diff --git a/src/guidellm/presentation/builder.py b/src/guidellm/presentation/builder.py index a27d7cec..6ea9c5c3 100644 --- a/src/guidellm/presentation/builder.py +++ b/src/guidellm/presentation/builder.py @@ -1,9 +1,9 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: - from guidellm.benchmark.benchmark import GenerativeBenchmark + from guidellm.benchmark import GenerativeBenchmark -from .data_models import BenchmarkDatum, RunInfo, WorkloadDetails +from guidellm.presentation.data_models import BenchmarkDatum, RunInfo, WorkloadDetails class UIDataBuilder: diff --git a/src/guidellm/presentation/data_models.py b/src/guidellm/presentation/data_models.py index 989ca8ab..9036636a 100644 --- a/src/guidellm/presentation/data_models.py +++ b/src/guidellm/presentation/data_models.py @@ -6,9 +6,9 @@ from pydantic import BaseModel, computed_field if TYPE_CHECKING: - from guidellm.benchmark.benchmark import GenerativeBenchmark + from guidellm.benchmark import GenerativeBenchmark -from guidellm.objects.statistics import DistributionSummary +from guidellm.utils import DistributionSummary class Bucket(BaseModel): diff --git a/src/guidellm/request/loader.py b/src/guidellm/request/loader.py index 1c875046..607a7455 100644 --- a/src/guidellm/request/loader.py +++ b/src/guidellm/request/loader.py @@ -12,9 +12,9 @@ from transformers import PreTrainedTokenizerBase # type: ignore[import] from guidellm.dataset import ColumnInputTypes, load_dataset -from guidellm.objects import StandardBaseModel from guidellm.request.request import GenerationRequest from guidellm.settings import settings +from guidellm.utils import StandardBaseModel __all__ = [ "GenerativeRequestLoader", diff --git a/src/guidellm/request/request.py b/src/guidellm/request/request.py index 81c8cabd..bf4e59fb 100644 --- a/src/guidellm/request/request.py +++ b/src/guidellm/request/request.py @@ -3,7 +3,7 @@ from pydantic import Field -from guidellm.objects.pydantic import StandardBaseModel +from guidellm.utils import StandardBaseModel __all__ = ["GenerationRequest"] diff --git a/src/guidellm/settings.py b/src/guidellm/settings.py index 72178425..20d9ff96 100644 --- a/src/guidellm/settings.py +++ b/src/guidellm/settings.py @@ -1,8 +1,9 @@ +from __future__ import annotations + import json -import os from collections.abc import Sequence from enum import Enum -from typing import Literal, Optional +from typing import Literal from pydantic import BaseModel, Field, model_validator from pydantic_settings import BaseSettings, SettingsConfigDict @@ -46,8 +47,8 @@ class LoggingSettings(BaseModel): disabled: bool = False clear_loggers: bool = True console_log_level: str = "WARNING" - log_file: Optional[str] = None - log_file_level: Optional[str] = None + log_file: str | None = None + log_file_level: str | None = None class DatasetSettings(BaseModel): @@ -80,11 +81,11 @@ class OpenAISettings(BaseModel): for OpenAI server based pathways """ - api_key: Optional[str] = None - bearer_token: Optional[str] = None - headers: Optional[dict[str, str]] = None - organization: Optional[str] = None - project: Optional[str] = None + api_key: str | None = None + bearer_token: str | None = None + headers: dict[str, str] | None = None + organization: str | None = None + project: str | None = None base_url: str = "http://localhost:8000" max_output_tokens: int = 16384 verify: bool = True @@ -131,24 +132,30 @@ class Settings(BaseSettings): request_http2: bool = True # Scheduler settings + mp_context_type: Literal["spawn", "fork", "forkserver"] | None = "fork" + mp_serialization: Literal["dict", "sequence"] | None = "dict" + mp_encoding: ( + Literal["msgpack", "msgspec"] + | None + | list[Literal["msgpack", "msgspec"] | None] + ) = ["msgspec", "msgpack", None] + mp_messaging_object: Literal["queue", "manager_queue", "pipe"] = "queue" + mp_requests_send_buffer_size: int = 1 + mp_poll_interval: float = 0.1 + mp_max_pending_buffer_percent: float = 0.5 + mp_max_worker_buffer_percent: float = 0.2 max_concurrency: int = 512 - max_worker_processes: int = Field( - # use number of CPUs - 1, but at least 10 - default_factory=lambda: max((os.cpu_count() or 1) - 1, 10) - ) - min_queued_requests: int = 20 - scheduler_start_delay: float = 5 + max_worker_processes: int = 10 + scheduler_start_delay_non_distributed: float = 1.0 + constraint_error_window_size: float = 30 + constraint_error_min_processed: float = 30 # Data settings dataset: DatasetSettings = DatasetSettings() # Request/stats settings - preferred_prompt_tokens_source: Optional[ - Literal["request", "response", "local"] - ] = "response" - preferred_output_tokens_source: Optional[ - Literal["request", "response", "local"] - ] = "response" + preferred_prompt_tokens_source: Literal["request", "response"] = "response" + preferred_output_tokens_source: Literal["request", "response"] = "response" preferred_backend: Literal["openai"] = "openai" preferred_route: Literal["text_completions", "chat_completions"] = ( "text_completions" diff --git a/src/guidellm/utils/typing.py b/src/guidellm/utils/typing.py new file mode 100644 index 00000000..8146ea1e --- /dev/null +++ b/src/guidellm/utils/typing.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from collections.abc import Iterator +from typing import Annotated, Literal, Union, get_args, get_origin + +# Backwards compatibility for Python <3.10 +try: + from types import UnionType # type: ignore[attr-defined] +except ImportError: + UnionType = Union + +# Backwards compatibility for Python <3.12 +try: + from typing import TypeAliasType # type: ignore[attr-defined] +except ImportError: + from typing_extensions import TypeAliasType + + +__all__ = ["get_literal_vals"] + + +def get_literal_vals(alias) -> frozenset[str]: + """Extract all literal values from a (possibly nested) type alias.""" + + def resolve(alias) -> Iterator[str]: + origin = get_origin(alias) + + # Base case: Literal types + if origin is Literal: + for literal_val in get_args(alias): + yield str(literal_val) + # Unwrap Annotated type + elif origin is Annotated: + yield from resolve(get_args(alias)[0]) + # Unwrap TypeAliasTypes + elif isinstance(alias, TypeAliasType): + yield from resolve(alias.__value__) + # Iterate over unions + elif origin in (Union, UnionType): + for arg in get_args(alias): + yield from resolve(arg) + # Fallback + else: + yield str(alias) + + return frozenset(resolve(alias)) diff --git a/tests/integration/scheduler/__init__.py b/tests/integration/scheduler/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration/scheduler/test_scheduler.py b/tests/integration/scheduler/test_scheduler.py new file mode 100644 index 00000000..51abf59b --- /dev/null +++ b/tests/integration/scheduler/test_scheduler.py @@ -0,0 +1,177 @@ +from __future__ import annotations + +import asyncio +import random +import uuid +from collections import defaultdict +from functools import wraps +from typing import Any + +import pytest +from pydantic import BaseModel, Field + +from guidellm.scheduler import ( + BackendInterface, + ConstraintInitializer, + Environment, + MaxNumberConstraint, + NonDistributedEnvironment, + ScheduledRequestInfo, + Scheduler, + SchedulerState, + SchedulingStrategy, + SynchronousStrategy, +) + + +def async_timeout(delay: float): + """Decorator to add timeout to async test functions.""" + + def decorator(func): + @wraps(func) + async def new_func(*args, **kwargs): + return await asyncio.wait_for(func(*args, **kwargs), timeout=delay) + + return new_func + + return decorator + + +class MockRequest(BaseModel): + payload: str + id_: str = Field(default_factory=lambda: str(uuid.uuid4())) + + +class MockBackend(BackendInterface): + """Mock backend for integration testing with predictable responses.""" + + def __init__( + self, + processes_limit_value: int | None = None, + requests_limit_value: int | None = None, + error_rate: float = 0.2, + response_delay: float = 0.0, + ): + self._processes_limit = processes_limit_value + self._requests_limit = requests_limit_value + self._error_rate = error_rate + self._response_delay = response_delay + + @property + def processes_limit(self) -> int | None: + return self._processes_limit + + @property + def requests_limit(self) -> int | None: + return self._requests_limit + + def info(self) -> dict[str, Any]: + return {"type": "mock_integration", "delay": self._response_delay} + + async def process_startup(self): + pass + + async def validate(self): + pass + + async def process_shutdown(self): + pass + + async def resolve(self, request: MockRequest, request_info, request_history): + """Return predictable response based on input request.""" + await asyncio.sleep(self._response_delay) + + if ( + self._error_rate + and self._error_rate > 0 + and random.random() < self._error_rate + ): + raise RuntimeError(f"mock_error_for_{request.payload}") + + yield f"response_for_{request.payload}", request_info + + +@pytest.mark.smoke +@pytest.mark.asyncio +@async_timeout(10.0) +@pytest.mark.parametrize( + ("strategy", "env", "constraint_inits"), + [ + ( + SynchronousStrategy(), + NonDistributedEnvironment(), + {"max_number": MaxNumberConstraint(max_num=100)}, + ), + ], +) +async def test_scheduler_run_integration( + strategy: SchedulingStrategy, + env: Environment, + constraint_inits: dict[str, ConstraintInitializer], +): + """Integration test for full scheduler workflow.""" + # Clear singleton state + if hasattr(Scheduler, "singleton_instance"): + Scheduler.singleton_instance = None + + scheduler = Scheduler() + constraints = { + key: init.create_constraint() for key, init in constraint_inits.items() + } + received_updates = defaultdict(list) + received_responses = [] + last_state = None + num_requests = 50 + + async for resp, req, info, state in scheduler.run( + requests=[MockRequest(payload=f"req_{ind}") for ind in range(num_requests)], + backend=MockBackend(), + strategy=strategy, + env=env, + **constraints, + ): + assert req is not None + assert isinstance(req, MockRequest) + assert isinstance(info, ScheduledRequestInfo) + assert info.status != "cancelled" + assert isinstance(state, SchedulerState) + if info.status == "completed": + assert resp == f"response_for_{req.payload}" + received_responses.append(resp) + elif info.status == "errored": + assert resp is None + assert info.error is not None + assert info.error == f"mock_error_for_{req.payload}" + received_responses.append(info.error) + + if len(received_updates[req.payload]) < 3: + received_updates[req.payload].append(info.status) + last_state = state + + assert len(received_updates) == num_requests + assert len(received_responses) == constraints["max_number"].max_num + assert last_state.created_requests == constraints["max_number"].max_num + assert last_state.queued_requests == 0 + assert last_state.processing_requests == 0 + assert last_state.processed_requests == constraints["max_number"].max_num + assert last_state.cancelled_requests == 0 + assert ( + last_state.successful_requests + last_state.errored_requests + ) == constraints["max_number"].max_num + + def _request_indices(): + while True: + yield from range(num_requests) + + for index, req, statuses, resp in zip( + _request_indices(), + received_updates.keys(), + received_updates.values(), + received_responses, + ): + assert req == f"req_{index}" + assert resp in (f"response_for_{req}", f"mock_error_for_{req}") + assert statuses in ( + ["queued", "in_progress", "completed"], + ["queued", "in_progress", "errored"], + ) diff --git a/tests/integration/scheduler/test_worker_group.py b/tests/integration/scheduler/test_worker_group.py new file mode 100644 index 00000000..c3be2b99 --- /dev/null +++ b/tests/integration/scheduler/test_worker_group.py @@ -0,0 +1,181 @@ +""" +Integration tests for WorkerProcessGroup. + +Tests the complete lifecycle of the worker group with real multiprocessing +worker processes and a mock backend. Validates end-to-end functionality +across different scheduling strategies and constraints. +""" + +from __future__ import annotations + +import asyncio +import random +import time +from collections import defaultdict +from functools import wraps +from typing import Any + +import pytest + +from guidellm.scheduler import ( + AsyncConstantStrategy, + AsyncPoissonStrategy, + BackendInterface, + ConcurrentStrategy, + MaxDurationConstraint, + MaxErrorRateConstraint, + MaxErrorsConstraint, + MaxGlobalErrorRateConstraint, + MaxNumberConstraint, + MeasuredRequestTimings, + SynchronousStrategy, + ThroughputStrategy, + WorkerProcessGroup, +) +from guidellm.scheduler.constraints import ConstraintInitializer +from guidellm.scheduler.strategies import SchedulingStrategy + + +def async_timeout(delay): + def decorator(func): + @wraps(func) + async def new_func(*args, **kwargs): + return await asyncio.wait_for(func(*args, **kwargs), timeout=delay) + + return new_func + + return decorator + + +class MockRequestTimings(MeasuredRequestTimings): + """Mock timing implementation for integration testing.""" + + +class MockBackend(BackendInterface): + """Mock backend for integration testing with predictable responses.""" + + def __init__( + self, + processes_limit_value: int | None = None, + requests_limit_value: int | None = None, + error_rate: float = 0.2, + response_delay: float = 0.0, + ): + self._processes_limit = processes_limit_value + self._requests_limit = requests_limit_value + self._error_rate = error_rate + self._response_delay = response_delay + + @property + def processes_limit(self) -> int | None: + return self._processes_limit + + @property + def requests_limit(self) -> int | None: + return self._requests_limit + + def info(self) -> dict[str, Any]: + return {"type": "mock_integration", "delay": self._response_delay} + + async def process_startup(self): + pass + + async def validate(self): + pass + + async def process_shutdown(self): + pass + + async def resolve(self, request, request_info, request_history): + """Return predictable response based on input request.""" + # Simulate processing time + await asyncio.sleep(self._response_delay) + + if ( + self._error_rate + and self._error_rate > 0 + and random.random() < self._error_rate + ): + raise RuntimeError("Mock error for testing") + + yield f"response_for_{request}", request_info + + +class TestWorkerGroup: + @pytest.mark.smoke + @pytest.mark.asyncio + @async_timeout(5) + @pytest.mark.parametrize( + "strategy", + [ + SynchronousStrategy(), + ConcurrentStrategy(streams=10), + ThroughputStrategy(max_concurrency=20), + AsyncConstantStrategy(rate=1000.0), + AsyncPoissonStrategy(rate=1000.0), + ], + ) + @pytest.mark.parametrize( + "constraints_inits", + [ + {"max_num": MaxNumberConstraint(max_num=100)}, + {"max_duration": MaxDurationConstraint(max_duration=0.5)}, + {"max_errors": MaxErrorsConstraint(max_errors=20)}, + {"max_error_rate": MaxErrorRateConstraint(max_error_rate=0.1)}, + {"max_global_error_rate": MaxGlobalErrorRateConstraint(max_error_rate=0.1)}, + ], + ) + async def test_lifecycle( + self, + strategy: SchedulingStrategy, + constraints_inits: dict[str, ConstraintInitializer], + ): + """Test comprehensive lifecycle with different strategies and constraints.""" + # Setup + backend = MockBackend(response_delay=0.01, processes_limit_value=1) + requests = [f"request_{ind}" for ind in range(1000)] + group = WorkerProcessGroup( + backend=backend, + requests=requests, + strategy=strategy, + constraints={ + key: init.create_constraint() for key, init in constraints_inits.items() + }, + infinite_requests=False, + ) + + try: + # Create processes + await group.create_processes() + assert group.processes is not None + assert len(group.processes) > 0 + assert group.mp_context is not None + + # Start processing + start_time = time.time() + 0.1 + await group.start(start_time) + actual_start = time.time() + assert actual_start == pytest.approx(start_time) + + # Validate scheduler state + assert group.scheduler_state is not None + assert group.scheduler_state.start_time == start_time + assert group.scheduler_state.num_processes == len(group.processes) + + # Collect all request updates + received_updates = defaultdict(list) + received_responses = [] + + async for ( + response, + request, + request_info, + _state, + ) in group.request_updates(): + received_updates[request].append(request_info.status) + if response is not None: + received_responses.append(response) + finally: + # Clean shutdown + exceptions = await group.shutdown() + assert len(exceptions) == 0, f"Shutdown errors: {exceptions}" diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 92bb89e1..e69de29b 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,195 +0,0 @@ -import json -from collections.abc import AsyncIterable -from typing import Any, Literal, Optional -from unittest.mock import MagicMock, patch - -import httpx -import pytest -import respx - -from guidellm.backends import ResponseSummary, StreamingTextResponse - -from .mock_backend import MockBackend - - -@pytest.fixture -def mock_auto_tokenizer(): - with patch("transformers.AutoTokenizer.from_pretrained") as mock_from_pretrained: - - def _fake_tokenize(text: str) -> list[int]: - tokens = text.split() - return [0] * len(tokens) - - mock_tokenizer = MagicMock() - mock_tokenizer.tokenize = MagicMock(side_effect=_fake_tokenize) - mock_from_pretrained.return_value = mock_tokenizer - yield mock_tokenizer - - -@pytest.fixture -def mock_backend(request): - params = request.param if hasattr(request, "param") else {} - kwargs = {} - - for key in ("model", "target", "iter_delay"): - if key in params: - kwargs[key] = params[key] - - return MockBackend(**kwargs) - - -class MockCompletionsIter(AsyncIterable): - def __init__( - self, - type_: Literal["text", "chat"], - prompt: str, - output_token_count: Optional[int], - target: Optional[str] = None, - model: Optional[str] = None, - iter_delay: Optional[float] = None, - ): - self._type = type_ - self._backend = MockBackend( - model=model, - target=target, - iter_delay=iter_delay, - ) - self._prompt = prompt - self._output_token_count = output_token_count - - async def __aiter__(self): - async for token_iter in ( - self._backend.text_completions( - prompt=self._prompt, output_token_count=self._output_token_count - ) - if self._type == "text" - else self._backend.chat_completions( - content=self._prompt, output_token_count=self._output_token_count - ) - ): - if ( - isinstance(token_iter, StreamingTextResponse) - and token_iter.type_ == "start" - ): - continue - - data: dict[str, Any] - - if isinstance(token_iter, StreamingTextResponse): - if self._type == "text": - data = { - "choices": [ - { - "index": token_iter.iter_count, - "text": token_iter.delta, - } - ] - } - elif self._type == "chat": - data = { - "choices": [ - { - "index": token_iter.iter_count, - "delta": {"content": token_iter.delta}, - } - ] - } - else: - raise ValueError("Invalid type for mock completions") - elif isinstance(token_iter, ResponseSummary): - data = { - "usage": { - "prompt_tokens": ( - len(self._prompt.split()) + self._prompt.count(" ") - ), - "completion_tokens": token_iter.response_output_tokens, - } - } - else: - raise ValueError("Invalid token_iter type") - - yield f"data: {json.dumps(data)}\n".encode() - - yield b"data: [DONE]\n" - - -@pytest.fixture -def httpx_openai_mock(request): - params = request.param if hasattr(request, "param") else {} - model = params.get("model", "mock-model") - target = params.get("target", "http://target.mock") - iter_delay = params.get("iter_delay", None) - - with respx.mock(assert_all_mocked=True, assert_all_called=False) as mock_router: - - async def _mock_completions_response(request) -> AsyncIterable[str]: - headers = request.headers - payload = json.loads(request.content) - - assert headers["Content-Type"] == "application/json" - assert payload["model"] == model - assert payload["stream"] is True - assert payload["stream_options"] == {"include_usage": True} - assert payload["prompt"] is not None - assert len(payload["prompt"]) > 0 - assert payload["max_completion_tokens"] > 0 - assert payload["max_tokens"] > 0 - - return httpx.Response( # type: ignore - 200, - stream=MockCompletionsIter( # type: ignore - type_="text", - prompt=payload["prompt"], - output_token_count=( - payload["max_completion_tokens"] - if payload.get("ignore_eos", False) - else None - ), - target=target, - model=model, - iter_delay=iter_delay, - ), - ) - - async def _mock_chat_completions_response(request): - headers = request.headers - payload = json.loads(request.content) - - assert headers["Content-Type"] == "application/json" - assert payload["model"] == model - assert payload["stream"] is True - assert payload["stream_options"] == {"include_usage": True} - assert payload["messages"] is not None - assert len(payload["messages"]) > 0 - assert payload["max_completion_tokens"] > 0 - assert payload["max_tokens"] > 0 - - return httpx.Response( # type: ignore - 200, - stream=MockCompletionsIter( # type: ignore - type_="chat", - prompt=payload["messages"][0]["content"], - output_token_count=( - payload["max_completion_tokens"] - if payload.get("ignore_eos", False) - else None - ), - target=target, - model=model, - iter_delay=iter_delay, - ), - ) - - mock_router.route(method="GET", path="/v1/models").mock( - return_value=httpx.Response( - 200, json={"data": [{"id": model} if model else {"id": "mock-model"}]} - ) - ) - mock_router.route(method="POST", path="/v1/completions").mock( - side_effect=_mock_completions_response # type: ignore - ) - mock_router.route(method="POST", path="/v1/chat/completions").mock( - side_effect=_mock_chat_completions_response - ) - - yield mock_router diff --git a/tests/unit/mock_backend.py b/tests/unit/mock_backend.py index 6080a9d1..5ac069a8 100644 --- a/tests/unit/mock_backend.py +++ b/tests/unit/mock_backend.py @@ -1,172 +1,184 @@ +""" +Mock backend implementation for testing purposes. +""" + import asyncio import random import time -from collections.abc import AsyncGenerator -from pathlib import Path -from typing import Any, Optional, Union - -from lorem.text import TextLorem # type: ignore -from PIL import Image - -from guidellm.backends import ( - Backend, - RequestArgs, - ResponseSummary, - StreamingTextResponse, +from collections.abc import AsyncIterator +from typing import Any, Optional + +from lorem.text import TextLorem + +from guidellm.backend.backend import Backend +from guidellm.backend.objects import ( + GenerationRequest, + GenerationRequestTimings, + GenerationResponse, ) +from guidellm.scheduler import ScheduledRequestInfo -@Backend.register("mock") # type: ignore +@Backend.register("mock") class MockBackend(Backend): + """ + Mock backend for testing that simulates text generation. + + Provides predictable responses with configurable delays and token counts + for testing the backend interface without requiring an actual LLM service. + """ + def __init__( self, - model: Optional[str] = "mock-model", - target: Optional[str] = "mock-target", + target: str = "mock-target", + model: str = "mock-model", iter_delay: Optional[float] = None, ): - super().__init__(type_="mock") # type: ignore + """ + Initialize mock backend. + + :param model: Model name to simulate. + :param target: Target URL to simulate. + :param iter_delay: Delay between iterations in seconds. + """ + super().__init__(type_="mock") # type: ignore [reportCallIssue] self._model = model self._target = target self._iter_delay = iter_delay + self._in_process = False @property def target(self) -> str: - return self._target # type: ignore + """Target URL for the mock backend.""" + return self._target @property def model(self) -> Optional[str]: + """Model name for the mock backend.""" return self._model - @property def info(self) -> dict[str, Any]: - return {} - - async def reset(self) -> None: - pass - - async def prepare_multiprocessing(self): - pass - - async def check_setup(self): - pass - - async def available_models(self) -> list[str]: - return [self.model] # type: ignore + """ + Return mock backend configuration information. + """ + return { + "type": "mock", + "model": self._model, + "target": self._target, + "iter_delay": self._iter_delay, + } + + async def process_startup(self) -> None: + """ + Initialize the mock backend process. + """ + self._in_process = True + + async def process_shutdown(self) -> None: + """ + Shutdown the mock backend process. + """ + self._in_process = False + + async def validate(self) -> None: + """ + Validate the mock backend configuration. + """ + if not self._in_process: + raise RuntimeError("Backend not started up for process") + + async def default_model(self) -> Optional[str]: + """ + Return the default model for the mock backend. + """ + return self._model - async def text_completions( # type: ignore + async def resolve( self, - prompt: Union[str, list[str]], - request_id: Optional[str] = None, - prompt_token_count: Optional[int] = None, - output_token_count: Optional[int] = None, - **kwargs, - ) -> AsyncGenerator[Union[StreamingTextResponse, ResponseSummary], None]: - if not isinstance(prompt, str) or not prompt: - raise ValueError("Prompt must be a non-empty string") - - async for response in self._text_prompt_response_generator( - prompt, - request_id, - prompt_token_count, - output_token_count, - ): - yield response - - async def chat_completions( # type: ignore - self, - content: Union[ - str, - list[Union[str, dict[str, Union[str, dict[str, str]]], Path, Image.Image]], - Any, - ], - request_id: Optional[str] = None, - prompt_token_count: Optional[int] = None, - output_token_count: Optional[int] = None, - raw_content: bool = False, - **kwargs, - ) -> AsyncGenerator[Union[StreamingTextResponse, ResponseSummary], None]: - if not isinstance(content, str) or not content: - raise ValueError("Content must be a non-empty string") - - async for response in self._text_prompt_response_generator( - content, - request_id, - prompt_token_count, - output_token_count, - ): - yield response - - async def _text_prompt_response_generator( - self, - prompt: str, - request_id: Optional[str], - prompt_token_count: Optional[int], - output_token_count: Optional[int], - ) -> AsyncGenerator[Union[StreamingTextResponse, ResponseSummary], None]: - tokens = self._get_tokens(output_token_count) - start_time = time.time() - - yield StreamingTextResponse( - type_="start", + request: GenerationRequest, + request_info: ScheduledRequestInfo, + history: Optional[list[tuple[GenerationRequest, GenerationResponse]]] = None, + ) -> AsyncIterator[tuple[GenerationResponse, ScheduledRequestInfo]]: + """ + Process a generation request and yield progressive responses. + + ### WRITTEN BY AI ### + """ + if not self._in_process: + raise RuntimeError("Backend not started up for process") + + if history is not None: + raise NotImplementedError( + "Multi-turn requests not supported in mock backend" + ) + + # Extract token counts from request + prompt_tokens = request.stats.get("prompt_tokens") + output_tokens = request.constraints.get("output_tokens") + + # Generate mock tokens + tokens = self._get_tokens(output_tokens) + + # Initialize response + response = GenerationResponse( + request_id=request.request_id, + request_args={ + "request_type": request.request_type, + "output_token_count": output_tokens, + **request.params, + }, value="", - start_time=start_time, - first_iter_time=None, - iter_count=0, - delta="", - time=start_time, - request_id=request_id, + request_prompt_tokens=prompt_tokens, + request_output_tokens=output_tokens, ) - first_iter_time = None - last_iter_time = None + # Initialize timings + request_info.request_timings = GenerationRequestTimings() + request_info.request_timings.request_start = time.time() + # Generate response iteratively for index, token in enumerate(tokens): if self._iter_delay: await asyncio.sleep(self._iter_delay) - if first_iter_time is None: - first_iter_time = time.time() - - yield StreamingTextResponse( - type_="iter", - value="".join(tokens[: index + 1]), - start_time=start_time, - first_iter_time=first_iter_time, - iter_count=index + 1, - delta=token, - time=time.time(), - request_id=request_id, - ) + if request_info.request_timings.first_iteration is None: + request_info.request_timings.first_iteration = time.time() - last_iter_time = time.time() - - yield ResponseSummary( - value="".join(tokens), - request_args=RequestArgs( - target=self.target, - headers={}, - params={}, - payload={"prompt": prompt, "output_token_count": output_token_count}, - ), - iterations=len(tokens), - start_time=start_time, - end_time=time.time(), - first_iter_time=first_iter_time, - last_iter_time=last_iter_time, - request_prompt_tokens=prompt_token_count, - request_output_tokens=output_token_count, - response_prompt_tokens=len(prompt.split()) + prompt.count(" "), - response_output_tokens=len(tokens), - request_id=request_id, + response.value += token # type: ignore [reportOperatorIssue] + response.delta = token + response.iterations = index + 1 + request_info.request_timings.last_iteration = time.time() + + yield response, request_info + + # Final response with usage stats + request_info.request_timings.request_end = time.time() + response.response_prompt_tokens = prompt_tokens or self._estimate_prompt_tokens( + str(request.content) ) + response.response_output_tokens = len(tokens) + response.delta = None + + yield response, request_info + + @staticmethod + def _estimate_prompt_tokens(content: str) -> int: + """ + Estimate prompt tokens from content. + """ + # Simple word-based token estimation + return len(str(content).split()) @staticmethod def _get_tokens(token_count: Optional[int] = None) -> list[str]: + """ + Generate mock tokens for response. + """ if token_count is None: token_count = random.randint(8, 512) words = TextLorem(srange=(token_count, token_count)).sentence().split() - tokens = [] # type: ignore + tokens = [] for word in words: if len(tokens) == token_count - 1: diff --git a/tests/unit/mock_benchmark.py b/tests/unit/mock_benchmark.py index 81364fa1..d846767d 100644 --- a/tests/unit/mock_benchmark.py +++ b/tests/unit/mock_benchmark.py @@ -1,271 +1,152 @@ +"""Mock benchmark objects for unit testing.""" + +from guidellm.backend import GenerationRequestTimings from guidellm.benchmark import ( - BenchmarkArgs, - BenchmarkRunStats, + BenchmarkSchedulerStats, GenerativeBenchmark, - GenerativeTextErrorStats, - GenerativeTextResponseStats, - SynchronousProfile, + GenerativeMetrics, + GenerativeRequestStats, ) -from guidellm.objects import StatusBreakdown -from guidellm.request import GenerativeRequestLoaderDescription -from guidellm.scheduler import ( - GenerativeRequestsWorkerDescription, - SchedulerRequestInfo, - SynchronousStrategy, +from guidellm.benchmark.objects import BenchmarkerDict, SchedulerDict +from guidellm.benchmark.profile import SynchronousProfile +from guidellm.scheduler import ScheduledRequestInfo, SchedulerState, SynchronousStrategy +from guidellm.utils import ( + DistributionSummary, + Percentiles, + StandardBaseDict, + StatusBreakdown, + StatusDistributionSummary, ) __all__ = ["mock_generative_benchmark"] +def _create_mock_percentiles() -> Percentiles: + """Create mock percentiles for testing.""" + return Percentiles( + p001=0.1, + p01=1.0, + p05=5.0, + p10=10.0, + p25=25.0, + p50=50.0, + p75=75.0, + p90=90.0, + p95=95.0, + p99=99.0, + p999=99.9, + ) + + +def _create_mock_distribution() -> DistributionSummary: + """Create mock distribution summary for testing.""" + return DistributionSummary( + mean=50.0, + median=50.0, + mode=50.0, + variance=10.0, + std_dev=3.16, + min=10.0, + max=100.0, + count=100, + total_sum=5000.0, + percentiles=_create_mock_percentiles(), + ) + + +def _create_status_dist() -> StatusDistributionSummary: + """Create mock status distribution summary for testing.""" + dist = _create_mock_distribution() + return StatusDistributionSummary( + successful=dist, + incomplete=dist, + errored=dist, + total=dist, + ) + + def mock_generative_benchmark() -> GenerativeBenchmark: - return GenerativeBenchmark.from_stats( - run_id="fa4a92c1-9a1d-4c83-b237-83fcc7971bd3", - successful=[ - GenerativeTextResponseStats( - request_id="181a63e2-dc26-4268-9cfc-2ed9279aae63", - request_type="text_completions", - scheduler_info=SchedulerRequestInfo( - requested=True, - completed=True, - errored=False, - canceled=False, - targeted_start_time=1744728125.203447, - queued_time=1744728125.204123, - dequeued_time=1744728125.2048807, - scheduled_time=1744728125.2048993, - worker_start=1744728125.2049701, - request_start=1744728125.2052872, - request_end=1744728126.7004411, - worker_end=1744728126.701175, - process_id=0, - ), - prompt="such a sacrifice to her advantage as years of gratitude cannot enough acknowledge. By this time she is actually with them! If such goodness does not make her miserable now, she will never deserve to be happy! What a meeting for her, when she first sees my aunt! We must endeavour to forget all that has passed on either side, said Jane I hope and trust they will yet be happy. His consenting to marry her is a proof, I will believe, that he is come to a right way of thinking. Their mutual affection will steady them; and I flatter myself they will settle so quietly, and live in so rational a manner", # noqa: E501 - output=", as to make their long life together very comfortable and very useful. I feel, if they and the honourable Mr. Thorpe, who still lives amongst us, should be all I need, I could perfectly rest happy. Writes to meet them in that kind of obedience which is necessary and honourable, and such", # noqa: E501 - prompt_tokens=128, - output_tokens=64, - start_time=1744728125.2052872, - end_time=1744728126.7004411, - first_token_time=1744728125.2473357, - last_token_time=1744728126.699908, - ), - GenerativeTextResponseStats( - request_id="8a7846d5-7624-420d-a269-831e568a848f", - request_type="text_completions", - scheduler_info=SchedulerRequestInfo( - requested=True, - completed=True, - errored=False, - canceled=False, - targeted_start_time=1744728125.204613, - queued_time=1744728125.2047558, - dequeued_time=1744728126.7025175, - scheduled_time=1744728126.7025256, - worker_start=1744728126.702579, - request_start=1744728126.7027814, - request_end=1744728128.1961868, - worker_end=1744728128.196895, - process_id=0, - ), - prompt="a reconciliation; and, after a little further resistance on the part of his aunt, her resentment gave way, either to her affection for him, or her curiosity to see how his wife conducted herself; and she condescended to wait on them at Pemberley, in spite of that pollution which its woods had received, not merely from the presence of such a mistress, but the visits of her uncle and aunt from the city. With the Gardiners they were always on the most intimate terms. Darcy, as well as Elizabeth, really loved them; and they were both ever sensible of the warmest gratitude towards the persons who,", # noqa: E501 - output=" in their own days of poverty, had been so hotel and hospitable to a young couple leaving Pemberley. Till the size of Mr. Bennet\u2019s salary had been altered, the blessing of their friendship was much more greatly needed by the family than it appeared after that event.\n- Mr. Darcy soon deserved", # noqa: E501 - prompt_tokens=128, - output_tokens=64, - start_time=1744728126.7027814, - end_time=1744728128.1961868, - first_token_time=1744728126.7526379, - last_token_time=1744728128.1956792, - ), - GenerativeTextResponseStats( - request_id="4cde0e6c-4531-4e59-aac1-07bc8b6e4139", - request_type="text_completions", - scheduler_info=SchedulerRequestInfo( - requested=True, - completed=True, - errored=False, - canceled=False, - targeted_start_time=1744728126.7031465, - queued_time=1744728126.7034643, - dequeued_time=1744728128.198447, - scheduled_time=1744728128.1984534, - worker_start=1744728128.198509, - request_start=1744728128.1986883, - request_end=1744728129.6919055, - worker_end=1744728129.692606, - process_id=0, - ), - prompt="struck her, that _she_ was selected from among her sisters as worthy of being the mistress of Hunsford Parsonage, and of assisting to form a quadrille table at Rosings, in the absence of more eligible visitors. The idea soon reached to conviction, as she observed his increasing civilities towards herself, and heard his frequent attempt at a compliment on her wit and vivacity; and though more astonished than gratified herself by this effect of her charms, it was not long before her mother gave her to understand that the probability of their marriage was exceedingly agreeable to _her_. Elizabeth, however, did not choose", # noqa: E501 - output=" to improve this conversation into a prophecy, and her mother would hardly take on herself to announce so important a phenomenon. At last he was to drive to Hunsford from Meryton on Sunday; they staid for an hour at eight o'clock, and the following day appeared to be hung up on the walls of", # noqa: E501 - prompt_tokens=128, - output_tokens=64, - start_time=1744728128.1986883, - end_time=1744728129.6919055, - first_token_time=1744728128.2481627, - last_token_time=1744728129.6914039, - ), - GenerativeTextResponseStats( - request_id="a95b96be-05d4-4130-b0dd-9528c01c9909", - request_type="text_completions", - scheduler_info=SchedulerRequestInfo( - requested=True, - completed=True, - errored=False, - canceled=False, - targeted_start_time=1744728128.1987216, - queued_time=1744728128.1991177, - dequeued_time=1744728129.6953137, - scheduled_time=1744728129.695318, - worker_start=1744728129.695379, - request_start=1744728129.6955585, - request_end=1744728131.187553, - worker_end=1744728131.188169, - process_id=0, - ), - prompt="were comfortable on this subject. Day after day passed away without bringing any other tidings of him than the report which shortly prevailed in Meryton of his coming no more to Netherfield the whole winter; a report which highly incensed Mrs. Bennet, and which she never failed to contradict as a most scandalous falsehood. Even Elizabeth began to fear not that Bingley was indifferent but that his sisters would be successful in keeping him away. Unwilling as she was to admit an idea so destructive to Jane s happiness, and so dishonourable to the stability of her lover, she could not prevent its frequently recurring", # noqa: E501 - output=" during these indefinite disputes; and was often seriously engaged in blaming her sisters for increasing a suspense which might only be caused by their own inattention to a subject of so much moment. Whether she had really made that impression on the s+.ayers, or whether she had merely imagined it, she could decide no farther, for", # noqa: E501 - prompt_tokens=128, - output_tokens=64, - start_time=1744728129.6955585, - end_time=1744728131.187553, - first_token_time=1744728129.7438853, - last_token_time=1744728131.187019, - ), - GenerativeTextResponseStats( - request_id="714b751c-bbfe-4b2a-a0af-7c1bf2c224ae", - request_type="text_completions", - scheduler_info=SchedulerRequestInfo( - requested=True, - completed=True, - errored=False, - canceled=False, - targeted_start_time=1744728129.6975086, - queued_time=1744728129.6978767, - dequeued_time=1744728131.190093, - scheduled_time=1744728131.190101, - worker_start=1744728131.1901798, - request_start=1744728131.1904676, - request_end=1744728132.6833503, - worker_end=1744728132.6839745, - process_id=0, - ), - prompt="? cried Elizabeth, brightening up for a moment. Upon my word, said Mrs. Gardiner, I begin to be of your uncle s opinion. It is really too great a violation of decency, honour, and interest, for him to be guilty of it. I cannot think so very ill of Wickham. Can you, yourself, Lizzie, so wholly give him up, as to believe him capable of it? Not perhaps of neglecting his own interest. But of every other neglect I can believe him capable. If, indeed, it should be so! But I dare not hope it. Why should they not go on", # noqa: E501 - output=" together? This is still a motive incapable of being denied. He has such a faculty of pleasing, and you know how much she likes him. \nQuestion: What made elder sisters the center of their families?\nSometimes early this would be discussed in the family circle, but that was a very exceptional treatment.\nThank you,", # noqa: E501 - prompt_tokens=128, - output_tokens=64, - start_time=1744728131.1904676, - end_time=1744728132.6833503, - first_token_time=1744728131.2394557, - last_token_time=1744728132.6828275, - ), - GenerativeTextResponseStats( - request_id="ef73ae8a-4c8f-4c88-b303-cfff152ce378", - request_type="text_completions", - scheduler_info=SchedulerRequestInfo( - requested=True, - completed=True, - errored=False, - canceled=False, - targeted_start_time=1744728131.1891043, - queued_time=1744728131.1893764, - dequeued_time=1744728132.6859632, - scheduled_time=1744728132.6859682, - worker_start=1744728132.6860242, - request_start=1744728132.6862206, - request_end=1744728134.1805167, - worker_end=1744728134.1813161, - process_id=0, - ), - prompt="was. But her commendation, though costing her some trouble, could by no means satisfy Mr. Collins, and he was very soon obliged to take her Ladyship s praise into his own hands. Sir William stayed only a week at Hunsford; but his visit was long enough to convince him of his daughter s being most comfortably settled, and of her possessing such a husband and such a neighbour as were not often met with. While Sir William was with them, Mr. Collins devoted his mornings to driving him out in his gig, and showing him the country but when he went away, the whole family returned to their usual employments", # noqa: E501 - output=", and the sides of the family in which he was more particularly interested, to their respective places in the establishment. Here Jane was occasionally up as a substitute to her indolent sister, in her matron s stead, but was more frequently left idle, and with her hours of quietness, the unwelcome intrusion", # noqa: E501 - prompt_tokens=128, - output_tokens=64, - start_time=1744728132.6862206, - end_time=1744728134.1805167, - first_token_time=1744728132.7354612, - last_token_time=1744728134.1797993, - ), - ], - errored=[], - incomplete=[ - GenerativeTextErrorStats( - request_id="1b3def04-ca81-4f59-a56c-452a069d91af", - request_type="text_completions", - scheduler_info=SchedulerRequestInfo( - requested=True, - completed=False, - errored=True, - canceled=True, - targeted_start_time=1744728132.686177, - queued_time=1744728132.6866345, - dequeued_time=1744728134.1831052, - scheduled_time=1744728134.1831107, - worker_start=1744728134.183183, - request_start=1744728134.183544, - request_end=1744728135.2031732, - worker_end=1744728135.2033112, - process_id=0, - ), - prompt="is to tempt anyone to our humble abode. Our plain manner of living, our small rooms, and few domestics, and the little we see of the world, must make Hunsford extremely dull to a young lady like yourself; but I hope you will believe us grateful for the condescension, and that we have done everything in our power to prevent you spending your time unpleasantly. Elizabeth was eager with her thanks and assurances of happiness. She had spent six weeks with great enjoyment; and the pleasure of being with Charlotte, and the kind attention she had received, must make _her_ feel the obliged. Mr. Collins", # noqa: E501 - output=", who certainly had an eye to Elizabeth's manner, was glad _he was not to lose the curiosity she had given, and requested her away_ , _for the politeness of her conciliating manner would", # noqa: E501 - prompt_tokens=128, - output_tokens=43, - start_time=1744728134.183544, - end_time=1744728135.2031732, - first_token_time=1744728134.2323751, - last_token_time=1744728135.1950455, - error="TimeoutError: The request timed out before completing.", - ) - ], - args=BenchmarkArgs( - profile=SynchronousProfile(), - strategy_index=0, + """Create a minimal mock GenerativeBenchmark for testing purposes.""" + return GenerativeBenchmark( + run_id="test-run-gen", + run_index=0, + scheduler=SchedulerDict( strategy=SynchronousStrategy(), - max_number=None, - max_duration=10.0, - warmup_number=None, - warmup_duration=None, - cooldown_number=None, - cooldown_duration=None, + constraints={}, + state=SchedulerState(node_id=0, num_processes=1), ), - run_stats=BenchmarkRunStats( - start_time=1744728125.0772898, - end_time=1744728135.8407037, + benchmarker=BenchmarkerDict( + profile=SynchronousProfile.create("synchronous", rate=None), + requests={}, + backend={}, + environment={}, + aggregators={}, + ), + env_args=StandardBaseDict(), + extras=StandardBaseDict(), + run_stats=BenchmarkSchedulerStats( + start_time=1, + end_time=2, requests_made=StatusBreakdown( - successful=6, + successful=1, + incomplete=0, errored=0, - incomplete=1, - total=7, + total=1, ), - queued_time_avg=1.2821388585226876, - scheduled_time_delay_avg=7.96999250139509e-6, - scheduled_time_sleep_avg=0.0, - worker_start_delay_avg=6.399835859026228e-5, - worker_time_avg=1.4266603674207414, - worker_start_time_targeted_delay_avg=1.2825865745544434, - request_start_time_delay_avg=0.6414163964135307, - request_start_time_targeted_delay_avg=1.2827096836907523, - request_time_delay_avg=0.0004316908972603934, - request_time_avg=1.426228676523481, + queued_time_avg=0.1, + worker_resolve_start_delay_avg=0.1, + worker_resolve_time_avg=0.1, + worker_resolve_end_delay_avg=0.1, + finalized_delay_avg=0.1, + worker_targeted_start_delay_avg=0.1, + request_start_delay_avg=0.1, + request_time_avg=0.1, + request_targeted_delay_avg=0.1, + ), + start_time=1000.0, + end_time=2000.0, + metrics=GenerativeMetrics( + requests_per_second=_create_status_dist(), + request_concurrency=_create_status_dist(), + request_latency=_create_status_dist(), + prompt_token_count=_create_status_dist(), + output_token_count=_create_status_dist(), + total_token_count=_create_status_dist(), + time_to_first_token_ms=_create_status_dist(), + time_per_output_token_ms=_create_status_dist(), + inter_token_latency_ms=_create_status_dist(), + output_tokens_per_second=_create_status_dist(), + tokens_per_second=_create_status_dist(), ), - worker=GenerativeRequestsWorkerDescription( - backend_type="openai_http", - backend_target="http://localhost:8000", - backend_model="neuralmagic/Qwen2.5-7B-quantized.w8a8", - backend_info={ - "max_output_tokens": 16384, - "timeout": 300, - "http2": True, - "authorization": False, - "organization": None, - "project": None, - "text_completions_path": "/v1/completions", - "chat_completions_path": "/v1/chat/completions", - }, + request_totals=StatusBreakdown( + successful=1, + incomplete=0, + errored=0, + total=1, ), - requests_loader=GenerativeRequestLoaderDescription( - data='{"prompt_tokens": 128, "output_tokens": 64}', - data_args=None, - processor="neuralmagic/Qwen2.5-7B-quantized.w8a8", - processor_args=None, + requests=StatusBreakdown( + successful=[ + GenerativeRequestStats( + scheduler_info=ScheduledRequestInfo( + request_timings=GenerationRequestTimings( + request_start=1, + first_iteration=2, + last_iteration=6, + request_end=6, + ) + ), + request_id="a", + request_type="text_completions", + prompt="p", + request_args={}, + output="o", + iterations=1, + prompt_tokens=1, + output_tokens=2, + ) + ], + incomplete=[], + errored=[], + total=None, ), - extras={}, ) diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py deleted file mode 100644 index 63beb512..00000000 --- a/tests/unit/test_cli.py +++ /dev/null @@ -1,105 +0,0 @@ -""" -Unit tests for CLI functionality, specifically the version flag. -""" - -import importlib.metadata -import re - -import pytest -from click.testing import CliRunner - -from guidellm.__main__ import cli - - -@pytest.mark.smoke -def test_version_flag_long(): - """Test that --version flag works correctly.""" - runner = CliRunner() - result = runner.invoke(cli, ["--version"]) - - assert result.exit_code == 0 - assert "guidellm version:" in result.output - assert result.output.strip().startswith("guidellm version:") - - -@pytest.mark.smoke -def test_version_flag_displays_actual_version(): - """Test that --version displays the actual version from version.py.""" - runner = CliRunner() - result = runner.invoke(cli, ["--version"]) - - assert result.exit_code == 0 - - version_pattern = r"guidellm version: \d+\.\d+" - assert re.search(version_pattern, result.output) - - -@pytest.mark.smoke -def test_version_flag_exits_cleanly(): - """Test that --version exits without processing other commands.""" - runner = CliRunner() - result = runner.invoke(cli, ["--version", "benchmark"]) - - assert result.exit_code == 0 - assert "guidellm version:" in result.output - assert "Commands to run a new benchmark" not in result.output - - -@pytest.mark.smoke -def test_help_shows_version_option(): - """Test that --help shows the --version option.""" - runner = CliRunner() - result = runner.invoke(cli, ["--help"]) - - assert result.exit_code == 0 - assert "--version" in result.output - assert "Show the version and exit" in result.output - - -@pytest.mark.smoke -def test_other_commands_still_work(): - """Test that other CLI commands still work after adding version flag.""" - runner = CliRunner() - result = runner.invoke(cli, ["--help"]) - - assert result.exit_code == 0 - assert "benchmark" in result.output - assert "config" in result.output - assert "preprocess" in result.output - - -@pytest.mark.smoke -def test_version_flag_case_sensitivity(): - """Test that --version flag is case sensitive.""" - runner = CliRunner() - - result = runner.invoke(cli, ["--version"]) - assert result.exit_code == 0 - assert "guidellm version:" in result.output - - # --VERSION should not work - result = runner.invoke(cli, ["--VERSION"]) - assert result.exit_code != 0 - assert "No such option" in result.output - - -@pytest.mark.integration -def test_version_integration_with_actual_version(): - """Integration test to verify version matches importlib.metadata.""" - try: - actual_version = importlib.metadata.version("guidellm") - - runner = CliRunner() - result = runner.invoke(cli, ["--version"]) - - assert result.exit_code == 0 - expected_output = f"guidellm version: {actual_version}" - assert expected_output in result.output - except importlib.metadata.PackageNotFoundError: - # If package is not installed, the CLI should show an error - # This is expected behavior when the package isn't properly installed - runner = CliRunner() - result = runner.invoke(cli, ["--version"]) - - # Click will handle the error when package is not found - assert result.exit_code != 0 diff --git a/tests/unit/test_config.py b/tests/unit/test_settings.py similarity index 100% rename from tests/unit/test_config.py rename to tests/unit/test_settings.py diff --git a/tests/unit/utils/test_typing.py b/tests/unit/utils/test_typing.py new file mode 100644 index 00000000..fafa8765 --- /dev/null +++ b/tests/unit/utils/test_typing.py @@ -0,0 +1,123 @@ +""" +Test suite for the typing utilities module. +""" + +from typing import Annotated, Literal, Union + +import pytest +from typing_extensions import TypeAlias + +from guidellm.utils.typing import get_literal_vals + +# Local type definitions to avoid imports from other modules +LocalProfileType = Literal["synchronous", "async", "concurrent", "throughput", "sweep"] +LocalStrategyType = Annotated[ + Literal["synchronous", "concurrent", "throughput", "constant", "poisson"], + "Valid strategy type identifiers for scheduling request patterns", +] +StrategyProfileType: TypeAlias = Union[LocalStrategyType, LocalProfileType] + + +class TestGetLiteralVals: + """Test cases for the get_literal_vals function.""" + + @pytest.mark.sanity + def test_profile_type(self): + """ + Test extracting values from ProfileType. + + ### WRITTEN BY AI ### + """ + result = get_literal_vals(LocalProfileType) + expected = frozenset( + {"synchronous", "async", "concurrent", "throughput", "sweep"} + ) + assert result == expected + + @pytest.mark.sanity + def test_strategy_type(self): + """ + Test extracting values from StrategyType. + + ### WRITTEN BY AI ### + """ + result = get_literal_vals(LocalStrategyType) + expected = frozenset( + {"synchronous", "concurrent", "throughput", "constant", "poisson"} + ) + assert result == expected + + @pytest.mark.smoke + def test_inline_union_type(self): + """ + Test extracting values from inline union of ProfileType | StrategyType. + + ### WRITTEN BY AI ### + """ + result = get_literal_vals(Union[LocalProfileType, LocalStrategyType]) + expected = frozenset( + { + "synchronous", + "async", + "concurrent", + "throughput", + "constant", + "poisson", + "sweep", + } + ) + assert result == expected + + @pytest.mark.smoke + def test_type_alias(self): + """ + Test extracting values from type alias union. + + ### WRITTEN BY AI ### + """ + result = get_literal_vals(StrategyProfileType) + expected = frozenset( + { + "synchronous", + "async", + "concurrent", + "throughput", + "constant", + "poisson", + "sweep", + } + ) + assert result == expected + + @pytest.mark.sanity + def test_single_literal(self): + """ + Test extracting values from single Literal type. + + ### WRITTEN BY AI ### + """ + result = get_literal_vals(Literal["test"]) + expected = frozenset({"test"}) + assert result == expected + + @pytest.mark.sanity + def test_multi_literal(self): + """ + Test extracting values from multi-value Literal type. + + ### WRITTEN BY AI ### + """ + result = get_literal_vals(Literal["test", "test2"]) + expected = frozenset({"test", "test2"}) + assert result == expected + + @pytest.mark.smoke + def test_literal_union(self): + """ + Test extracting values from union of Literal types. + + ### WRITTEN BY AI ### + """ + result = get_literal_vals(Union[Literal["test", "test2"], Literal["test3"]]) + expected = frozenset({"test", "test2", "test3"}) + assert result == expected From 6d0d4c24c361158d747356f7d7f3de1884697d53 Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Fri, 19 Sep 2025 13:06:54 +0000 Subject: [PATCH 4/6] add in the perf extras Signed-off-by: Mark Kurtz --- pyproject.toml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 966a032b..29ae92c5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,6 +64,12 @@ dependencies = [ ] [project.optional-dependencies] +perf = [ + "orjson", + "msgpack", + "msgspec", + "uvloop", +] dev = [ # build "build>=1.0.0", From ab5466b1f5affd146fa69dbfc63cf31c6a1c82f0 Mon Sep 17 00:00:00 2001 From: Jared O'Connell <46976761+jaredoconnell@users.noreply.github.com> Date: Wed, 24 Sep 2025 17:31:45 -0400 Subject: [PATCH 5/6] [GuideLLM Refactor] Fix from-file (#366) ## Summary This PR ports the new functionality from `benchmark run` to `benchmark from-file`, and does so in a way that reuses as much code as practical to have one source of truth. ## Details - Fixes from-file by making it to use the new output format. - Moves code related to the new output formats to separate functions that are called from both benchmark entrypoints. - Moves additional chunks of code out of the large benchmark run entrypoint function for modularity. ## Test Plan Run a benchmark with an output of json or yaml, and use `from-file` to re-import it and export it. You can select any output type supported by `benchmark run`. `guidellm benchmark from-file ./result.json --output-formats console` `guidellm benchmark from-file ./result.yaml --output-formats yaml` ## Related Issues --- - [x] "I certify that all code in this PR is my own, except as noted below." ## Use of AI - [x] Includes AI-assisted code completion - [ ] Includes code generated by an AI application - [ ] Includes AI-generated tests (NOTE: AI written tests should have a docstring that includes `## WRITTEN BY AI ##`) --------- Signed-off-by: Jared O'Connell --- src/guidellm/__main__.py | 37 +++-- src/guidellm/benchmark/entrypoints.py | 204 +++++++++++++++++--------- 2 files changed, 158 insertions(+), 83 deletions(-) diff --git a/src/guidellm/__main__.py b/src/guidellm/__main__.py index 675003a9..9d85346b 100644 --- a/src/guidellm/__main__.py +++ b/src/guidellm/__main__.py @@ -473,23 +473,30 @@ def run( ) @click.option( "--output-path", - type=click.Path(file_okay=True, dir_okay=True, exists=False), - default=None, - is_flag=False, - flag_value=Path.cwd() / "benchmarks_reexported.json", + type=click.Path(), + default=Path.cwd(), + help=( + "Allows re-exporting the benchmarks to other formats. " + "The path to save the output formats to, if the format is a file type. " + "If it is a directory, it will save all output formats selected under it. " + "If it is a file, it will save the corresponding output format to that file. " + "Any output formats that were given that do not match the file extension will " + "be saved in the parent directory of the file path. " + "Defaults to the current working directory. " + ), +) +@click.option( + "--output-formats", + multiple=True, + type=str, + default=("console", "json"), # ("console", "json", "html", "csv") help=( - "Allows re-exporting the benchmarks to another format. " - "The path to save the output to. If it is a directory, " - "it will save benchmarks.json under it. " - "Otherwise, json, yaml, or csv files are supported for output types " - "which will be read from the extension for the file path. " - "This input is optional. If the output path flag is not provided, " - "the benchmarks will not be reexported. If the flag is present but " - "no value is specified, it will default to the current directory " - "with the file name `benchmarks_reexported.json`." + "The output formats to use for the benchmark results. " + "Defaults to console, json, html, and csv where the file formats " + "will be saved at the specified output path." ), ) -def from_file(path, output_path): +def from_file(path, output_path, output_formats): """ Load and optionally re-export a previously saved benchmark report. @@ -497,7 +504,7 @@ def from_file(path, output_path): to different output formats. Supports JSON, YAML, and CSV export formats based on the output file extension. """ - reimport_benchmarks_report(path, output_path) + asyncio.run(reimport_benchmarks_report(path, output_path, output_formats)) @cli.command( diff --git a/src/guidellm/benchmark/entrypoints.py b/src/guidellm/benchmark/entrypoints.py index 60077ee8..828402d8 100644 --- a/src/guidellm/benchmark/entrypoints.py +++ b/src/guidellm/benchmark/entrypoints.py @@ -26,7 +26,6 @@ from guidellm.benchmark.benchmarker import Benchmarker from guidellm.benchmark.objects import GenerativeBenchmark, GenerativeBenchmarksReport from guidellm.benchmark.output import ( - GenerativeBenchmarkerConsole, GenerativeBenchmarkerOutput, ) from guidellm.benchmark.profile import Profile, ProfileType @@ -53,6 +52,97 @@ _CURRENT_WORKING_DIR = Path.cwd() +# Data types + +DataType = ( + Iterable[str] + | Iterable[dict[str, Any]] + | Dataset + | DatasetDict + | IterableDataset + | IterableDatasetDict + | str + | Path +) + +OutputFormatType = ( + tuple[str, ...] + | list[str] + | dict[str, str | dict[str, Any] | GenerativeBenchmarkerOutput] + | None +) + + +# Helper functions + +async def initialize_backend( + backend: BackendType | Backend, + target: str, + model: str | None, + backend_kwargs: dict[str, Any] | None, +) -> Backend: + backend = ( + Backend.create( + backend, target=target, model=model, **(backend_kwargs or {}) + ) + if not isinstance(backend, Backend) + else backend + ) + await backend.process_startup() + await backend.validate() + return backend + + +async def resolve_profile( + constraint_inputs: dict[str, int | float], + profile: Profile | str | None, + rate: list[float] | None, + random_seed: int, + constraints: dict[str, ConstraintInitializer | Any], +): + for key, val in constraint_inputs.items(): + if val is not None: + constraints[key] = val + if not isinstance(profile, Profile): + if isinstance(profile, str): + profile = Profile.create( + rate_type=profile, + rate=rate, + random_seed=random_seed, + constraints={**constraints}, + ) + else: + raise ValueError(f"Expected string for profile; got {type(profile)}") + + elif constraints: + raise ValueError( + "Constraints must be empty when providing a Profile instance. " + f"Provided constraints: {constraints} ; provided profile: {profile}" + ) + return profile + +async def resolve_output_formats( + output_formats: OutputFormatType, + output_path: str | Path | None, +) -> dict[str, GenerativeBenchmarkerOutput]: + output_formats = GenerativeBenchmarkerOutput.resolve( + output_formats=(output_formats or {}), output_path=output_path + ) + return output_formats + +async def finalize_outputs( + report: GenerativeBenchmarksReport, + resolved_output_formats: dict[str, GenerativeBenchmarkerOutput] +): + output_format_results = {} + for key, output in resolved_output_formats.items(): + output_result = await output.finalize(report) + output_format_results[key] = output_result + return output_format_results + + +# Complete entrypoints + async def benchmark_with_scenario(scenario: Scenario, **kwargs): """ Run a benchmark using a scenario and specify any extra arguments @@ -67,16 +157,7 @@ async def benchmark_with_scenario(scenario: Scenario, **kwargs): # @validate_call(config={"arbitrary_types_allowed": True}) async def benchmark_generative_text( # noqa: C901 target: str, - data: ( - Iterable[str] - | Iterable[dict[str, Any]] - | Dataset - | DatasetDict - | IterableDataset - | IterableDatasetDict - | str - | Path - ), + data: DataType, profile: StrategyType | ProfileType | Profile, rate: float | list[float] | None = None, random_seed: int = 42, @@ -91,12 +172,7 @@ async def benchmark_generative_text( # noqa: C901 data_sampler: Literal["random"] | None = None, # Output configuration output_path: str | Path | None = _CURRENT_WORKING_DIR, - output_formats: ( - tuple[str, ...] - | list[str] - | dict[str, str | dict[str, Any] | GenerativeBenchmarkerOutput] - | None - ) = ("console", "json", "html", "csv"), + output_formats: OutputFormatType = ("console", "json", "html", "csv"), # Updates configuration progress: tuple[str, ...] | list[str] | list[BenchmarkerProgress] | None = None, print_updates: bool = False, @@ -120,16 +196,7 @@ async def benchmark_generative_text( # noqa: C901 with console.print_update_step( title=f"Initializing backend {backend}" ) as console_step: - backend = ( - Backend.create( - backend, target=target, model=model, **(backend_kwargs or {}) - ) - if not isinstance(backend, Backend) - else backend - ) - console_step.update(f"{backend.__class__.__name__} backend initialized") - await backend.process_startup() - await backend.validate() + backend = await initialize_backend(backend, target, model, backend_kwargs) console_step.finish( title=f"{backend.__class__.__name__} backend initialized", details=backend.info, @@ -190,27 +257,19 @@ async def benchmark_generative_text( # noqa: C901 with console.print_update_step( title=f"Resolving profile {profile}" ) as console_step: - for key, val in { - "max_seconds": max_seconds, - "max_requests": max_requests, - "max_errors": max_errors, - "max_error_rate": max_error_rate, - "max_global_error_rate": max_global_error_rate, - }.items(): - if val is not None: - constraints[key] = val - if not isinstance(profile, Profile): - profile = Profile.create( - rate_type=profile, - rate=rate, - random_seed=random_seed, - constraints={**constraints}, - ) - elif constraints: - raise ValueError( - "Constraints must be empty when providing a Profile instance. " - f"Provided constraints: {constraints} ; provided profile: {profile}" - ) + profile = await resolve_profile( + { + "max_seconds": max_seconds, + "max_requests": max_requests, + "max_errors": max_errors, + "max_error_rate": max_error_rate, + "max_global_error_rate": max_global_error_rate, + }, + profile, + rate, + random_seed, + constraints, + ) console_step.finish( title=f"{profile.__class__.__name__} profile resolved", details=InfoMixin.extract_from_obj(profile), @@ -237,12 +296,10 @@ async def benchmark_generative_text( # noqa: C901 ) with console.print_update_step(title="Resolving output formats") as console_step: - output_formats = GenerativeBenchmarkerOutput.resolve( - output_formats=(output_formats or {}), output_path=output_path - ) + resolved_output_formats = await resolve_output_formats(output_formats, output_path) console_step.finish( title="Output formats resolved", - details={key: str(val) for key, val in output_formats.items()}, + details={key: str(val) for key, val in resolved_output_formats.items()}, status_level="success", ) @@ -278,14 +335,11 @@ async def benchmark_generative_text( # noqa: C901 if benchmark: report.benchmarks.append(benchmark) - output_format_results = {} - for key, output in output_formats.items(): - output_result = await output.finalize(report) - output_format_results[key] = output_result + output_format_results = await finalize_outputs(report, resolved_output_formats) console.print("\n\n") console.print_update( - title=f"Benchmarking complete, generated {len(report.benchmarks)} benchmark(s)", + title=f"Benchmarking complete; generated {len(report.benchmarks)} benchmark(s)", status="success", ) for key, value in output_format_results.items(): @@ -294,20 +348,34 @@ async def benchmark_generative_text( # noqa: C901 return report, output_format_results -def reimport_benchmarks_report(file: Path, output_path: Path | None) -> None: +async def reimport_benchmarks_report( + file: Path, + output_path: Path | None, + output_formats: OutputFormatType = ("console", "json", "html", "csv"), +) -> tuple[GenerativeBenchmarksReport, dict[str, Any]]: """ The command-line entry point for re-importing and displaying an - existing benchmarks report. Can also specify + existing benchmarks report. Can also specify an output format. Assumes the file provided exists. """ - report = GenerativeBenchmarksReport.load_file(file) - console_output = GenerativeBenchmarkerConsole() - console_output.finalize(report) console = Console() + with console.print_update_step( + title=f"Loading benchmarks from {file}" + ) as console_step: + report = GenerativeBenchmarksReport.load_file(file) + console_step.finish(f"Import of old benchmarks complete; loaded {len(report.benchmarks)} benchmark(s)") + + with console.print_update_step(title="Resolving output formats") as console_step: + resolved_output_formats = await resolve_output_formats(output_formats, output_path) + console_step.finish( + title="Output formats resolved", + details={key: str(val) for key, val in resolved_output_formats.items()}, + status_level="success", + ) - if output_path: - with console.print_update_step( - title=f"Saving benchmarks report to {output_path}..." - ) as console_step: - saved_path = report.save_file(output_path) - console_step.finish(title=f"Benchmarks report saved to {saved_path}") + output_format_results = await finalize_outputs(report, resolved_output_formats) + + for key, value in output_format_results.items(): + console.print_update(title=f" {key:<8}: {value}", status="debug") + + return report, output_format_results From 78615f74af5936806f323eecc94709f2ee3317ee Mon Sep 17 00:00:00 2001 From: Samuel Monson Date: Thu, 25 Sep 2025 11:13:50 -0400 Subject: [PATCH 6/6] [GuideLLM Refactor] Entrypoint: Reintroduce changes from main (#363) ## Summary Reintroduces a few changes from main --------- Signed-off-by: Samuel Monson --- src/guidellm/__main__.py | 1 + src/guidellm/logger.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/guidellm/__main__.py b/src/guidellm/__main__.py index 9d85346b..13a748d5 100644 --- a/src/guidellm/__main__.py +++ b/src/guidellm/__main__.py @@ -104,6 +104,7 @@ def decode_escaped_str(_ctx, _param, value): @click.group() +@click.version_option(package_name="guidellm", message="guidellm version: %(version)s") def cli(): """ Main entry point for the GuideLLM command-line interface. diff --git a/src/guidellm/logger.py b/src/guidellm/logger.py index 48b41a49..70259bad 100644 --- a/src/guidellm/logger.py +++ b/src/guidellm/logger.py @@ -71,7 +71,8 @@ def configure_logger(config: LoggingSettings = settings.logging): logger.add( sys.stdout, level=config.console_log_level.upper(), - format="{time} | {function} | {level} - {message}", + format="{time:YY-MM-DD HH:mm:ss}|{level: <8} \ + |{name}:{function}:{line} - {message}" ) if config.log_file or config.log_file_level: