diff --git a/CHANGELOG.md b/CHANGELOG.md index a9fc7ac..35f678c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Added `ChatOpenRouter()` for chatting via [Open Router](https://openrouter.ai/). (#148) * Added `ChatHuggingFace()` for chatting via [Hugging Face](https://huggingface.co/). (#144) * Added `ChatPortkey()` for chatting via [Portkey AI](https://portkey.ai/). (#143) +* Added `ChatVllm()` for chatting via [vLLM](https://docs.vllm.ai/en/latest/). (#24) ### Bug fixes diff --git a/chatlas/__init__.py b/chatlas/__init__.py index b3b47f3..2c50502 100644 --- a/chatlas/__init__.py +++ b/chatlas/__init__.py @@ -19,6 +19,7 @@ from ._provider_perplexity import ChatPerplexity from ._provider_portkey import ChatPortkey from ._provider_snowflake import ChatSnowflake +from ._provider_vllm import ChatVllm from ._tokens import token_usage from ._tools import Tool, ToolRejectError from ._turn import Turn @@ -46,6 +47,7 @@ "ChatPortkey", "ChatSnowflake", "ChatVertex", + "ChatVllm", "Chat", "content_image_file", "content_image_plot", diff --git a/chatlas/_provider_anthropic.py b/chatlas/_provider_anthropic.py index ab0961c..2261c25 100644 --- a/chatlas/_provider_anthropic.py +++ b/chatlas/_provider_anthropic.py @@ -286,7 +286,7 @@ def _chat_perform_args( kwargs: Optional["SubmitInputArgs"] = None, ) -> "SubmitInputArgs": tool_schemas = [ - self._anthropic_tool_schema(tool.schema) for tool in tools.values() + self._tool_schema_json(tool.schema) for tool in tools.values() ] # If data extraction is requested, add a "mock" tool with parameters inferred from the data model @@ -306,7 +306,7 @@ def _structured_tool_call(**kwargs: Any): }, } - tool_schemas.append(self._anthropic_tool_schema(data_model_tool.schema)) + tool_schemas.append(self._tool_schema_json(data_model_tool.schema)) if stream: stream = False @@ -542,7 +542,7 @@ def _as_content_block(content: Content) -> "ContentBlockParam": raise ValueError(f"Unknown content type: {type(content)}") @staticmethod - def _anthropic_tool_schema(schema: "ChatCompletionToolParam") -> "ToolParam": + def _tool_schema_json(schema: "ChatCompletionToolParam") -> "ToolParam": fn = schema["function"] name = fn["name"] diff --git a/chatlas/_provider_openai.py b/chatlas/_provider_openai.py index 215bb28..5e33630 100644 --- a/chatlas/_provider_openai.py +++ b/chatlas/_provider_openai.py @@ -34,6 +34,7 @@ ChatCompletion, ChatCompletionChunk, ChatCompletionMessageParam, + ChatCompletionToolParam, ) from openai.types.chat.chat_completion_assistant_message_param import ( ContentArrayOfContentPart, @@ -276,7 +277,7 @@ def _chat_perform_args( data_model: Optional[type[BaseModel]] = None, kwargs: Optional["SubmitInputArgs"] = None, ) -> "SubmitInputArgs": - tool_schemas = [tool.schema for tool in tools.values()] + tool_schemas = [self._tool_schema_json(tool.schema) for tool in tools.values()] kwargs_full: "SubmitInputArgs" = { "stream": stream, @@ -514,6 +515,12 @@ def _as_message_param(turns: list[Turn]) -> list["ChatCompletionMessageParam"]: return res + @staticmethod + def _tool_schema_json( + schema: "ChatCompletionToolParam", + ) -> "ChatCompletionToolParam": + return schema + def _as_turn( self, completion: "ChatCompletion", has_data_model: bool ) -> Turn[ChatCompletion]: diff --git a/chatlas/_provider_portkey.py b/chatlas/_provider_portkey.py index f4d56a5..6c63a34 100644 --- a/chatlas/_provider_portkey.py +++ b/chatlas/_provider_portkey.py @@ -78,11 +78,10 @@ def ChatPortkey( Chat A chat object that retains the state of the conversation. - Notes + Note ----- This function is a lightweight wrapper around [](`~chatlas.ChatOpenAI`) with the defaults tweaked for PortkeyAI. - """ if model is None: model = log_model_default("gpt-4.1") diff --git a/chatlas/_provider_vllm.py b/chatlas/_provider_vllm.py new file mode 100644 index 0000000..a04ddb0 --- /dev/null +++ b/chatlas/_provider_vllm.py @@ -0,0 +1,121 @@ +import os +from typing import TYPE_CHECKING, Optional + +import requests + +from ._chat import Chat +from ._provider_openai import OpenAIProvider + +if TYPE_CHECKING: + from openai.types.chat import ChatCompletionToolParam + + from .types.openai import ChatClientArgs + + +def ChatVllm( + *, + base_url: str, + system_prompt: Optional[str] = None, + model: Optional[str] = None, + api_key: Optional[str] = None, + seed: Optional[int] = None, + kwargs: Optional["ChatClientArgs"] = None, +) -> Chat: + """ + Chat with a model hosted by vLLM + + [vLLM](https://docs.vllm.ai/en/latest/) is an open source library that + provides an efficient and convenient LLMs model server. You can use + `ChatVllm()` to connect to endpoints powered by vLLM. + + Prerequisites + ------------- + + ::: {.callout-note} + ## vLLM runtime + + `ChatVllm` requires a vLLM server to be running somewhere (either on your + machine or a remote server). If you want to run a vLLM server locally, see + the [vLLM documentation](https://docs.vllm.ai/en/stable/getting_started/quickstart.html). + ::: + + + Parameters + ---------- + base_url + Base URL of the vLLM server (e.g., "http://localhost:8000/v1"). + system_prompt + A system prompt to set the behavior of the assistant. + model + Model identifier to use. + seed + Random seed for reproducibility. + api_key + API key for authentication. If not provided, the `VLLM_API_KEY` environment + variable will be used. + kwargs + Additional arguments to pass to the LLM client. + + Return + ------ + Chat + A chat object that retains the state of the conversation. + + Note + ----- + This function is a lightweight wrapper around [](`~chatlas.ChatOpenAI`) with + the defaults tweaked for PortkeyAI. + """ + + if api_key is None: + api_key = get_vllm_key() + + if model is None: + models = get_vllm_models(base_url, api_key) + available_models = ", ".join(models) + raise ValueError(f"Must specify model. Available models: {available_models}") + + return Chat( + provider=VLLMProvider( + base_url=base_url, + model=model, + seed=seed, + api_key=api_key, + kwargs=kwargs, + ), + system_prompt=system_prompt, + ) + + +class VLLMProvider(OpenAIProvider): + # Just like OpenAI but no strict + @staticmethod + def _tool_schema_json( + schema: "ChatCompletionToolParam", + ) -> "ChatCompletionToolParam": + schema["function"]["strict"] = False + return schema + + +def get_vllm_key() -> str: + key = os.getenv("VLLM_API_KEY", os.getenv("VLLM_KEY")) + if not key: + raise ValueError("VLLM_API_KEY environment variable not set") + return key + + +def get_vllm_models(base_url: str, api_key: Optional[str] = None) -> list[str]: + if api_key is None: + api_key = get_vllm_key() + + headers = {"Authorization": f"Bearer {api_key}"} + response = requests.get(f"{base_url}/v1/models", headers=headers) + response.raise_for_status() + data = response.json() + + return [model["id"] for model in data["data"]] + + +# def chat_vllm_test(**kwargs) -> Chat: +# """Create a test chat instance with default parameters.""" +# return ChatVllm(base_url="https://llm.nrp-nautilus.io/", model="llama3", **kwargs) diff --git a/docs/_quarto.yml b/docs/_quarto.yml index 2394b46..43c44ee 100644 --- a/docs/_quarto.yml +++ b/docs/_quarto.yml @@ -129,6 +129,7 @@ quartodoc: - ChatPortkey - ChatSnowflake - ChatVertex + - ChatVllm - title: The chat object desc: Methods and attributes available on a chat instance contents: diff --git a/tests/test_provider_vllm.py b/tests/test_provider_vllm.py new file mode 100644 index 0000000..ba0dde3 --- /dev/null +++ b/tests/test_provider_vllm.py @@ -0,0 +1,106 @@ +import os + +import pytest + +do_test = os.getenv("TEST_VLLM", "true") +if do_test.lower() == "false": + pytest.skip("Skipping vLLM tests", allow_module_level=True) + +from chatlas import ChatVllm + +from .conftest import ( + assert_tools_async, + assert_tools_simple, + assert_turns_existing, + assert_turns_system, +) + + +def test_vllm_simple_request(): + # This test assumes you have a vLLM server running locally + # Skip if TEST_VLLM_BASE_URL is not set + base_url = os.getenv("TEST_VLLM_BASE_URL") + if base_url is None: + pytest.skip("TEST_VLLM_BASE_URL is not set; skipping vLLM tests") + + model = os.getenv("TEST_VLLM_MODEL", "llama3") + + chat = ChatVllm( + base_url=base_url, + model=model, + system_prompt="Be as terse as possible; no punctuation", + ) + chat.chat("What is 1 + 1?") + turn = chat.get_last_turn() + assert turn is not None + assert turn.tokens is not None + assert len(turn.tokens) == 3 + assert turn.tokens[0] >= 10 # More lenient assertion for vLLM + assert turn.finish_reason == "stop" + + +@pytest.mark.asyncio +async def test_vllm_simple_streaming_request(): + base_url = os.getenv("TEST_VLLM_BASE_URL") + if base_url is None: + pytest.skip("TEST_VLLM_BASE_URL is not set; skipping vLLM tests") + + model = os.getenv("TEST_VLLM_MODEL", "llama3") + + chat = ChatVllm( + base_url=base_url, + model=model, + system_prompt="Be as terse as possible; no punctuation", + ) + res = [] + async for x in await chat.stream_async("What is 1 + 1?"): + res.append(x) + assert "2" in "".join(res) + turn = chat.get_last_turn() + assert turn is not None + assert turn.finish_reason == "stop" + + +def test_vllm_respects_turns_interface(): + base_url = os.getenv("TEST_VLLM_BASE_URL") + if base_url is None: + pytest.skip("TEST_VLLM_BASE_URL is not set; skipping vLLM tests") + + model = os.getenv("TEST_VLLM_MODEL", "llama3") + + def chat_fun(**kwargs): + return ChatVllm(base_url=base_url, model=model, **kwargs) + + assert_turns_system(chat_fun) + assert_turns_existing(chat_fun) + + +def test_vllm_tool_variations(): + base_url = os.getenv("TEST_VLLM_BASE_URL") + if base_url is None: + pytest.skip("TEST_VLLM_BASE_URL is not set; skipping vLLM tests") + + model = os.getenv("TEST_VLLM_MODEL", "llama3") + + def chat_fun(**kwargs): + return ChatVllm(base_url=base_url, model=model, **kwargs) + + assert_tools_simple(chat_fun) + + +@pytest.mark.asyncio +async def test_vllm_tool_variations_async(): + base_url = os.getenv("TEST_VLLM_BASE_URL") + if base_url is None: + pytest.skip("TEST_VLLM_BASE_URL is not set; skipping vLLM tests") + + model = os.getenv("TEST_VLLM_MODEL", "llama3") + + def chat_fun(**kwargs): + return ChatVllm(base_url=base_url, model=model, **kwargs) + + await assert_tools_async(chat_fun) + + +# Note: vLLM support for data extraction and images depends on the specific model +# and configuration, so we skip those tests for now \ No newline at end of file