diff --git a/.github/workflows/update-pricing.yml b/.github/workflows/update-pricing.yml new file mode 100644 index 00000000..9851e230 --- /dev/null +++ b/.github/workflows/update-pricing.yml @@ -0,0 +1,38 @@ +name: Update Pricing + +on: + workflow_dispatch: + pull_request: + types: [opened, synchronize, reopened, ready_for_review] + +jobs: + check-pricing: + name: Check for pricing updates in Ellmer + runs-on: ubuntu-latest + + steps: + + - name: Checkout current prices.json in chatlas + uses: actions/checkout@v4 + with: + sparse-checkout: /chatlas/data/prices.json + sparse-checkout-cone-mode: false + path: main + + - name: Get Ellmer prices.json + uses: actions/checkout@v4 + with: + sparse-checkout: /data-raw/prices.json + sparse-checkout-cone-mode: false + repository: tidyverse/ellmer + path: ellmer + + - name: Check for differences + run: | + echo "Checking diff between prices.json" + git diff --no-index --stat ellmer/data-raw/prices.json main/chatlas/data/prices.json + if [[ -n $(git diff --no-index --stat ellmer/data-raw/prices.json main/chatlas/data/prices.json) ]]; then + echo "Changes detected:" + echo "::error::Ellmer's prices.json does not match the current Chatlas prices.json" + exit 1 + fi diff --git a/CHANGELOG.md b/CHANGELOG.md index a3a627cc..c8b37e44 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * `.register_mcp_tools_http_stream_async()` and `.register_mcp_tools_stdio_async()`: for registering tools from a [MCP server](https://modelcontextprotocol.io/). (#39) * `.get_tools()` and `.set_tools()`: for fine-grained control over registered tools. (#39) * `.add_turn()`: to add `Turn`(s) to the current chat history. (#126) + * `.get_cost()`: to get the estimated cost of the chat. Only popular models are supported, but you can also supply your own token prices. (#106) * Tool functions passed to `.register_tool()` can now `yield` numerous results. (#39) * New content classes (`ContentToolResultImage` and `ContentToolResultResource`) were added, primarily to represent MCP tool results that include images/files. (#39) * A `Tool` can now be constructed from a pre-existing tool schema (via a new `__init__` method). (#39) @@ -27,11 +28,20 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changes +#### Breaking Changes + +* `Chat`'s `.tokens()` methods have been removed in favor of `.get_tokens()` which returns both cumulative tokens in the turn and discrete tokens. (#106) +* The base `Provider` class now includes a `name` and `model` property. In order for them to work properly, implementations should pass a `name` and `model` along to the `__init__()` method. (#106) + +#### Other Changes + * `Tool`'s constructor no longer takes a function as input. Use the new `.from_func()` method instead to create a `Tool` from a function. (#39) * `.register_tool()` now throws an exception when the tool has the same name as an already registered tool. Set the new `force` parameter to `True` to force the registration. (#39) ### Improvements +* `Chat`'s representation now includes cost information if it can be calculated. (#106) +* `token_usage()` includes cost if it can be calculated. (#106) * `ChatOpenAI()` and `ChatGithub()` now default to GPT 4.1 (instead of 4o). (#115) * `ChatAnthropic()` now supports `content_image_url()`. (#112) * HTML styling improvements for `ContentToolResult` and `ContentToolRequest`. (#39) diff --git a/chatlas/_anthropic.py b/chatlas/_anthropic.py index b18383bd..7f81723c 100644 --- a/chatlas/_anthropic.py +++ b/chatlas/_anthropic.py @@ -178,11 +178,13 @@ class AnthropicProvider(Provider[Message, RawMessageStreamEvent, Message]): def __init__( self, *, - max_tokens: int, + max_tokens: int = 4096, model: str, - api_key: str | None, + api_key: Optional[str] = None, + name: str = "Anthropic", kwargs: Optional["ChatClientArgs"] = None, ): + super().__init__(name=name, model=model) try: from anthropic import Anthropic, AsyncAnthropic except ImportError: @@ -190,8 +192,6 @@ def __init__( "`ChatAnthropic()` requires the `anthropic` package. " "You can install it with 'pip install anthropic'." ) - - self._model = model self._max_tokens = max_tokens kwargs_full: "ChatClientArgs" = { @@ -314,7 +314,7 @@ def _structured_tool_call(**kwargs: Any): kwargs_full: "SubmitInputArgs" = { "stream": stream, "messages": self._as_message_params(turns), - "model": self._model, + "model": self.model, "max_tokens": self._max_tokens, "tools": tool_schemas, **(kwargs or {}), @@ -712,10 +712,14 @@ def __init__( aws_region: str | None, aws_profile: str | None, aws_session_token: str | None, - max_tokens: int, + max_tokens: int = 4096, base_url: str | None, + name: str = "AnthropicBedrock", kwargs: Optional["ChatBedrockClientArgs"] = None, ): + + super().__init__(name=name, model=model, max_tokens=max_tokens) + try: from anthropic import AnthropicBedrock, AsyncAnthropicBedrock except ImportError: @@ -724,9 +728,6 @@ def __init__( "Install it with `pip install anthropic[bedrock]`." ) - self._model = model - self._max_tokens = max_tokens - kwargs_full: "ChatBedrockClientArgs" = { "aws_secret_key": aws_secret_key, "aws_access_key": aws_access_key, diff --git a/chatlas/_chat.py b/chatlas/_chat.py index 62de9047..00c10bb6 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -44,6 +44,7 @@ from ._logging import log_tool_error from ._mcp_manager import MCPSessionManager from ._provider import Provider +from ._tokens import get_token_pricing from ._tools import Tool, ToolRejectError from ._turn import Turn, user_turn from ._typing_extensions import TypedDict @@ -54,6 +55,22 @@ class AnyTypeDict(TypedDict, total=False): pass +class TokensDict(TypedDict): + """ + A TypedDict representing the token counts for a turn in the chat. + This is used to represent the token counts for each turn in the chat. + `role` represents the role of the turn (i.e., "user" or "assistant"). + `tokens` represents the new tokens used in the turn. + `tokens_total` represents the total tokens used in the turn. + Ex. A new user input of 2 tokens is sent, plus 10 tokens of context from prior turns (input and output). + This would have a `tokens_total` of 12. + """ + + role: Literal["user", "assistant"] + tokens: int + tokens_total: int + + SubmitInputArgsT = TypeVar("SubmitInputArgsT", bound=AnyTypeDict) """ A TypedDict representing the arguments that can be passed to the `.chat()` @@ -64,6 +81,8 @@ class AnyTypeDict(TypedDict, total=False): EchoOptions = Literal["output", "all", "none", "text"] +CostOptions = Literal["all", "last"] + class Chat(Generic[SubmitInputArgsT, CompletionT]): """ @@ -215,43 +234,14 @@ def system_prompt(self, value: str | None): if value is not None: self._turns.insert(0, Turn("system", value)) - @overload - def tokens(self) -> list[tuple[int, int] | None]: ... - - @overload - def tokens( - self, - values: Literal["cumulative"], - ) -> list[tuple[int, int] | None]: ... - - @overload - def tokens( - self, - values: Literal["discrete"], - ) -> list[int]: ... - - def tokens( - self, - values: Literal["cumulative", "discrete"] = "discrete", - ) -> list[int] | list[tuple[int, int] | None]: + def get_tokens(self) -> list[TokensDict]: """ Get the tokens for each turn in the chat. - Parameters - ---------- - values - If "cumulative" (the default), the result can be summed to get the - chat's overall token usage (helpful for computing overall cost of - the chat). If "discrete", the result can be summed to get the number of - tokens the turns will cost to generate the next response (helpful - for estimating cost of the next response, or for determining if you - are about to exceed the token limit). - Returns ------- - list[int] - A list of token counts for each (non-system) turn in the chat. The - 1st turn includes the tokens count for the system prompt (if any). + list[TokensDict] + A list of dictionaries with the token counts for each (non-system) turn Raises ------ @@ -265,9 +255,6 @@ def tokens( turns = self.get_turns(include_system_prompt=False) - if values == "cumulative": - return [turn.tokens for turn in turns] - if len(turns) == 0: return [] @@ -303,12 +290,21 @@ def tokens( "Expected the 1st assistant turn to contain token counts. " + err_info ) - res: list[int] = [ + res: list[TokensDict] = [ # Implied token count for the 1st user input - turns[1].tokens[0], + { + "role": "user", + "tokens": turns[1].tokens[0], + "tokens_total": turns[1].tokens[0], + }, # The token count for the 1st assistant response - turns[1].tokens[1], + { + "role": "assistant", + "tokens": turns[1].tokens[1], + "tokens_total": turns[1].tokens[1], + }, ] + for i in range(1, len(turns) - 1, 2): ti = turns[i] tj = turns[i + 2] @@ -323,15 +319,94 @@ def tokens( ) res.extend( [ - # Implied token count for the user input - tj.tokens[0] - sum(ti.tokens), - # The token count for the assistant response - tj.tokens[1], + { + "role": "user", + # Implied token count for the user input + "tokens": tj.tokens[0] - sum(ti.tokens), + # Total tokens = Total User Tokens for the Turn = Distinct new tokens + context sent + "tokens_total": tj.tokens[0], + }, + { + "role": "assistant", + # The token count for the assistant response + "tokens": tj.tokens[1], + # Total tokens = Total Assistant tokens used in the turn + "tokens_total": tj.tokens[1], + }, ] ) return res + def get_cost( + self, + options: CostOptions = "all", + token_price: Optional[tuple[float, float]] = None, + ) -> float: + """ + Get the cost of the chat. Note that this is a rough estimate. Providers may change their pricing frequently and without notice. + + Parameters + ---------- + options + One of the following (default is "all"): + - `"all"`: Return the total cost of all turns in the chat. + - `"last"`: Return the cost of the last turn in the chat. + token_price + An optional tuple in the format of (input_token_cost, output_token_cost) for bringing your own cost information. + - `"input_token_cost"`: The cost per user token in USD per million tokens. + - `"output_token_cost"`: The cost per assistant token in USD per million tokens. + + Returns + ------- + float + The cost of the chat, in USD. + """ + + # Look up token cost for user and input tokens based on the provider and model + turns_tokens = self.get_tokens() + if token_price: + input_token_price = token_price[0] / 1e6 + output_token_price = token_price[1] / 1e6 + else: + price_token = get_token_pricing(self.provider.name, self.provider.model) + if not price_token: + raise KeyError( + f"We could not locate pricing information for model '{ self.provider.model }' from provider '{ self.provider.name }'. " + "If you know the pricing for this model, specify it in `token_price`." + ) + input_token_price = price_token["input"] / 1e6 + output_token_price = price_token["output"] / 1e6 + + if len(turns_tokens) == 0: + return 0.0 + + if options not in ("all", "last"): + raise ValueError( + f"Expected `options` to be one of 'all' or 'last', not '{ options }'" + ) + + if options == "all": + asst_tokens = sum( + u["tokens_total"] for u in turns_tokens if u["role"] == "assistant" + ) + user_tokens = sum( + u["tokens_total"] for u in turns_tokens if u["role"] == "user" + ) + cost = (asst_tokens * output_token_price) + ( + user_tokens * input_token_price + ) + return cost + + last_turn = turns_tokens[-1] + if last_turn["role"] == "assistant": + return last_turn["tokens"] * output_token_price + if last_turn["role"] == "user": + return last_turn["tokens_total"] * input_token_price + raise ValueError( + f"Expected last turn to have a role of 'user' or `'assistant'`, not '{ last_turn['role'] }'" + ) + def token_count( self, *args: Content | str, @@ -736,9 +811,9 @@ def stream( kwargs=kwargs, ) - def wrapper() -> Generator[ - str | ContentToolRequest | ContentToolResult, None, None - ]: + def wrapper() -> ( + Generator[str | ContentToolRequest | ContentToolResult, None, None] + ): with display: for chunk in generator: yield chunk @@ -800,9 +875,9 @@ async def stream_async( display = self._markdown_display(echo=echo) - async def wrapper() -> AsyncGenerator[ - str | ContentToolRequest | ContentToolResult, None - ]: + async def wrapper() -> ( + AsyncGenerator[str | ContentToolRequest | ContentToolResult, None] + ): with display: async for chunk in self._chat_impl_async( turn, @@ -2006,8 +2081,12 @@ def __str__(self): def __repr__(self): turns = self.get_turns(include_system_prompt=True) - tokens = sum(sum(turn.tokens) for turn in turns if turn.tokens) - res = f"" + tokens = self.get_tokens() + cost = self.get_cost() + tokens_asst = sum(u["tokens_total"] for u in tokens if u["role"] == "assistant") + tokens_user = sum(u["tokens_total"] for u in tokens if u["role"] == "user") + + res = f"" for turn in turns: res += "\n" + turn.__repr__(indent=2) return res + "\n" diff --git a/chatlas/_databricks.py b/chatlas/_databricks.py index b6b0f9ab..5fef7bb2 100644 --- a/chatlas/_databricks.py +++ b/chatlas/_databricks.py @@ -92,6 +92,7 @@ def __init__( self, *, model: str, + name: str = "Databricks", workspace_client: Optional["WorkspaceClient"] = None, ): try: @@ -105,7 +106,8 @@ def __init__( import httpx from openai import AsyncOpenAI - self._model = model + super().__init__(name=name, model=model) + self._seed = None if workspace_client is None: diff --git a/chatlas/_github.py b/chatlas/_github.py index 5916dfad..16b88dfe 100644 --- a/chatlas/_github.py +++ b/chatlas/_github.py @@ -5,8 +5,8 @@ from ._chat import Chat from ._logging import log_model_default -from ._openai import ChatOpenAI -from ._utils import MISSING, MISSING_TYPE +from ._openai import OpenAIProvider +from ._utils import MISSING, MISSING_TYPE, is_testing if TYPE_CHECKING: from ._openai import ChatCompletion @@ -121,11 +121,17 @@ def ChatGithub( if api_key is None: api_key = os.getenv("GITHUB_TOKEN", os.getenv("GITHUB_PAT")) - return ChatOpenAI( + if isinstance(seed, MISSING_TYPE): + seed = 1014 if is_testing() else None + + return Chat( + provider=OpenAIProvider( + api_key=api_key, + model=model, + base_url=base_url, + seed=seed, + name="GitHub", + kwargs=kwargs, + ), system_prompt=system_prompt, - model=model, - api_key=api_key, - base_url=base_url, - seed=seed, - kwargs=kwargs, ) diff --git a/chatlas/_google.py b/chatlas/_google.py index 07fa6195..c8f070f8 100644 --- a/chatlas/_google.py +++ b/chatlas/_google.py @@ -138,6 +138,7 @@ def ChatGoogle( provider=GoogleProvider( model=model, api_key=api_key, + name="Google/Gemini", kwargs=kwargs, ), system_prompt=system_prompt, @@ -154,6 +155,7 @@ def __init__( *, model: str, api_key: str | None, + name: str = "Google/Gemini", kwargs: Optional["ChatClientArgs"], ): try: @@ -163,8 +165,7 @@ def __init__( f"The {self.__class__.__name__} class requires the `google-genai` package. " "Install it with `pip install google-genai`." ) - - self._model = model + super().__init__(name=name, model=model) kwargs_full: "ChatClientArgs" = { "api_key": api_key, @@ -256,7 +257,7 @@ def _chat_perform_args( from google.genai.types import Tool as GoogleTool kwargs_full: "SubmitInputArgs" = { - "model": self._model, + "model": self.model, "contents": cast("GoogleContent", self._google_contents(turns)), **(kwargs or {}), } @@ -592,6 +593,7 @@ def ChatVertex( provider=GoogleProvider( model=model, api_key=api_key, + name="Google/Vertex", kwargs=kwargs, ), system_prompt=system_prompt, diff --git a/chatlas/_groq.py b/chatlas/_groq.py index 712009ac..a775c607 100644 --- a/chatlas/_groq.py +++ b/chatlas/_groq.py @@ -5,8 +5,8 @@ from ._chat import Chat from ._logging import log_model_default -from ._openai import ChatOpenAI -from ._utils import MISSING, MISSING_TYPE +from ._openai import OpenAIProvider +from ._utils import MISSING, MISSING_TYPE, is_testing if TYPE_CHECKING: from ._openai import ChatCompletion @@ -114,14 +114,21 @@ def ChatGroq( """ if model is None: model = log_model_default("llama3-8b-8192") + if api_key is None: api_key = os.getenv("GROQ_API_KEY") - return ChatOpenAI( + if isinstance(seed, MISSING_TYPE): + seed = 1014 if is_testing() else None + + return Chat( + provider=OpenAIProvider( + api_key=api_key, + model=model, + base_url=base_url, + seed=seed, + name="Groq", + kwargs=kwargs, + ), system_prompt=system_prompt, - model=model, - api_key=api_key, - base_url=base_url, - seed=seed, - kwargs=kwargs, ) diff --git a/chatlas/_ollama.py b/chatlas/_ollama.py index ef85db36..7f9a4a74 100644 --- a/chatlas/_ollama.py +++ b/chatlas/_ollama.py @@ -7,7 +7,8 @@ import orjson from ._chat import Chat -from ._openai import ChatOpenAI +from ._openai import OpenAIProvider +from ._utils import MISSING_TYPE, is_testing if TYPE_CHECKING: from ._openai import ChatCompletion @@ -93,14 +94,19 @@ def ChatOllama( raise ValueError( f"Must specify model. Locally installed models: {', '.join(models)}" ) - - return ChatOpenAI( + if isinstance(seed, MISSING_TYPE): + seed = 1014 if is_testing() else None + + return Chat( + provider=OpenAIProvider( + api_key="ollama", # ignored + model=model, + base_url=f"{base_url}/v1", + seed=seed, + name="Ollama", + kwargs=kwargs, + ), system_prompt=system_prompt, - api_key="ollama", # ignored - base_url=f"{base_url}/v1", - model=model, - seed=seed, - kwargs=kwargs, ) diff --git a/chatlas/_openai.py b/chatlas/_openai.py index 39d40a28..d336e3c1 100644 --- a/chatlas/_openai.py +++ b/chatlas/_openai.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, cast, overload import orjson +from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI from pydantic import BaseModel from ._chat import Chat @@ -175,11 +176,11 @@ def __init__( model: str, base_url: str = "https://api.openai.com/v1", seed: Optional[int] = None, + name: str = "OpenAI", kwargs: Optional["ChatClientArgs"] = None, ): - from openai import AsyncOpenAI, OpenAI + super().__init__(name=name, model=model) - self._model = model self._seed = seed kwargs_full: "ChatClientArgs" = { @@ -276,7 +277,7 @@ def _chat_perform_args( kwargs_full: "SubmitInputArgs" = { "stream": stream, "messages": self._as_message_param(turns), - "model": self._model, + "model": self.model, **(kwargs or {}), } @@ -648,15 +649,17 @@ def __init__( self, *, endpoint: Optional[str] = None, - deployment_id: Optional[str] = None, + deployment_id: str, api_version: Optional[str] = None, api_key: Optional[str] = None, seed: int | None = None, + name: str = "OpenAIAzure", + model: Optional[str] = "UnusedValue", kwargs: Optional["ChatAzureClientArgs"] = None, ): - from openai import AsyncAzureOpenAI, AzureOpenAI - self._model = deployment_id + super().__init__(name=name, model=deployment_id) + self._seed = seed kwargs_full: "ChatAzureClientArgs" = { diff --git a/chatlas/_perplexity.py b/chatlas/_perplexity.py index 90b58b0e..2de4ff01 100644 --- a/chatlas/_perplexity.py +++ b/chatlas/_perplexity.py @@ -5,8 +5,8 @@ from ._chat import Chat from ._logging import log_model_default -from ._openai import ChatOpenAI -from ._utils import MISSING, MISSING_TYPE +from ._openai import OpenAIProvider +from ._utils import MISSING, MISSING_TYPE, is_testing if TYPE_CHECKING: from ._openai import ChatCompletion @@ -122,11 +122,17 @@ def ChatPerplexity( if api_key is None: api_key = os.getenv("PERPLEXITY_API_KEY") - return ChatOpenAI( + if isinstance(seed, MISSING_TYPE): + seed = 1014 if is_testing() else None + + return Chat( + provider=OpenAIProvider( + api_key=api_key, + model=model, + base_url=base_url, + seed=seed, + name="Perplexity", + kwargs=kwargs, + ), system_prompt=system_prompt, - model=model, - api_key=api_key, - base_url=base_url, - seed=seed, - kwargs=kwargs, ) diff --git a/chatlas/_provider.py b/chatlas/_provider.py index d6af75bf..dd1a79aa 100644 --- a/chatlas/_provider.py +++ b/chatlas/_provider.py @@ -40,6 +40,24 @@ class Provider( directly. """ + def __init__(self, *, name: str, model: str): + self._name = name + self._model = model + + @property + def name(self): + """ + Get the name of the provider + """ + return self._name + + @property + def model(self): + """ + Get the model used by the provider + """ + return self._model + @overload @abstractmethod def chat_perform( diff --git a/chatlas/_snowflake.py b/chatlas/_snowflake.py index 6a3678d3..07e62ac6 100644 --- a/chatlas/_snowflake.py +++ b/chatlas/_snowflake.py @@ -164,6 +164,7 @@ def __init__( password: Optional[str], private_key_file: Optional[str], private_key_file_pwd: Optional[str], + name: str = "Snowflake", kwargs: Optional[dict[str, "str | int"]], ): try: @@ -174,6 +175,7 @@ def __init__( "`ChatSnowflake()` requires the `snowflake-ml-python` package. " "Please install it via `pip install snowflake-ml-python`." ) + super().__init__(name=name, model=model) configs: dict[str, str | int] = drop_none( { @@ -187,8 +189,6 @@ def __init__( } ) - self._model = model - session = Session.builder.configs(configs).create() self._cortex_service = Root(session).cortex_inference_service @@ -303,7 +303,7 @@ def _complete_request( from snowflake.core.cortex.inference_service import CompleteRequest req = CompleteRequest( - model=self._model, + model=self.model, messages=self._as_request_messages(turns), stream=stream, ) diff --git a/chatlas/_tokens.py b/chatlas/_tokens.py index 35bb0bd6..2a42bea2 100644 --- a/chatlas/_tokens.py +++ b/chatlas/_tokens.py @@ -1,9 +1,13 @@ from __future__ import annotations import copy +import importlib.resources as resources +import warnings from threading import Lock from typing import TYPE_CHECKING +import orjson + from ._logging import logger from ._typing_extensions import TypedDict @@ -17,8 +21,10 @@ class TokenUsage(TypedDict): """ name: str + model: str input: int output: int + cost: float | None class ThreadSafeTokenCounter: @@ -26,7 +32,9 @@ def __init__(self): self._lock = Lock() self._tokens: dict[str, TokenUsage] = {} - def log_tokens(self, name: str, input_tokens: int, output_tokens: int) -> None: + def log_tokens( + self, name: str, model: str, input_tokens: int, output_tokens: int + ) -> None: logger.info( f"Provider '{name}' generated a response of {output_tokens} tokens " f"from an input of {input_tokens} tokens." @@ -36,12 +44,21 @@ def log_tokens(self, name: str, input_tokens: int, output_tokens: int) -> None: if name not in self._tokens: self._tokens[name] = { "name": name, + "model": model, "input": input_tokens, "output": output_tokens, + "cost": compute_price(name, model, input_tokens, output_tokens), } else: self._tokens[name]["input"] += input_tokens self._tokens[name]["output"] += output_tokens + price = compute_price(name, model, input_tokens, output_tokens) + if price is not None: + cost = self._tokens[name]["cost"] + if cost is None: + self._tokens[name]["cost"] = price + else: + self._tokens[name]["cost"] = cost + price def get_usage(self) -> list[TokenUsage] | None: with self._lock: @@ -59,8 +76,7 @@ def tokens_log(provider: "Provider", tokens: tuple[int, int]) -> None: """ Log token usage for a provider in a thread-safe manner. """ - name = provider.__class__.__name__.replace("Provider", "") - _token_counter.log_tokens(name, tokens[0], tokens[1]) + _token_counter.log_tokens(provider.name, provider.model, tokens[0], tokens[1]) def tokens_reset() -> None: @@ -71,17 +87,89 @@ def tokens_reset() -> None: _token_counter = ThreadSafeTokenCounter() +class TokenPrice(TypedDict): + """ + Defines the necessary information to look up pricing for a given turn. + """ + + provider: str + """The provider name (e.g., "OpenAI", "Anthropic", etc.)""" + model: str + """The model name (e.g., "gpt-3.5-turbo", "claude-2", etc.)""" + cached_input: float + """The cost per user token in USD per million tokens for cached input""" + input: float + """The cost per user token in USD per million tokens""" + output: float + """The cost per assistant token in USD per million tokens""" + + +# Load in pricing pulled from ellmer +f = resources.files("chatlas").joinpath("data/prices.json").read_text(encoding="utf-8") +pricing_list: list[TokenPrice] = orjson.loads(f) + + +def get_token_pricing(name: str, model: str) -> TokenPrice | None: + """ + Get token pricing information given a provider name and model + + Note + ---- + Only a subset of providers and models and currently supported. + The pricing information derives from ellmer. + + Returns + ------- + TokenPrice | None + """ + result = next( + ( + item + for item in pricing_list + if item["provider"] == name and item["model"] == model + ), + None, + ) + if result is None: + warnings.warn( + f"Token pricing for the provider '{name}' and model '{model}' you selected is not available. " + "Please check the provider's documentation." + ) + + return result + + +def compute_price( + name: str, model: str, input_tokens: int, output_tokens: int +) -> float | None: + """ + Compute the cost of a turn. + + Returns + ------- + float | None + The cost of the turn in USD, or None if the cost could not be calculated. + """ + price = get_token_pricing(name, model) + if price is None: + return None + input_price = input_tokens * (price["input"] / 1e6) + output_price = output_tokens * (price["output"] / 1e6) + return input_price + output_price + + def token_usage() -> list[TokenUsage] | None: """ Report on token usage in the current session Call this function to find out the cumulative number of tokens that you - have sent and received in the current session. + have sent and received in the current session. The price will be shown if known Returns ------- list[TokenUsage] | None - A list of dictionaries with the following keys: "name", "input", and "output". + A list of dictionaries with the following keys: "name", "input", "output", and "cost". + If no cost data is available for the name/model combination chosen, then "cost" will be None. If no tokens have been logged, then None is returned. """ return _token_counter.get_usage() diff --git a/chatlas/data/prices.json b/chatlas/data/prices.json new file mode 100644 index 00000000..d9cd7eb9 --- /dev/null +++ b/chatlas/data/prices.json @@ -0,0 +1,264 @@ +[ + { + "provider": "OpenAI", + "model": "gpt-4.5-preview", + "cached_input": 37.5, + "input": 75, + "output": 150 + }, + { + "provider": "OpenAI", + "model": "gpt-4.5-preview-2025-02-27", + "cached_input": 37.5, + "input": 75, + "output": 150 + }, + { + "provider": "OpenAI", + "model": "gpt-4.1", + "cached_input": 0.5, + "input": 2, + "output": 8 + }, + { + "provider": "OpenAI", + "model": "gpt-4.1-mini", + "cached_input": 0.1, + "input": 0.4, + "output": 1.6 + }, + { + "provider": "OpenAI", + "model": "gpt-4.1-nano", + "cached_input": 0.025, + "input": 0.1, + "output": 0.4 + }, + { + "provider": "OpenAI", + "model": "gpt-4.1-2025-04-14", + "cached_input": 0.5, + "input": 2, + "output": 8 + }, + { + "provider": "OpenAI", + "model": "gpt-4.1-mini-2025-04-14", + "cached_input": 0.1, + "input": 0.4, + "output": 1.6 + }, + { + "provider": "OpenAI", + "model": "gpt-4.1-nano-2025-04-14", + "cached_input": 0.025, + "input": 0.1, + "output": 0.4 + }, + { + "provider": "OpenAI", + "model": "gpt-4o", + "cached_input": 1.25, + "input": 2.5, + "output": 10 + }, + { + "provider": "OpenAI", + "model": "gpt-4o-2024-11-20", + "cached_input": 1.25, + "input": 2.5, + "output": 10 + }, + { + "provider": "OpenAI", + "model": "gpt-4o-2024-08-06", + "cached_input": 1.25, + "input": 2.5, + "output": 10 + }, + { + "provider": "OpenAI", + "model": "gpt-4o-2024-05-13", + "input": 5, + "output": 15 + }, + { + "provider": "OpenAI", + "model": "gpt-4o-mini", + "cached_input": 0.075, + "input": 0.15, + "output": 0.6 + }, + { + "provider": "OpenAI", + "model": "gpt-4o-mini-2024-07-18", + "cached_input": 0.075, + "input": 0.15, + "output": 0.6 + }, + { + "provider": "OpenAI", + "model": "o1", + "cached_input": 7.5, + "input": 15, + "output": 60 + }, + { + "provider": "OpenAI", + "model": "o1-2024-12-17", + "cached_input": 7.5, + "input": 15, + "output": 60 + }, + { + "provider": "OpenAI", + "model": "o1-preview-2024-09-12", + "cached_input": 7.5, + "input": 15, + "output": 60 + }, + { + "provider": "OpenAI", + "model": "o1-pro", + "input": 150, + "output": 600 + }, + { + "provider": "OpenAI", + "model": "o1-pro-2025-03-19", + "input": 150, + "output": 600 + }, + { + "provider": "OpenAI", + "model": "o3-mini", + "cached_input": 0.55, + "input": 1.1, + "output": 4.4 + }, + { + "provider": "OpenAI", + "model": "o3-mini-2025-01-31", + "cached_input": 0.55, + "input": 1.1, + "output": 4.4 + }, + { + "provider": "OpenAI", + "model": "o1-mini", + "cached_input": 0.55, + "input": 1.1, + "output": 4.4 + }, + { + "provider": "OpenAI", + "model": "o1-mini-2024-09-12", + "cached_input": 0.55, + "input": 1.1, + "output": 4.4 + }, + { + "provider": "OpenAI", + "model": "gpt-4o-mini-search-preview", + "input": 0.15, + "output": 0.6 + }, + { + "provider": "OpenAI", + "model": "gpt-4o-mini-search-preview-2025-03-11", + "input": 0.15, + "output": 0.6 + }, + { + "provider": "OpenAI", + "model": "gpt-4o-search-preview", + "input": 2.5, + "output": 10 + }, + { + "provider": "OpenAI", + "model": "gpt-4o-search-preview-2025-03-11", + "input": 2.5, + "output": 10 + }, + { + "provider": "OpenAI", + "model": "computer-use-preview", + "input": 3, + "output": 12 + }, + { + "provider": "OpenAI", + "model": "computer-use-preview-2025-03-11", + "input": 3, + "output": 12 + }, + { + "provider": "Anthropic", + "model": "claude-opus-4", + "cached_input": 1.5, + "input": 15, + "output": 75 + }, + { + "provider": "Anthropic", + "model": "claude-sonnet-4", + "cached_input": 0.3, + "input": 3, + "output": 15 + }, + { + "provider": "Anthropic", + "model": "claude-3-7-sonnet", + "cached_input": 0.3, + "input": 3, + "output": 15 + }, + { + "provider": "Anthropic", + "model": "claude-3-5-sonnet", + "cached_input": 0.3, + "input": 3, + "output": 15 + }, + { + "provider": "Anthropic", + "model": "claude-3-5-haiku", + "cached_input": 0.08, + "input": 0.8, + "output": 4 + }, + { + "provider": "Anthropic", + "model": "claude-3-opus", + "cached_input": 1.5, + "input": 15, + "output": 75 + }, + { + "provider": "Anthropic", + "model": "claude-3-haiku", + "cached_input": 0.03, + "input": 0.25, + "output": 1.25 + }, + { + "provider": "Google/Gemini", + "model": "gemini-2.0-flash", + "cached_input": 0.025, + "input": 0.1, + "output": 0.4 + }, + { + "provider": "Google/Gemini", + "model": "gemini-2.0-flash-lite", + "input": 0.075, + "output": 0.3 + }, + { + "provider": "Google/Gemini", + "model": "gemini-1.5-flash", + "input": 0.3, + "output": 0.075 + } +] diff --git a/tests/__snapshots__/test_chat.ambr b/tests/__snapshots__/test_chat.ambr index 88b3a275..f910d543 100644 --- a/tests/__snapshots__/test_chat.ambr +++ b/tests/__snapshots__/test_chat.ambr @@ -32,7 +32,7 @@ # --- # name: test_basic_repr ''' - + diff --git a/tests/test_chat.py b/tests/test_chat.py index 2a6f3bd7..8a942962 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -2,8 +2,6 @@ import tempfile import pytest -from pydantic import BaseModel - from chatlas import ( ChatOpenAI, ContentToolRequest, @@ -12,6 +10,7 @@ Turn, ) from chatlas._chat import ToolFailureWarning +from pydantic import BaseModel def test_simple_batch_chat(): @@ -31,10 +30,12 @@ async def test_simple_async_batch_chat(): def test_simple_streaming_chat(): chat = ChatOpenAI() - res = chat.stream(""" + res = chat.stream( + """ What are the canonical colors of the ROYGBIV rainbow? Put each colour on its own line. Don't use punctuation. - """) + """ + ) chunks = [chunk for chunk in res] assert len(chunks) > 2 result = "".join(chunks) @@ -49,10 +50,12 @@ def test_simple_streaming_chat(): @pytest.mark.asyncio async def test_simple_streaming_chat_async(): chat = ChatOpenAI() - res = await chat.stream_async(""" + res = await chat.stream_async( + """ What are the canonical colors of the ROYGBIV rainbow? Put each colour on its own line. Don't use punctuation. - """) + """ + ) chunks = [chunk async for chunk in res] assert len(chunks) > 2 result = "".join(chunks) @@ -288,3 +291,57 @@ def test_tool(user: str) -> str: assert str(response).lower() == "joe unknown hadley red" assert "Joe denied the request." in capsys.readouterr().out + + +def test_get_cost(): + chat = ChatOpenAI(api_key="fake_key") + chat.set_turns( + [ + Turn(role="user", contents="Hi"), + Turn(role="assistant", contents="Hello", tokens=(2, 10)), + Turn(role="user", contents="Hi"), + Turn(role="assistant", contents="Hello", tokens=(14, 10)), + ] + ) + + with pytest.raises( + ValueError, + match=re.escape( + "Expected `options` to be one of 'all' or 'last', not 'bad_option'" + ), + ): + chat.get_cost(options="bad_option") + + # Checking that these have the right form vs. the actual calculation because the price may change + cost = chat.get_cost(options="all") + assert isinstance(cost, float) + assert cost > 0 + + last = chat.get_cost(options="last") + assert isinstance(last, float) + assert last > 0 + + assert cost > last + + byoc = (2.0, 3.0) + + cost2 = chat.get_cost(options="all", token_price=byoc) + assert cost2 == 0.000092 + + last2 = chat.get_cost(options="last", token_price=byoc) + assert last2 == 0.00003 + + chat2 = ChatOpenAI(api_key="fake_key", model="BADBAD") + chat2.set_turns( + [ + Turn(role="user", contents="Hi"), + Turn(role="assistant", contents="Hello", tokens=(2, 10)), + Turn(role="user", contents="Hi"), + Turn(role="assistant", contents="Hello", tokens=(14, 10)), + ] + ) + with pytest.raises( + KeyError, + match="We could not locate pricing information for model 'BADBAD' from provider 'OpenAI'. If you know the pricing for this model, specify it in `token_price`.", + ): + chat2.get_cost(options="all") diff --git a/tests/test_provider_azure.py b/tests/test_provider_azure.py index 8eb2850e..eef856c4 100644 --- a/tests/test_provider_azure.py +++ b/tests/test_provider_azure.py @@ -22,6 +22,7 @@ def test_azure_simple_request(): turn = chat.get_last_turn() assert turn is not None assert turn.tokens == (27, 2) + assert chat.provider.name == "OpenAIAzure" @pytest.mark.asyncio diff --git a/tests/test_provider_google.py b/tests/test_provider_google.py index 42e98c2b..d9ea896f 100644 --- a/tests/test_provider_google.py +++ b/tests/test_provider_google.py @@ -2,7 +2,7 @@ import pytest import requests -from chatlas import ChatGoogle +from chatlas import ChatGoogle, ChatVertex from google.genai.errors import APIError from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential @@ -33,6 +33,31 @@ def test_google_simple_request(): assert turn is not None assert turn.tokens == (16, 2) assert turn.finish_reason == "STOP" + assert chat.provider.name == "Google/Gemini" + + +def test_vertex_simple_request(): + chat = ChatVertex( + system_prompt="Be as terse as possible; no punctuation", + ) + chat.chat("What is 1 + 1?") + turn = chat.get_last_turn() + assert turn is not None + assert turn.tokens == (16, 2) + assert turn.finish_reason == "STOP" + assert chat.provider.name == "Google/Vertex" + + +def test_name_setting(): + chat = ChatGoogle( + system_prompt="Be as terse as possible; no punctuation", + ) + assert chat.provider.name == "Google/Gemini" + + chat = ChatVertex( + system_prompt="Be as terse as possible; no punctuation", + ) + assert chat.provider.name == "Google/Vertex" @pytest.mark.asyncio diff --git a/tests/test_provider_snowflake.py b/tests/test_provider_snowflake.py index 4d7d899a..16006aad 100644 --- a/tests/test_provider_snowflake.py +++ b/tests/test_provider_snowflake.py @@ -25,6 +25,7 @@ def test_openai_simple_request(): chat.chat("What is 1 + 1?") turn = chat.get_last_turn() assert turn is not None + assert chat.provider.name == "Snowflake" # No token / finish_reason info available? # assert turn.tokens is not None diff --git a/tests/test_tokens.py b/tests/test_tokens.py index 2fe8bf3c..0c7e2deb 100644 --- a/tests/test_tokens.py +++ b/tests/test_tokens.py @@ -1,11 +1,18 @@ +import pytest from chatlas import ChatAnthropic, ChatGoogle, ChatOpenAI, Turn from chatlas._openai import OpenAIAzureProvider, OpenAIProvider -from chatlas._tokens import token_usage, tokens_log, tokens_reset +from chatlas._tokens import ( + compute_price, + get_token_pricing, + token_usage, + tokens_log, + tokens_reset, +) def test_tokens_method(): - chat = ChatOpenAI() - assert chat.tokens(values="discrete") == [] + chat = ChatOpenAI(api_key="fake_key") + assert len(chat.get_tokens()) == 0 chat = ChatOpenAI() chat.set_turns( @@ -15,7 +22,10 @@ def test_tokens_method(): ] ) - assert chat.tokens(values="discrete") == [2, 10] + assert chat.get_tokens() == [ + {"role": "user", "tokens": 2, "tokens_total": 2}, + {"role": "assistant", "tokens": 10, "tokens_total": 10}, + ] chat = ChatOpenAI() chat.set_turns( @@ -24,11 +34,15 @@ def test_tokens_method(): Turn(role="assistant", contents="Hello", tokens=(2, 10)), Turn(role="user", contents="Hi"), Turn(role="assistant", contents="Hello", tokens=(14, 10)), - ] + ], ) - assert chat.tokens(values="discrete") == [2, 10, 2, 10] - assert chat.tokens(values="cumulative") == [None, (2, 10), None, (14, 10)] + assert chat.get_tokens() == [ + {"role": "user", "tokens": 2, "tokens_total": 2}, + {"role": "assistant", "tokens": 10, "tokens_total": 10}, + {"role": "user", "tokens": 2, "tokens_total": 14}, + {"role": "assistant", "tokens": 10, "tokens_total": 10}, + ] def test_token_count_method(): @@ -42,6 +56,35 @@ def test_token_count_method(): assert chat.token_count("What is 1 + 1?") == 9 +def test_get_token_prices(): + chat = ChatOpenAI(model="o1-mini") + pricing = get_token_pricing(chat.provider.name, chat.provider.model) + assert pricing["provider"] == "OpenAI" + assert pricing["model"] == "o1-mini" + assert isinstance(pricing["cached_input"], float) + assert isinstance(pricing["input"], float) + assert isinstance(pricing["output"], float) + + with pytest.warns( + match="Token pricing for the provider 'OpenAI' and model 'ABCD' you selected is not available. " + "Please check the provider's documentation." + ): + chat = ChatOpenAI(model="ABCD") + pricing = get_token_pricing(chat.provider.name, chat.provider.model) + assert pricing is None + + +def test_compute_price(): + chat = ChatOpenAI(model="o1-mini") + price = compute_price(chat.provider.name, chat.provider.model, 10, 50) + assert isinstance(price, float) + assert price > 0 + + chat = ChatOpenAI(model="ABCD") + price = compute_price(chat.provider.name, chat.provider.model, 10, 50) + assert price is None + + def test_usage_is_none(): tokens_reset() assert token_usage() is None @@ -50,8 +93,7 @@ def test_usage_is_none(): def test_can_retrieve_and_log_tokens(): tokens_reset() - provider = OpenAIProvider(model="foo") - + provider = OpenAIProvider(api_key="fake_key", model="gpt-4.1") tokens_log(provider, (10, 50)) tokens_log(provider, (0, 10)) usage = token_usage() @@ -60,8 +102,11 @@ def test_can_retrieve_and_log_tokens(): assert usage[0]["name"] == "OpenAI" assert usage[0]["input"] == 10 assert usage[0]["output"] == 60 + assert usage[0]["cost"] is not None - provider2 = OpenAIAzureProvider(endpoint="foo", api_version="bar") + provider2 = OpenAIAzureProvider( + api_key="fake_key", endpoint="foo", deployment_id="test", api_version="bar" + ) tokens_log(provider2, (5, 25)) usage = token_usage() @@ -70,5 +115,6 @@ def test_can_retrieve_and_log_tokens(): assert usage[1]["name"] == "OpenAIAzure" assert usage[1]["input"] == 5 assert usage[1]["output"] == 25 + assert usage[1]["cost"] is None tokens_reset()