From 542af1b93a0f317e0766c6fd33ec8a8e3262adea Mon Sep 17 00:00:00 2001 From: Carson Date: Wed, 18 Dec 2024 16:19:03 -0600 Subject: [PATCH 1/4] Add ChatVLLM() --- README.md | 1 + chatlas/__init__.py | 2 + chatlas/_anthropic.py | 6 +- chatlas/_openai.py | 9 ++- chatlas/_vllm.py | 146 ++++++++++++++++++++++++++++++++++++++++++ docs/_quarto.yml | 1 + 6 files changed, 161 insertions(+), 4 deletions(-) create mode 100644 chatlas/_vllm.py diff --git a/README.md b/README.md index a6687a40..458ffb23 100644 --- a/README.md +++ b/README.md @@ -40,6 +40,7 @@ pip install -U git+https://github.com/posit-dev/chatlas * Ollama local models: [`ChatOllama()`](https://posit-dev.github.io/chatlas/reference/ChatOllama.html). * OpenAI: [`ChatOpenAI()`](https://posit-dev.github.io/chatlas/reference/ChatOpenAI.html). * perplexity.ai: [`ChatPerplexity()`](https://posit-dev.github.io/chatlas/reference/ChatPerplexity.html). +* vLLM: [`ChatVLLM()`](https://posit-dev.github.io/chatlas/reference/ChatVLLM.html). It also supports the following enterprise cloud providers: diff --git a/chatlas/__init__.py b/chatlas/__init__.py index ef921da5..5e967897 100644 --- a/chatlas/__init__.py +++ b/chatlas/__init__.py @@ -13,6 +13,7 @@ from ._tokens import token_usage from ._tools import Tool from ._turn import Turn +from ._vllm import ChatVLLM __all__ = ( "ChatAnthropic", @@ -24,6 +25,7 @@ "ChatOpenAI", "ChatAzureOpenAI", "ChatPerplexity", + "ChatVLLM", "Chat", "content_image_file", "content_image_plot", diff --git a/chatlas/_anthropic.py b/chatlas/_anthropic.py index bdb6cbb6..50a9a99a 100644 --- a/chatlas/_anthropic.py +++ b/chatlas/_anthropic.py @@ -286,7 +286,7 @@ def _chat_perform_args( kwargs: Optional["SubmitInputArgs"] = None, ) -> "SubmitInputArgs": tool_schemas = [ - self._anthropic_tool_schema(tool.schema) for tool in tools.values() + self._tool_schema_json(tool.schema) for tool in tools.values() ] # If data extraction is requested, add a "mock" tool with parameters inferred from the data model @@ -306,7 +306,7 @@ def _structured_tool_call(**kwargs: Any): }, } - tool_schemas.append(self._anthropic_tool_schema(data_model_tool.schema)) + tool_schemas.append(self._tool_schema_json(data_model_tool.schema)) if stream: stream = False @@ -430,7 +430,7 @@ def _as_content_block(content: Content) -> "ContentBlockParam": raise ValueError(f"Unknown content type: {type(content)}") @staticmethod - def _anthropic_tool_schema(schema: "ChatCompletionToolParam") -> "ToolParam": + def _tool_schema_json(schema: "ChatCompletionToolParam") -> "ToolParam": fn = schema["function"] name = fn["name"] diff --git a/chatlas/_openai.py b/chatlas/_openai.py index 370ffb74..122c35f8 100644 --- a/chatlas/_openai.py +++ b/chatlas/_openai.py @@ -28,6 +28,7 @@ ChatCompletion, ChatCompletionChunk, ChatCompletionMessageParam, + ChatCompletionToolParam, ) from openai.types.chat.chat_completion_assistant_message_param import ( ContentArrayOfContentPart, @@ -288,7 +289,7 @@ def _chat_perform_args( data_model: Optional[type[BaseModel]] = None, kwargs: Optional["SubmitInputArgs"] = None, ) -> "SubmitInputArgs": - tool_schemas = [tool.schema for tool in tools.values()] + tool_schemas = [self._tool_schema_json(tool.schema) for tool in tools.values()] kwargs_full: "SubmitInputArgs" = { "stream": stream, @@ -454,6 +455,12 @@ def _as_message_param(turns: list[Turn]) -> list["ChatCompletionMessageParam"]: return res + @staticmethod + def _tool_schema_json( + schema: "ChatCompletionToolParam", + ) -> "ChatCompletionToolParam": + return schema + def _as_turn( self, completion: "ChatCompletion", has_data_model: bool ) -> Turn[ChatCompletion]: diff --git a/chatlas/_vllm.py b/chatlas/_vllm.py new file mode 100644 index 00000000..928cc426 --- /dev/null +++ b/chatlas/_vllm.py @@ -0,0 +1,146 @@ +import os +from typing import TYPE_CHECKING, Optional + +import requests + +from ._chat import Chat +from ._openai import OpenAIProvider +from ._turn import Turn, normalize_turns + +if TYPE_CHECKING: + from openai.types.chat import ChatCompletionToolParam + + from .types.openai import ChatClientArgs + + +def ChatVLLM( + *, + base_url: str, + system_prompt: Optional[str] = None, + turns: Optional[list[Turn]] = None, + model: Optional[str] = None, + api_key: Optional[str] = None, + seed: Optional[int] = None, + kwargs: Optional["ChatClientArgs"] = None, +) -> Chat: + """ + Chat with a model hosted by vLLM + + [vLLM](https://docs.vllm.ai/en/latest/) is an open source library that + provides an efficient and convenient LLMs model server. You can use + `ChatVLLM()` to connect to endpoints powered by vLLM. + + Prerequisites + ------------- + + ::: {.callout-note} + ## vLLM runtime + + `ChatVLLM` requires a vLLM server to be running somewhere (either on your + machine or a remote server). If you want to run a vLLM server locally, see + the [vLLM documentation](https://docs.vllm.ai/en/v0.5.3/getting_started/quickstart.html). + ::: + + ::: {.callout-note} + ## Python requirements + + `ChatVLLM` requires the `openai` package (e.g., `pip install openai`). + ::: + + + Parameters + ---------- + base_url + A system prompt to set the behavior of the assistant. + system_prompt + Optional system prompt to prepend to conversation. + turns + A list of turns to start the chat with (i.e., continuing a previous + conversation). If not provided, the conversation begins from scratch. Do + not provide non-`None` values for both `turns` and `system_prompt`. Each + message in the list should be a dictionary with at least `role` (usually + `system`, `user`, or `assistant`, but `tool` is also possible). Normally + there is also a `content` field, which is a string. + model + Model identifier to use. + seed + Random seed for reproducibility. + api_key + API key for authentication. If not provided, the `VLLM_API_KEY` environment + variable will be used. + kwargs + Additional arguments to pass to the LLM client. + + Returns: + Chat instance configured for vLLM + """ + + if api_key is None: + api_key = get_vllm_key() + + if model is None: + models = get_vllm_models(base_url, api_key) + available_models = ", ".join(models) + raise ValueError(f"Must specify model. Available models: {available_models}") + + return Chat( + provider=VLLMProvider( + base_url=base_url, + model=model, + seed=seed, + api_key=api_key, + kwargs=kwargs, + ), + turns=normalize_turns( + turns or [], + system_prompt, + ), + ) + + +class VLLMProvider(OpenAIProvider): + def __init__( + self, + base_url: str, + model: str, + seed: int | None, + api_key: str | None, + kwargs: Optional["ChatClientArgs"], + ): + self.base_url = base_url + self.model = model + self.seed = seed + self.api_key = api_key + self.kwargs = kwargs + + # Just like OpenAI but no strict + @staticmethod + def _tool_schema_json( + schema: "ChatCompletionToolParam", + ) -> "ChatCompletionToolParam": + schema["function"]["strict"] = False + return schema + + +def get_vllm_key() -> str: + key = os.getenv("VLLM_API_KEY", os.getenv("VLLM_KEY")) + if not key: + raise ValueError("VLLM_API_KEY environment variable not set") + return key + + +def get_vllm_models(base_url: str, api_key: Optional[str] = None) -> list[str]: + if api_key is None: + api_key = get_vllm_key() + + headers = {"Authorization": f"Bearer {api_key}"} + response = requests.get(f"{base_url}/v1/models", headers=headers) + response.raise_for_status() + data = response.json() + + return [model["id"] for model in data["data"]] + + +# def chat_vllm_test(**kwargs) -> Chat: +# """Create a test chat instance with default parameters.""" +# return ChatVLLM(base_url="https://llm.nrp-nautilus.io/", model="llama3", **kwargs) diff --git a/docs/_quarto.yml b/docs/_quarto.yml index d937cc52..eba43105 100644 --- a/docs/_quarto.yml +++ b/docs/_quarto.yml @@ -69,6 +69,7 @@ quartodoc: - ChatOllama - ChatOpenAI - ChatPerplexity + - ChatVLLM - title: The chat object desc: Methods and attributes available on a chat instance contents: From 5033d067b1b0f78653bf2c11649184a7e40dfc8c Mon Sep 17 00:00:00 2001 From: Carson Date: Wed, 18 Dec 2024 16:28:54 -0600 Subject: [PATCH 2/4] update changelog --- CHANGELOG.md | 2 ++ chatlas/_vllm.py | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 616af04a..ce810351 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### New features +* Adds vLLM support via a new `ChatVLLM` class. (#24) + ### Bug fixes * `ChatOllama` no longer fails when a `OPENAI_API_KEY` environment variable is not set. diff --git a/chatlas/_vllm.py b/chatlas/_vllm.py index 928cc426..a84f44ec 100644 --- a/chatlas/_vllm.py +++ b/chatlas/_vllm.py @@ -103,8 +103,8 @@ def __init__( self, base_url: str, model: str, - seed: int | None, - api_key: str | None, + seed: Optional[int], + api_key: Optional[str], kwargs: Optional["ChatClientArgs"], ): self.base_url = base_url From fb86ef9d50838735cd3f1607a7dc1e0fff69b271 Mon Sep 17 00:00:00 2001 From: Carson Date: Tue, 12 Aug 2025 17:17:50 -0500 Subject: [PATCH 3/4] cleanup --- chatlas/__init__.py | 4 +- chatlas/_provider_vllm.py | 189 ++++++++++++++------------------------ chatlas/_vllm.py | 146 ----------------------------- 3 files changed, 71 insertions(+), 268 deletions(-) delete mode 100644 chatlas/_vllm.py diff --git a/chatlas/__init__.py b/chatlas/__init__.py index 54ded830..2c505022 100644 --- a/chatlas/__init__.py +++ b/chatlas/__init__.py @@ -19,10 +19,10 @@ from ._provider_perplexity import ChatPerplexity from ._provider_portkey import ChatPortkey from ._provider_snowflake import ChatSnowflake +from ._provider_vllm import ChatVllm from ._tokens import token_usage from ._tools import Tool, ToolRejectError from ._turn import Turn -from ._vllm import ChatVLLM try: from ._version import version as __version__ @@ -44,10 +44,10 @@ "ChatOpenRouter", "ChatAzureOpenAI", "ChatPerplexity", - "ChatVLLM", "ChatPortkey", "ChatSnowflake", "ChatVertex", + "ChatVllm", "Chat", "content_image_file", "content_image_plot", diff --git a/chatlas/_provider_vllm.py b/chatlas/_provider_vllm.py index 7f7fb53f..780d27aa 100644 --- a/chatlas/_provider_vllm.py +++ b/chatlas/_provider_vllm.py @@ -1,15 +1,15 @@ -from __future__ import annotations - import os from typing import TYPE_CHECKING, Optional +import requests + from ._chat import Chat from ._provider_openai import OpenAIProvider -from ._utils import MISSING, MISSING_TYPE, is_testing if TYPE_CHECKING: - from ._provider_openai import ChatCompletion - from .types.openai import ChatClientArgs, SubmitInputArgs + from openai.types.chat import ChatCompletionToolParam + + from .types.openai import ChatClientArgs def ChatVllm( @@ -18,11 +18,11 @@ def ChatVllm( system_prompt: Optional[str] = None, model: Optional[str] = None, api_key: Optional[str] = None, - seed: Optional[int] | MISSING_TYPE = MISSING, + seed: Optional[int] = None, kwargs: Optional["ChatClientArgs"] = None, -) -> Chat["SubmitInputArgs", ChatCompletion]: +) -> Chat: """ - Chat with a model hosted by vLLM. + Chat with a model hosted by vLLM [vLLM](https://docs.vllm.ai/en/latest/) is an open source library that provides an efficient and convenient LLMs model server. You can use @@ -32,147 +32,96 @@ def ChatVllm( ------------- ::: {.callout-note} - ## vLLM Server + ## vLLM runtime - You need access to a running vLLM server instance. vLLM provides - OpenAI-compatible API endpoints, so this function works with any - vLLM deployment that exposes the `/v1/chat/completions` endpoint. + `ChatVllm` requires a vLLM server to be running somewhere (either on your + machine or a remote server). If you want to run a vLLM server locally, see + the [vLLM documentation](https://docs.vllm.ai/en/v0.5.3/getting_started/quickstart.html). ::: - Examples - -------- + ::: {.callout-note} + ## Python requirements - ```python - import os - from chatlas import ChatVllm + `ChatVllm` requires the `openai` package (e.g., `pip install openai`). + ::: - # Connect to a vLLM server - chat = ChatVllm( - base_url="http://localhost:8000/v1", - model="meta-llama/Llama-2-7b-chat-hf", - api_key=os.getenv("VLLM_API_KEY"), # Optional, depends on server config - ) - chat.chat("What is the capital of France?") - ``` Parameters ---------- base_url - The base URL of the vLLM server endpoint. This should include the - `/v1` path if the server follows OpenAI API conventions. - system_prompt A system prompt to set the behavior of the assistant. + system_prompt + Optional system prompt to prepend to conversation. + turns + A list of turns to start the chat with (i.e., continuing a previous + conversation). If not provided, the conversation begins from scratch. Do + not provide non-`None` values for both `turns` and `system_prompt`. Each + message in the list should be a dictionary with at least `role` (usually + `system`, `user`, or `assistant`, but `tool` is also possible). Normally + there is also a `content` field, which is a string. model - The model to use for the chat. If None, you may need to specify - the model name that's loaded on your vLLM server. - api_key - The API key to use for authentication. Some vLLM deployments may - not require authentication. You can set the `VLLM_API_KEY` - environment variable instead of passing it directly. + Model identifier to use. seed - Optional integer seed that vLLM uses to try and make output more - reproducible. + Random seed for reproducibility. + api_key + API key for authentication. If not provided, the `VLLM_API_KEY` environment + variable will be used. kwargs - Additional arguments to pass to the `openai.OpenAI()` client constructor. - - Returns - ------- - Chat - A chat object that retains the state of the conversation. - - Note - ---- - This function is a lightweight wrapper around [](`~chatlas.ChatOpenAI`) with - the defaults tweaked for vLLM endpoints. - - Note - ---- - vLLM servers are OpenAI-compatible, so this provider uses the same underlying - client as OpenAI but configured for your vLLM endpoint. Some advanced OpenAI - features may not be available depending on your vLLM server configuration. - - Note - ---- - Pasting an API key into a chat constructor (e.g., `ChatVllm(api_key="...")`) - is the simplest way to get started, and is fine for interactive use, but is - problematic for code that may be shared with others. - - Instead, consider using environment variables or a configuration file to manage - your credentials. One popular way to manage credentials is to use a `.env` file - to store your credentials, and then use the `python-dotenv` package to load them - into your environment. - - ```shell - pip install python-dotenv - ``` - - ```shell - # .env - VLLM_API_KEY=... - ``` - - ```python - from chatlas import ChatVllm - from dotenv import load_dotenv - - load_dotenv() - chat = ChatVllm(base_url="http://localhost:8000/v1") - chat.console() - ``` - - Another, more general, solution is to load your environment variables into the shell - before starting Python (maybe in a `.bashrc`, `.zshrc`, etc. file): - - ```shell - export VLLM_API_KEY=... - ``` + Additional arguments to pass to the LLM client. + + Returns: + Chat instance configured for vLLM """ - if api_key is None: - api_key = os.getenv("VLLM_API_KEY") - if isinstance(seed, MISSING_TYPE): - seed = 1014 if is_testing() else None + if api_key is None: + api_key = get_vllm_key() if model is None: - raise ValueError( - "Must specify model. vLLM servers can host different models, so you need to " - "specify which one to use. Check your vLLM server's /v1/models endpoint " - "to see available models." - ) + models = get_vllm_models(base_url, api_key) + available_models = ", ".join(models) + raise ValueError(f"Must specify model. Available models: {available_models}") return Chat( - provider=VllmProvider( - api_key=api_key, - model=model, + provider=VLLMProvider( base_url=base_url, + model=model, seed=seed, - name="vLLM", + api_key=api_key, kwargs=kwargs, ), system_prompt=system_prompt, ) -class VllmProvider(OpenAIProvider): - """ - Provider for vLLM endpoints. +class VLLMProvider(OpenAIProvider): + # Just like OpenAI but no strict + @staticmethod + def _tool_schema_json( + schema: "ChatCompletionToolParam", + ) -> "ChatCompletionToolParam": + schema["function"]["strict"] = False + return schema - vLLM is OpenAI-compatible but may have some differences in tool handling - and other advanced features. - """ - def _chat_perform_args(self, *args, **kwargs): - """ - Customize request arguments for vLLM compatibility. +def get_vllm_key() -> str: + key = os.getenv("VLLM_API_KEY", os.getenv("VLLM_KEY")) + if not key: + raise ValueError("VLLM_API_KEY environment variable not set") + return key + + +def get_vllm_models(base_url: str, api_key: Optional[str] = None) -> list[str]: + if api_key is None: + api_key = get_vllm_key() + + headers = {"Authorization": f"Bearer {api_key}"} + response = requests.get(f"{base_url}/v1/models", headers=headers) + response.raise_for_status() + data = response.json() - vLLM may not support all OpenAI features like stream_options, - so we remove potentially unsupported parameters. - """ - # Get the base arguments from OpenAI provider - result = super()._chat_perform_args(*args, **kwargs) + return [model["id"] for model in data["data"]] - # Remove stream_options if present (some vLLM versions don't support it) - if "stream_options" in result: - del result["stream_options"] - return result +# def chat_vllm_test(**kwargs) -> Chat: +# """Create a test chat instance with default parameters.""" +# return ChatVllm(base_url="https://llm.nrp-nautilus.io/", model="llama3", **kwargs) diff --git a/chatlas/_vllm.py b/chatlas/_vllm.py deleted file mode 100644 index a84f44ec..00000000 --- a/chatlas/_vllm.py +++ /dev/null @@ -1,146 +0,0 @@ -import os -from typing import TYPE_CHECKING, Optional - -import requests - -from ._chat import Chat -from ._openai import OpenAIProvider -from ._turn import Turn, normalize_turns - -if TYPE_CHECKING: - from openai.types.chat import ChatCompletionToolParam - - from .types.openai import ChatClientArgs - - -def ChatVLLM( - *, - base_url: str, - system_prompt: Optional[str] = None, - turns: Optional[list[Turn]] = None, - model: Optional[str] = None, - api_key: Optional[str] = None, - seed: Optional[int] = None, - kwargs: Optional["ChatClientArgs"] = None, -) -> Chat: - """ - Chat with a model hosted by vLLM - - [vLLM](https://docs.vllm.ai/en/latest/) is an open source library that - provides an efficient and convenient LLMs model server. You can use - `ChatVLLM()` to connect to endpoints powered by vLLM. - - Prerequisites - ------------- - - ::: {.callout-note} - ## vLLM runtime - - `ChatVLLM` requires a vLLM server to be running somewhere (either on your - machine or a remote server). If you want to run a vLLM server locally, see - the [vLLM documentation](https://docs.vllm.ai/en/v0.5.3/getting_started/quickstart.html). - ::: - - ::: {.callout-note} - ## Python requirements - - `ChatVLLM` requires the `openai` package (e.g., `pip install openai`). - ::: - - - Parameters - ---------- - base_url - A system prompt to set the behavior of the assistant. - system_prompt - Optional system prompt to prepend to conversation. - turns - A list of turns to start the chat with (i.e., continuing a previous - conversation). If not provided, the conversation begins from scratch. Do - not provide non-`None` values for both `turns` and `system_prompt`. Each - message in the list should be a dictionary with at least `role` (usually - `system`, `user`, or `assistant`, but `tool` is also possible). Normally - there is also a `content` field, which is a string. - model - Model identifier to use. - seed - Random seed for reproducibility. - api_key - API key for authentication. If not provided, the `VLLM_API_KEY` environment - variable will be used. - kwargs - Additional arguments to pass to the LLM client. - - Returns: - Chat instance configured for vLLM - """ - - if api_key is None: - api_key = get_vllm_key() - - if model is None: - models = get_vllm_models(base_url, api_key) - available_models = ", ".join(models) - raise ValueError(f"Must specify model. Available models: {available_models}") - - return Chat( - provider=VLLMProvider( - base_url=base_url, - model=model, - seed=seed, - api_key=api_key, - kwargs=kwargs, - ), - turns=normalize_turns( - turns or [], - system_prompt, - ), - ) - - -class VLLMProvider(OpenAIProvider): - def __init__( - self, - base_url: str, - model: str, - seed: Optional[int], - api_key: Optional[str], - kwargs: Optional["ChatClientArgs"], - ): - self.base_url = base_url - self.model = model - self.seed = seed - self.api_key = api_key - self.kwargs = kwargs - - # Just like OpenAI but no strict - @staticmethod - def _tool_schema_json( - schema: "ChatCompletionToolParam", - ) -> "ChatCompletionToolParam": - schema["function"]["strict"] = False - return schema - - -def get_vllm_key() -> str: - key = os.getenv("VLLM_API_KEY", os.getenv("VLLM_KEY")) - if not key: - raise ValueError("VLLM_API_KEY environment variable not set") - return key - - -def get_vllm_models(base_url: str, api_key: Optional[str] = None) -> list[str]: - if api_key is None: - api_key = get_vllm_key() - - headers = {"Authorization": f"Bearer {api_key}"} - response = requests.get(f"{base_url}/v1/models", headers=headers) - response.raise_for_status() - data = response.json() - - return [model["id"] for model in data["data"]] - - -# def chat_vllm_test(**kwargs) -> Chat: -# """Create a test chat instance with default parameters.""" -# return ChatVLLM(base_url="https://llm.nrp-nautilus.io/", model="llama3", **kwargs) From 22ee0dd316af00b1ac125d3553f5ad396faa761f Mon Sep 17 00:00:00 2001 From: Carson Date: Tue, 12 Aug 2025 17:48:26 -0500 Subject: [PATCH 4/4] Add tests --- chatlas/_provider_portkey.py | 3 +- chatlas/_provider_vllm.py | 30 ++++------ tests/test_provider_vllm.py | 106 +++++++++++++++++++++++++++++++++++ 3 files changed, 119 insertions(+), 20 deletions(-) create mode 100644 tests/test_provider_vllm.py diff --git a/chatlas/_provider_portkey.py b/chatlas/_provider_portkey.py index f4d56a5f..6c63a342 100644 --- a/chatlas/_provider_portkey.py +++ b/chatlas/_provider_portkey.py @@ -78,11 +78,10 @@ def ChatPortkey( Chat A chat object that retains the state of the conversation. - Notes + Note ----- This function is a lightweight wrapper around [](`~chatlas.ChatOpenAI`) with the defaults tweaked for PortkeyAI. - """ if model is None: model = log_model_default("gpt-4.1") diff --git a/chatlas/_provider_vllm.py b/chatlas/_provider_vllm.py index 780d27aa..a04ddb00 100644 --- a/chatlas/_provider_vllm.py +++ b/chatlas/_provider_vllm.py @@ -36,29 +36,16 @@ def ChatVllm( `ChatVllm` requires a vLLM server to be running somewhere (either on your machine or a remote server). If you want to run a vLLM server locally, see - the [vLLM documentation](https://docs.vllm.ai/en/v0.5.3/getting_started/quickstart.html). - ::: - - ::: {.callout-note} - ## Python requirements - - `ChatVllm` requires the `openai` package (e.g., `pip install openai`). + the [vLLM documentation](https://docs.vllm.ai/en/stable/getting_started/quickstart.html). ::: Parameters ---------- base_url - A system prompt to set the behavior of the assistant. + Base URL of the vLLM server (e.g., "http://localhost:8000/v1"). system_prompt - Optional system prompt to prepend to conversation. - turns - A list of turns to start the chat with (i.e., continuing a previous - conversation). If not provided, the conversation begins from scratch. Do - not provide non-`None` values for both `turns` and `system_prompt`. Each - message in the list should be a dictionary with at least `role` (usually - `system`, `user`, or `assistant`, but `tool` is also possible). Normally - there is also a `content` field, which is a string. + A system prompt to set the behavior of the assistant. model Model identifier to use. seed @@ -69,8 +56,15 @@ def ChatVllm( kwargs Additional arguments to pass to the LLM client. - Returns: - Chat instance configured for vLLM + Return + ------ + Chat + A chat object that retains the state of the conversation. + + Note + ----- + This function is a lightweight wrapper around [](`~chatlas.ChatOpenAI`) with + the defaults tweaked for PortkeyAI. """ if api_key is None: diff --git a/tests/test_provider_vllm.py b/tests/test_provider_vllm.py new file mode 100644 index 00000000..ba0dde39 --- /dev/null +++ b/tests/test_provider_vllm.py @@ -0,0 +1,106 @@ +import os + +import pytest + +do_test = os.getenv("TEST_VLLM", "true") +if do_test.lower() == "false": + pytest.skip("Skipping vLLM tests", allow_module_level=True) + +from chatlas import ChatVllm + +from .conftest import ( + assert_tools_async, + assert_tools_simple, + assert_turns_existing, + assert_turns_system, +) + + +def test_vllm_simple_request(): + # This test assumes you have a vLLM server running locally + # Skip if TEST_VLLM_BASE_URL is not set + base_url = os.getenv("TEST_VLLM_BASE_URL") + if base_url is None: + pytest.skip("TEST_VLLM_BASE_URL is not set; skipping vLLM tests") + + model = os.getenv("TEST_VLLM_MODEL", "llama3") + + chat = ChatVllm( + base_url=base_url, + model=model, + system_prompt="Be as terse as possible; no punctuation", + ) + chat.chat("What is 1 + 1?") + turn = chat.get_last_turn() + assert turn is not None + assert turn.tokens is not None + assert len(turn.tokens) == 3 + assert turn.tokens[0] >= 10 # More lenient assertion for vLLM + assert turn.finish_reason == "stop" + + +@pytest.mark.asyncio +async def test_vllm_simple_streaming_request(): + base_url = os.getenv("TEST_VLLM_BASE_URL") + if base_url is None: + pytest.skip("TEST_VLLM_BASE_URL is not set; skipping vLLM tests") + + model = os.getenv("TEST_VLLM_MODEL", "llama3") + + chat = ChatVllm( + base_url=base_url, + model=model, + system_prompt="Be as terse as possible; no punctuation", + ) + res = [] + async for x in await chat.stream_async("What is 1 + 1?"): + res.append(x) + assert "2" in "".join(res) + turn = chat.get_last_turn() + assert turn is not None + assert turn.finish_reason == "stop" + + +def test_vllm_respects_turns_interface(): + base_url = os.getenv("TEST_VLLM_BASE_URL") + if base_url is None: + pytest.skip("TEST_VLLM_BASE_URL is not set; skipping vLLM tests") + + model = os.getenv("TEST_VLLM_MODEL", "llama3") + + def chat_fun(**kwargs): + return ChatVllm(base_url=base_url, model=model, **kwargs) + + assert_turns_system(chat_fun) + assert_turns_existing(chat_fun) + + +def test_vllm_tool_variations(): + base_url = os.getenv("TEST_VLLM_BASE_URL") + if base_url is None: + pytest.skip("TEST_VLLM_BASE_URL is not set; skipping vLLM tests") + + model = os.getenv("TEST_VLLM_MODEL", "llama3") + + def chat_fun(**kwargs): + return ChatVllm(base_url=base_url, model=model, **kwargs) + + assert_tools_simple(chat_fun) + + +@pytest.mark.asyncio +async def test_vllm_tool_variations_async(): + base_url = os.getenv("TEST_VLLM_BASE_URL") + if base_url is None: + pytest.skip("TEST_VLLM_BASE_URL is not set; skipping vLLM tests") + + model = os.getenv("TEST_VLLM_MODEL", "llama3") + + def chat_fun(**kwargs): + return ChatVllm(base_url=base_url, model=model, **kwargs) + + await assert_tools_async(chat_fun) + + +# Note: vLLM support for data extraction and images depends on the specific model +# and configuration, so we skip those tests for now \ No newline at end of file