From 370edbf4d6570ee6eb231c252651966a58fc261b Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Wed, 11 Jun 2025 10:13:13 -0400 Subject: [PATCH 01/59] Added get_tokens --- chatlas/_chat.py | 87 ++++++++++++++++++++------------------------ tests/test_tokens.py | 26 +++++++++---- 2 files changed, 57 insertions(+), 56 deletions(-) diff --git a/chatlas/_chat.py b/chatlas/_chat.py index bda821ca..0d04de3e 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -188,43 +188,18 @@ def system_prompt(self, value: str | None): if value is not None: self._turns.insert(0, Turn("system", value)) - @overload - def tokens(self) -> list[tuple[int, int] | None]: ... - - @overload - def tokens( - self, - values: Literal["cumulative"], - ) -> list[tuple[int, int] | None]: ... - - @overload - def tokens( - self, - values: Literal["discrete"], - ) -> list[int]: ... - - def tokens( - self, - values: Literal["cumulative", "discrete"] = "discrete", - ) -> list[int] | list[tuple[int, int] | None]: + def get_tokens(self) -> list[dict[str, int | str]]: """ Get the tokens for each turn in the chat. - Parameters - ---------- - values - If "cumulative" (the default), the result can be summed to get the - chat's overall token usage (helpful for computing overall cost of - the chat). If "discrete", the result can be summed to get the number of - tokens the turns will cost to generate the next response (helpful - for estimating cost of the next response, or for determining if you - are about to exceed the token limit). - Returns ------- - list[int] - A list of token counts for each (non-system) turn in the chat. The - 1st turn includes the tokens count for the system prompt (if any). + list[dict[str, str | int]] + A list of dictionaries with the token counts for each (non-system) turn + in the chat. + `tokens` represents the new tokens used in the turn. + `tokens_total` represents the total tokens used in the turn. + Ex. A new user input of 2 tokens is sent, plus 10 tokens of context from prior turns (input and output) would have a `tokens_total` of 12. Raises ------ @@ -238,9 +213,6 @@ def tokens( turns = self.get_turns(include_system_prompt=False) - if values == "cumulative": - return [turn.tokens for turn in turns] - if len(turns) == 0: return [] @@ -276,12 +248,21 @@ def tokens( "Expected the 1st assistant turn to contain token counts. " + err_info ) - res: list[int] = [ + res: list[dict[str, int | str]] = [ # Implied token count for the 1st user input - turns[1].tokens[0], + { + "role": "user", + "tokens": turns[1].tokens[0], + "tokens_total": turns[1].tokens[0], + }, # The token count for the 1st assistant response - turns[1].tokens[1], + { + "role": "assistant", + "tokens": turns[1].tokens[1], + "tokens_total": turns[1].tokens[1], + }, ] + for i in range(1, len(turns) - 1, 2): ti = turns[i] tj = turns[i + 2] @@ -296,10 +277,20 @@ def tokens( ) res.extend( [ - # Implied token count for the user input - tj.tokens[0] - sum(ti.tokens), - # The token count for the assistant response - tj.tokens[1], + { + "role": "user", + # Implied token count for the user input + "tokens": tj.tokens[0] - sum(ti.tokens), + # Total tokens = Total User Tokens for the Trn = Distinct new tokens + context sent + "tokens_total": tj.tokens[0], + }, + { + "role": "assistant", + # The token count for the assistant response + "tokens": tj.tokens[1], + # Total tokens = Total Assistant tokens used in the turn + "tokens_total": tj.tokens[1], + }, ] ) @@ -706,9 +697,9 @@ def stream( kwargs=kwargs, ) - def wrapper() -> Generator[ - str | ContentToolRequest | ContentToolResult, None, None - ]: + def wrapper() -> ( + Generator[str | ContentToolRequest | ContentToolResult, None, None] + ): with display: for chunk in generator: yield chunk @@ -770,9 +761,9 @@ async def stream_async( display = self._markdown_display(echo=echo) - async def wrapper() -> AsyncGenerator[ - str | ContentToolRequest | ContentToolResult, None - ]: + async def wrapper() -> ( + AsyncGenerator[str | ContentToolRequest | ContentToolResult, None] + ): with display: async for chunk in self._chat_impl_async( turn, diff --git a/tests/test_tokens.py b/tests/test_tokens.py index 47d260e8..8dc80ffa 100644 --- a/tests/test_tokens.py +++ b/tests/test_tokens.py @@ -4,8 +4,8 @@ def test_tokens_method(): - chat = ChatOpenAI() - assert chat.tokens(values="discrete") == [] + chat = ChatOpenAI(api_key="fake_key") + assert len(chat.get_tokens()) == 0 chat = ChatOpenAI( turns=[ @@ -14,19 +14,27 @@ def test_tokens_method(): ] ) - assert chat.tokens(values="discrete") == [2, 10] + assert chat.get_tokens() == [ + {"role": "user", "tokens": 2, "tokens_total": 2}, + {"role": "assistant", "tokens": 10, "tokens_total": 10}, + ] chat = ChatOpenAI( + api_key="fake_key", turns=[ Turn(role="user", contents="Hi"), Turn(role="assistant", contents="Hello", tokens=(2, 10)), Turn(role="user", contents="Hi"), Turn(role="assistant", contents="Hello", tokens=(14, 10)), - ] + ], ) - assert chat.tokens(values="discrete") == [2, 10, 2, 10] - assert chat.tokens(values="cumulative") == [None, (2, 10), None, (14, 10)] + assert chat.get_tokens() == [ + {"role": "user", "tokens": 2, "tokens_total": 2}, + {"role": "assistant", "tokens": 10, "tokens_total": 10}, + {"role": "user", "tokens": 2, "tokens_total": 14}, + {"role": "assistant", "tokens": 10, "tokens_total": 10}, + ] def test_token_count_method(): @@ -48,7 +56,7 @@ def test_usage_is_none(): def test_can_retrieve_and_log_tokens(): tokens_reset() - provider = OpenAIProvider(model="foo") + provider = OpenAIProvider(api_key="fake_key", model="foo") tokens_log(provider, (10, 50)) tokens_log(provider, (0, 10)) @@ -59,7 +67,9 @@ def test_can_retrieve_and_log_tokens(): assert usage[0]["input"] == 10 assert usage[0]["output"] == 60 - provider2 = OpenAIAzureProvider(endpoint="foo", api_version="bar") + provider2 = OpenAIAzureProvider( + api_key="fake_key", endpoint="foo", api_version="bar" + ) tokens_log(provider2, (5, 25)) usage = token_usage() From 30694be127d9d292967fa92f25e06829dacd6756 Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Thu, 12 Jun 2025 12:32:56 -0400 Subject: [PATCH 02/59] changed type of token dict to TypedDict --- chatlas/_chat.py | 55 ++++++++- chatlas/prices.json | 264 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 316 insertions(+), 3 deletions(-) create mode 100644 chatlas/prices.json diff --git a/chatlas/_chat.py b/chatlas/_chat.py index 0d04de3e..0fd5c1c5 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -52,6 +52,15 @@ class AnyTypeDict(TypedDict, total=False): pass +class TokensDict(TypedDict): + """ + A TypedDict representing the token counts for a turn in the chat. + """ + + role: Literal["user", "assistant"] + tokens: int + tokens_total: int + SubmitInputArgsT = TypeVar("SubmitInputArgsT", bound=AnyTypeDict) """ @@ -63,6 +72,8 @@ class AnyTypeDict(TypedDict, total=False): EchoOptions = Literal["output", "all", "none", "text"] +CostOptions = Literal["all", "last"] + class Chat(Generic[SubmitInputArgsT, CompletionT]): """ @@ -188,13 +199,13 @@ def system_prompt(self, value: str | None): if value is not None: self._turns.insert(0, Turn("system", value)) - def get_tokens(self) -> list[dict[str, int | str]]: + def get_tokens(self) -> list[TokensDict]: """ Get the tokens for each turn in the chat. Returns ------- - list[dict[str, str | int]] + list[TokensDict] A list of dictionaries with the token counts for each (non-system) turn in the chat. `tokens` represents the new tokens used in the turn. @@ -248,7 +259,7 @@ def get_tokens(self) -> list[dict[str, int | str]]: "Expected the 1st assistant turn to contain token counts. " + err_info ) - res: list[dict[str, int | str]] = [ + res: list[TokensDict] = [ # Implied token count for the 1st user input { "role": "user", @@ -296,6 +307,44 @@ def get_tokens(self) -> list[dict[str, int | str]]: return res + def get_cost( + self, + options: CostOptions = "all", + ) -> float: + """ + Get the cost of the chat. + + Parameters + ---------- + options + One of the following (default is "all"): + - `"all"`: Return the total cost of all turns in the chat. + - `"last"`: Return the cost of the last turn in the chat. + + Returns + ------- + float + The cost of the chat, in USD. + """ + + # Look up token cost for user and input tokens based on the provider and model + + if options == "last": + # Multiply last user token count by user token cost + # Multiply last assistant token count by assistant token cost + # Add + # Return + + if options == "all": + # Get all the user token counts + # Get all the assistant token counts + # Multiply all the user token counts by the user token cost + # Multiply all the assistant token counts by the assistant token cost + # Add them together and return + + + + def token_count( self, *args: Content | str, diff --git a/chatlas/prices.json b/chatlas/prices.json new file mode 100644 index 00000000..d9cd7eb9 --- /dev/null +++ b/chatlas/prices.json @@ -0,0 +1,264 @@ +[ + { + "provider": "OpenAI", + "model": "gpt-4.5-preview", + "cached_input": 37.5, + "input": 75, + "output": 150 + }, + { + "provider": "OpenAI", + "model": "gpt-4.5-preview-2025-02-27", + "cached_input": 37.5, + "input": 75, + "output": 150 + }, + { + "provider": "OpenAI", + "model": "gpt-4.1", + "cached_input": 0.5, + "input": 2, + "output": 8 + }, + { + "provider": "OpenAI", + "model": "gpt-4.1-mini", + "cached_input": 0.1, + "input": 0.4, + "output": 1.6 + }, + { + "provider": "OpenAI", + "model": "gpt-4.1-nano", + "cached_input": 0.025, + "input": 0.1, + "output": 0.4 + }, + { + "provider": "OpenAI", + "model": "gpt-4.1-2025-04-14", + "cached_input": 0.5, + "input": 2, + "output": 8 + }, + { + "provider": "OpenAI", + "model": "gpt-4.1-mini-2025-04-14", + "cached_input": 0.1, + "input": 0.4, + "output": 1.6 + }, + { + "provider": "OpenAI", + "model": "gpt-4.1-nano-2025-04-14", + "cached_input": 0.025, + "input": 0.1, + "output": 0.4 + }, + { + "provider": "OpenAI", + "model": "gpt-4o", + "cached_input": 1.25, + "input": 2.5, + "output": 10 + }, + { + "provider": "OpenAI", + "model": "gpt-4o-2024-11-20", + "cached_input": 1.25, + "input": 2.5, + "output": 10 + }, + { + "provider": "OpenAI", + "model": "gpt-4o-2024-08-06", + "cached_input": 1.25, + "input": 2.5, + "output": 10 + }, + { + "provider": "OpenAI", + "model": "gpt-4o-2024-05-13", + "input": 5, + "output": 15 + }, + { + "provider": "OpenAI", + "model": "gpt-4o-mini", + "cached_input": 0.075, + "input": 0.15, + "output": 0.6 + }, + { + "provider": "OpenAI", + "model": "gpt-4o-mini-2024-07-18", + "cached_input": 0.075, + "input": 0.15, + "output": 0.6 + }, + { + "provider": "OpenAI", + "model": "o1", + "cached_input": 7.5, + "input": 15, + "output": 60 + }, + { + "provider": "OpenAI", + "model": "o1-2024-12-17", + "cached_input": 7.5, + "input": 15, + "output": 60 + }, + { + "provider": "OpenAI", + "model": "o1-preview-2024-09-12", + "cached_input": 7.5, + "input": 15, + "output": 60 + }, + { + "provider": "OpenAI", + "model": "o1-pro", + "input": 150, + "output": 600 + }, + { + "provider": "OpenAI", + "model": "o1-pro-2025-03-19", + "input": 150, + "output": 600 + }, + { + "provider": "OpenAI", + "model": "o3-mini", + "cached_input": 0.55, + "input": 1.1, + "output": 4.4 + }, + { + "provider": "OpenAI", + "model": "o3-mini-2025-01-31", + "cached_input": 0.55, + "input": 1.1, + "output": 4.4 + }, + { + "provider": "OpenAI", + "model": "o1-mini", + "cached_input": 0.55, + "input": 1.1, + "output": 4.4 + }, + { + "provider": "OpenAI", + "model": "o1-mini-2024-09-12", + "cached_input": 0.55, + "input": 1.1, + "output": 4.4 + }, + { + "provider": "OpenAI", + "model": "gpt-4o-mini-search-preview", + "input": 0.15, + "output": 0.6 + }, + { + "provider": "OpenAI", + "model": "gpt-4o-mini-search-preview-2025-03-11", + "input": 0.15, + "output": 0.6 + }, + { + "provider": "OpenAI", + "model": "gpt-4o-search-preview", + "input": 2.5, + "output": 10 + }, + { + "provider": "OpenAI", + "model": "gpt-4o-search-preview-2025-03-11", + "input": 2.5, + "output": 10 + }, + { + "provider": "OpenAI", + "model": "computer-use-preview", + "input": 3, + "output": 12 + }, + { + "provider": "OpenAI", + "model": "computer-use-preview-2025-03-11", + "input": 3, + "output": 12 + }, + { + "provider": "Anthropic", + "model": "claude-opus-4", + "cached_input": 1.5, + "input": 15, + "output": 75 + }, + { + "provider": "Anthropic", + "model": "claude-sonnet-4", + "cached_input": 0.3, + "input": 3, + "output": 15 + }, + { + "provider": "Anthropic", + "model": "claude-3-7-sonnet", + "cached_input": 0.3, + "input": 3, + "output": 15 + }, + { + "provider": "Anthropic", + "model": "claude-3-5-sonnet", + "cached_input": 0.3, + "input": 3, + "output": 15 + }, + { + "provider": "Anthropic", + "model": "claude-3-5-haiku", + "cached_input": 0.08, + "input": 0.8, + "output": 4 + }, + { + "provider": "Anthropic", + "model": "claude-3-opus", + "cached_input": 1.5, + "input": 15, + "output": 75 + }, + { + "provider": "Anthropic", + "model": "claude-3-haiku", + "cached_input": 0.03, + "input": 0.25, + "output": 1.25 + }, + { + "provider": "Google/Gemini", + "model": "gemini-2.0-flash", + "cached_input": 0.025, + "input": 0.1, + "output": 0.4 + }, + { + "provider": "Google/Gemini", + "model": "gemini-2.0-flash-lite", + "input": 0.075, + "output": 0.3 + }, + { + "provider": "Google/Gemini", + "model": "gemini-1.5-flash", + "input": 0.3, + "output": 0.075 + } +] From c0c8e67dc05455396a81188b92b1b7242df0b7ce Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Mon, 23 Jun 2025 09:36:49 -0400 Subject: [PATCH 03/59] Adding name to Providers where pricing is supported --- chatlas/_anthropic.py | 3 ++- chatlas/_chat.py | 27 ++++++++++++++++----- chatlas/_google.py | 4 ++++ chatlas/_openai.py | 2 ++ chatlas/_tokens.py | 44 ++++++++++++++++++++++++++++++++--- tests/test_provider_google.py | 25 +++++++++++++++++++- tests/test_tokens.py | 12 ++++++++-- 7 files changed, 104 insertions(+), 13 deletions(-) diff --git a/chatlas/_anthropic.py b/chatlas/_anthropic.py index 994d483a..3b71797e 100644 --- a/chatlas/_anthropic.py +++ b/chatlas/_anthropic.py @@ -200,7 +200,7 @@ def __init__( "`ChatAnthropic()` requires the `anthropic` package. " "You can install it with 'pip install anthropic'." ) - + self.name = "Anthropic" self._model = model self._max_tokens = max_tokens @@ -711,6 +711,7 @@ def ChatBedrockAnthropic( base_url=base_url, kwargs=kwargs, ), + pricing_provider="Anthropic Bedrock", turns=normalize_turns( turns or [], system_prompt, diff --git a/chatlas/_chat.py b/chatlas/_chat.py index 0fd5c1c5..95cb806a 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -52,6 +52,7 @@ class AnyTypeDict(TypedDict, total=False): pass + class TokensDict(TypedDict): """ A TypedDict representing the token counts for a turn in the chat. @@ -93,6 +94,7 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]): def __init__( self, provider: Provider, + pricing_provider: str | None = None, turns: Optional[Sequence[Turn]] = None, ): """ @@ -117,6 +119,20 @@ def __init__( "css_styles": {}, } + def get_provider_name(self) -> str: + """ + Get the name of the provider. + + Returns + ------- + str + The name of the provider. + """ + # Use the class name of the provider, removing "Provider" suffix + # Ex. "OpenAIProvider" -> "ChatOpenAI" + # Ex. "AnthropicProvider" -> "ChatAnthropic" + name = self.provider.__class__.__name__.replace("Provider", "") + def get_turns( self, *, @@ -308,8 +324,8 @@ def get_tokens(self) -> list[TokensDict]: return res def get_cost( - self, - options: CostOptions = "all", + self, + options: CostOptions = "all", ) -> float: """ Get the cost of the chat. @@ -328,12 +344,13 @@ def get_cost( """ # Look up token cost for user and input tokens based on the provider and model - + if options == "last": # Multiply last user token count by user token cost # Multiply last assistant token count by assistant token cost # Add # Return + print("hi") if options == "all": # Get all the user token counts @@ -341,9 +358,7 @@ def get_cost( # Multiply all the user token counts by the user token cost # Multiply all the assistant token counts by the assistant token cost # Add them together and return - - - + print("hello") def token_count( self, diff --git a/chatlas/_google.py b/chatlas/_google.py index f9f5b35d..532cf621 100644 --- a/chatlas/_google.py +++ b/chatlas/_google.py @@ -176,6 +176,10 @@ def __init__( ) self._model = model + if kwargs.get("vertexai"): + self.name = "Google/Vertex" + else: + self.name = "Google/Gemini" kwargs_full: "ChatClientArgs" = { "api_key": api_key, diff --git a/chatlas/_openai.py b/chatlas/_openai.py index 84bcd169..f3490af7 100644 --- a/chatlas/_openai.py +++ b/chatlas/_openai.py @@ -192,6 +192,7 @@ def __init__( self._model = model self._seed = seed + self.name = "OpenAI" kwargs_full: "ChatClientArgs" = { "api_key": api_key, @@ -677,6 +678,7 @@ def __init__( self._model = deployment_id self._seed = seed + self.name = "OpenAIAzure" kwargs_full: "ChatAzureClientArgs" = { "azure_endpoint": endpoint, diff --git a/chatlas/_tokens.py b/chatlas/_tokens.py index 35bb0bd6..9ff9c080 100644 --- a/chatlas/_tokens.py +++ b/chatlas/_tokens.py @@ -55,11 +55,21 @@ def get_usage(self) -> list[TokenUsage] | None: _token_counter = ThreadSafeTokenCounter() -def tokens_log(provider: "Provider", tokens: tuple[int, int]) -> None: +def tokens_log(provider: Provider, tokens: tuple[int, int]) -> None: """ Log token usage for a provider in a thread-safe manner. """ - name = provider.__class__.__name__.replace("Provider", "") + + if hasattr(provider, "name"): + # Use the provider's name if it has one + name = provider.name + else: + # Fallback to class name if provider does not have a name attribute + logger.info( + f"Provider {provider.__class__.__name__} does not have a 'name' attribute. " + "Using class name instead." + ) + name = provider.__class__.__name__.replace("Provider", "") _token_counter.log_tokens(name, tokens[0], tokens[1]) @@ -71,12 +81,38 @@ def tokens_reset() -> None: _token_counter = ThreadSafeTokenCounter() +def get_token_cost( + name: str, model: str, input_tokens: int, output_tokens: int +) -> float | None: + """ + Get the cost of tokens for a given provider and model. + + Parameters + ---------- + name : Provider + The provider instance. + model : str + The model name. + input_tokens : int + The number of input tokens. + output_tokens : int + The number of output tokens. + + Returns + ------- + float + The cost of the tokens, or None if the cost is not known. + """ + + # return get_token_cost(provider.__name__, model, input_tokens, output_tokens) + + def token_usage() -> list[TokenUsage] | None: """ Report on token usage in the current session Call this function to find out the cumulative number of tokens that you - have sent and received in the current session. + have sent and received in the current session. The price will be shown if known Returns ------- @@ -84,4 +120,6 @@ def token_usage() -> list[TokenUsage] | None: A list of dictionaries with the following keys: "name", "input", and "output". If no tokens have been logged, then None is returned. """ + _token_counter.get_usage() + return _token_counter.get_usage() diff --git a/tests/test_provider_google.py b/tests/test_provider_google.py index 42e98c2b..1352b9d6 100644 --- a/tests/test_provider_google.py +++ b/tests/test_provider_google.py @@ -2,7 +2,7 @@ import pytest import requests -from chatlas import ChatGoogle +from chatlas import ChatGoogle, ChatVertex from google.genai.errors import APIError from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential @@ -35,6 +35,29 @@ def test_google_simple_request(): assert turn.finish_reason == "STOP" +def test_vertex_simple_request(): + chat = ChatVertex( + system_prompt="Be as terse as possible; no punctuation", + ) + chat.chat("What is 1 + 1?") + turn = chat.get_last_turn() + assert turn is not None + assert turn.tokens == (16, 2) + assert turn.finish_reason == "STOP" + + +def test_name_setting(): + chat = ChatGoogle( + system_prompt="Be as terse as possible; no punctuation", + ) + assert chat.provider.name == "Google/Gemini" + + chat = ChatVertex( + system_prompt="Be as terse as possible; no punctuation", + ) + assert chat.provider.name == "Google/Vertex" + + @pytest.mark.asyncio async def test_google_simple_streaming_request(): chat = ChatGoogle( diff --git a/tests/test_tokens.py b/tests/test_tokens.py index 8dc80ffa..5b20edff 100644 --- a/tests/test_tokens.py +++ b/tests/test_tokens.py @@ -1,6 +1,7 @@ from chatlas import ChatAnthropic, ChatGoogle, ChatOpenAI, Turn from chatlas._openai import OpenAIAzureProvider, OpenAIProvider from chatlas._tokens import token_usage, tokens_log, tokens_reset +from pytest import MonkeyPatch def test_tokens_method(): @@ -57,7 +58,6 @@ def test_can_retrieve_and_log_tokens(): tokens_reset() provider = OpenAIProvider(api_key="fake_key", model="foo") - tokens_log(provider, (10, 50)) tokens_log(provider, (0, 10)) usage = token_usage() @@ -71,11 +71,19 @@ def test_can_retrieve_and_log_tokens(): api_key="fake_key", endpoint="foo", api_version="bar" ) + # Check that the provider has a name attribute to start + assert provider2.name == "OpenAIAzure" + + delattr(provider2, "name") # Ensure no name attribute + assert hasattr(provider2, "name") is False + tokens_log(provider2, (5, 25)) usage = token_usage() assert usage is not None assert len(usage) == 2 - assert usage[1]["name"] == "OpenAIAzure" + assert ( + usage[1]["name"] == "OpenAIAzure" + ) # Check that the name is set correctly for Providers without a name pre-set assert usage[1]["input"] == 5 assert usage[1]["output"] == 25 From 6ea4ece7408bbea55e2e22fb6f13e26441efa646 Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Mon, 23 Jun 2025 14:54:46 -0400 Subject: [PATCH 04/59] Correct kwargs.get --- chatlas/_google.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chatlas/_google.py b/chatlas/_google.py index 532cf621..8f00839c 100644 --- a/chatlas/_google.py +++ b/chatlas/_google.py @@ -176,7 +176,7 @@ def __init__( ) self._model = model - if kwargs.get("vertexai"): + if kwargs and kwargs.get("vertexai"): self.name = "Google/Vertex" else: self.name = "Google/Gemini" From 56463d4ae3e6b4f6fd6a6e70e03bae74b428f6cb Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Mon, 23 Jun 2025 15:11:49 -0400 Subject: [PATCH 05/59] removing unused import --- chatlas/_chat.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/chatlas/_chat.py b/chatlas/_chat.py index 95cb806a..5374ad5e 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -94,7 +94,6 @@ class Chat(Generic[SubmitInputArgsT, CompletionT]): def __init__( self, provider: Provider, - pricing_provider: str | None = None, turns: Optional[Sequence[Turn]] = None, ): """ @@ -119,20 +118,6 @@ def __init__( "css_styles": {}, } - def get_provider_name(self) -> str: - """ - Get the name of the provider. - - Returns - ------- - str - The name of the provider. - """ - # Use the class name of the provider, removing "Provider" suffix - # Ex. "OpenAIProvider" -> "ChatOpenAI" - # Ex. "AnthropicProvider" -> "ChatAnthropic" - name = self.provider.__class__.__name__.replace("Provider", "") - def get_turns( self, *, From 64d1b5cd6200645d277a275f849633e859b2f99d Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Tue, 24 Jun 2025 10:43:47 -0400 Subject: [PATCH 06/59] Adding start of pricing, fixing test --- chatlas/_anthropic.py | 1 - chatlas/_chat.py | 46 +++++++++++++++++++++++++++++++++- chatlas/{ => data}/prices.json | 0 tests/test_tokens.py | 1 - 4 files changed, 45 insertions(+), 3 deletions(-) rename chatlas/{ => data}/prices.json (100%) diff --git a/chatlas/_anthropic.py b/chatlas/_anthropic.py index 3b71797e..59920cd1 100644 --- a/chatlas/_anthropic.py +++ b/chatlas/_anthropic.py @@ -711,7 +711,6 @@ def ChatBedrockAnthropic( base_url=base_url, kwargs=kwargs, ), - pricing_provider="Anthropic Bedrock", turns=normalize_turns( turns or [], system_prompt, diff --git a/chatlas/_chat.py b/chatlas/_chat.py index 5374ad5e..7dd3cc0b 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -6,6 +6,7 @@ import sys import traceback import warnings + from pathlib import Path from threading import Thread from typing import ( @@ -23,7 +24,8 @@ TypeVar, overload, ) - +import orjson +import importlib from pydantic import BaseModel from ._callbacks import CallbackManager @@ -48,6 +50,13 @@ from ._typing_extensions import TypedDict from ._utils import html_escape, wrap_async +f = ( + importlib.resources.files("chatlas") + .joinpath("data/prices.json") + .read_text(encoding="utf-8") +) +prices_json = orjson.loads(f) + class AnyTypeDict(TypedDict, total=False): pass @@ -308,6 +317,37 @@ def get_tokens(self) -> list[TokensDict]: return res + def get_token_pricing(self) -> dict[str, str | float]: + """ + Get the token pricing for the chat if available based on the prices.json file. + + Returns + ------- + dict[str, str | float] + A dictionary with the token pricing for the chat. The keys are: + - `"provider"`: The provider name (e.g., "OpenAI", "Anthropic", etc.). + - `model`: The model name (e.g., "gpt-3.5-turbo", "claude-2", etc.). + - `"input"`: The cost per user token in USD. + - `"output"`: The cost per assistant token in USD. + """ + if not self.provider.name or self.provider.name not in prices_json: + warnings.warn( + f"Token pricing for this provider is not available. " + "Please check the provider's documentation." + ) + return {} + result = next( + ( + item + for item in prices_json + if item["provider"] == self.provider.name + and item["model"] == self.provider.model + ), + None, + ) + print(result) + return result + def get_cost( self, options: CostOptions = "all", @@ -329,6 +369,10 @@ def get_cost( """ # Look up token cost for user and input tokens based on the provider and model + turns = self.get_turns(include_system_prompt=False) + + if len(turns) == 0: + return 0.0 if options == "last": # Multiply last user token count by user token cost diff --git a/chatlas/prices.json b/chatlas/data/prices.json similarity index 100% rename from chatlas/prices.json rename to chatlas/data/prices.json diff --git a/tests/test_tokens.py b/tests/test_tokens.py index 5b20edff..06b665e5 100644 --- a/tests/test_tokens.py +++ b/tests/test_tokens.py @@ -1,7 +1,6 @@ from chatlas import ChatAnthropic, ChatGoogle, ChatOpenAI, Turn from chatlas._openai import OpenAIAzureProvider, OpenAIProvider from chatlas._tokens import token_usage, tokens_log, tokens_reset -from pytest import MonkeyPatch def test_tokens_method(): From 00b9389c9340edaeccafda2e22d2a22f3bb17b4b Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Tue, 24 Jun 2025 13:02:26 -0400 Subject: [PATCH 07/59] Initial token price fetching --- chatlas/_chat.py | 17 +++++++++++------ tests/test_chat.py | 20 ++++++++++++++++---- 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/chatlas/_chat.py b/chatlas/_chat.py index 7dd3cc0b..77b85909 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -330,10 +330,9 @@ def get_token_pricing(self) -> dict[str, str | float]: - `"input"`: The cost per user token in USD. - `"output"`: The cost per assistant token in USD. """ - if not self.provider.name or self.provider.name not in prices_json: + if not self.provider.name: warnings.warn( - f"Token pricing for this provider is not available. " - "Please check the provider's documentation." + f"Please specify a provider name to access pricing information." ) return {} result = next( @@ -341,11 +340,17 @@ def get_token_pricing(self) -> dict[str, str | float]: item for item in prices_json if item["provider"] == self.provider.name - and item["model"] == self.provider.model + and item["model"] == self.provider._model ), - None, + {}, ) - print(result) + + if not result: + warnings.warn( + f"Token pricing for the provider and model you selected is not available. " + "Please check the provider's documentation." + ) + return result def get_cost( diff --git a/tests/test_chat.py b/tests/test_chat.py index eb368e8b..ee79bfe8 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -19,6 +19,14 @@ def test_simple_batch_chat(): assert str(response) == "2" +def test_import_prices(): + print("Starting") + chat = ChatOpenAI() + print("Provider: ", chat.provider.name, chat.provider._model) + print("Pricing result: ", chat.get_token_pricing()) + print("DONE") + + @pytest.mark.asyncio async def test_simple_async_batch_chat(): chat = ChatOpenAI() @@ -30,10 +38,12 @@ async def test_simple_async_batch_chat(): def test_simple_streaming_chat(): chat = ChatOpenAI() - res = chat.stream(""" + res = chat.stream( + """ What are the canonical colors of the ROYGBIV rainbow? Put each colour on its own line. Don't use punctuation. - """) + """ + ) chunks = [chunk for chunk in res] assert len(chunks) > 2 result = "".join(chunks) @@ -48,10 +58,12 @@ def test_simple_streaming_chat(): @pytest.mark.asyncio async def test_simple_streaming_chat_async(): chat = ChatOpenAI() - res = await chat.stream_async(""" + res = await chat.stream_async( + """ What are the canonical colors of the ROYGBIV rainbow? Put each colour on its own line. Don't use punctuation. - """) + """ + ) chunks = [chunk async for chunk in res] assert len(chunks) > 2 result = "".join(chunks) From 471e67a921b86c020654573068c1ab6eb9956e69 Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Tue, 24 Jun 2025 13:15:05 -0400 Subject: [PATCH 08/59] Correcting importlib --- chatlas/_chat.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/chatlas/_chat.py b/chatlas/_chat.py index 77b85909..cd6367d1 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -25,7 +25,7 @@ overload, ) import orjson -import importlib +import importlib.resources as resources from pydantic import BaseModel from ._callbacks import CallbackManager @@ -50,11 +50,7 @@ from ._typing_extensions import TypedDict from ._utils import html_escape, wrap_async -f = ( - importlib.resources.files("chatlas") - .joinpath("data/prices.json") - .read_text(encoding="utf-8") -) +f = resources.files("chatlas").joinpath("data/prices.json").read_text(encoding="utf-8") prices_json = orjson.loads(f) From 80635a17d044636a9deac7f4a1c4582268952300 Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Tue, 24 Jun 2025 13:21:23 -0400 Subject: [PATCH 09/59] Fixing ordering --- chatlas/_chat.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/chatlas/_chat.py b/chatlas/_chat.py index cd6367d1..83fb2f2c 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -24,8 +24,9 @@ TypeVar, overload, ) -import orjson import importlib.resources as resources + +import orjson from pydantic import BaseModel from ._callbacks import CallbackManager @@ -328,7 +329,7 @@ def get_token_pricing(self) -> dict[str, str | float]: """ if not self.provider.name: warnings.warn( - f"Please specify a provider name to access pricing information." + "Please specify a provider name to access pricing information." ) return {} result = next( @@ -343,7 +344,7 @@ def get_token_pricing(self) -> dict[str, str | float]: if not result: warnings.warn( - f"Token pricing for the provider and model you selected is not available. " + "Token pricing for the provider and model you selected is not available. " "Please check the provider's documentation." ) From df255a3219f2897e38fc0ea975cf76bb4c1f408b Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Tue, 24 Jun 2025 13:28:41 -0400 Subject: [PATCH 10/59] Fixing imports again --- chatlas/_chat.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/chatlas/_chat.py b/chatlas/_chat.py index 83fb2f2c..8ec304b3 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -1,12 +1,12 @@ from __future__ import annotations import copy +import importlib.resources as resources import inspect import os import sys import traceback import warnings - from pathlib import Path from threading import Thread from typing import ( @@ -24,7 +24,6 @@ TypeVar, overload, ) -import importlib.resources as resources import orjson from pydantic import BaseModel From a445686a50a3229e829135dc8c4bc313879ef109 Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Tue, 24 Jun 2025 13:46:22 -0400 Subject: [PATCH 11/59] Ignoring type issues that don't exist --- chatlas/_chat.py | 7 ++++--- chatlas/_tokens.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/chatlas/_chat.py b/chatlas/_chat.py index 8ec304b3..b2e6f06e 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -326,7 +326,7 @@ def get_token_pricing(self) -> dict[str, str | float]: - `"input"`: The cost per user token in USD. - `"output"`: The cost per assistant token in USD. """ - if not self.provider.name: + if not self.provider.name: # type: ignore warnings.warn( "Please specify a provider name to access pricing information." ) @@ -335,8 +335,8 @@ def get_token_pricing(self) -> dict[str, str | float]: ( item for item in prices_json - if item["provider"] == self.provider.name - and item["model"] == self.provider._model + if item["provider"] == self.provider.name # type: ignore + and item["model"] == self.provider._model # type: ignore ), {}, ) @@ -389,6 +389,7 @@ def get_cost( # Multiply all the assistant token counts by the assistant token cost # Add them together and return print("hello") + return 0.0 def token_count( self, diff --git a/chatlas/_tokens.py b/chatlas/_tokens.py index 9ff9c080..662a7aed 100644 --- a/chatlas/_tokens.py +++ b/chatlas/_tokens.py @@ -62,7 +62,7 @@ def tokens_log(provider: Provider, tokens: tuple[int, int]) -> None: if hasattr(provider, "name"): # Use the provider's name if it has one - name = provider.name + name = provider.name # type: ignore else: # Fallback to class name if provider does not have a name attribute logger.info( From 06cbfda5a311bda8640da4eedaf7b2098dabb573 Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Tue, 24 Jun 2025 14:25:18 -0400 Subject: [PATCH 12/59] Test removing backticks from streaming test --- tests/test_chat.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_chat.py b/tests/test_chat.py index ee79bfe8..5569181d 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -47,11 +47,11 @@ def test_simple_streaming_chat(): chunks = [chunk for chunk in res] assert len(chunks) > 2 result = "".join(chunks) - res = re.sub(r"\s+", "", result).lower() + res = re.sub(r"[\s`]+", "", result).lower() assert res == "redorangeyellowgreenblueindigoviolet" turn = chat.get_last_turn() assert turn is not None - res = re.sub(r"\s+", "", turn.text).lower() + res = re.sub(r"[\s`]+", "", turn.text).lower() assert res == "redorangeyellowgreenblueindigoviolet" From b6b21267ce02f81ced552f7e144b3a9ee7ed433a Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Wed, 25 Jun 2025 12:35:28 -0400 Subject: [PATCH 13/59] Switched to _name convention --- chatlas/_anthropic.py | 10 +++++- chatlas/_chat.py | 42 +------------------------- chatlas/_databricks.py | 9 ++++++ chatlas/_google.py | 13 ++++++-- chatlas/_openai.py | 15 +++++++-- chatlas/_provider.py | 8 +++++ chatlas/_tokens.py | 52 ++++++++++++++++++++++++-------- tests/test_chat.py | 8 ----- tests/test_provider_anthropic.py | 2 ++ tests/test_tokens.py | 10 +++++- 10 files changed, 102 insertions(+), 67 deletions(-) diff --git a/chatlas/_anthropic.py b/chatlas/_anthropic.py index 59920cd1..1244ab2a 100644 --- a/chatlas/_anthropic.py +++ b/chatlas/_anthropic.py @@ -200,7 +200,7 @@ def __init__( "`ChatAnthropic()` requires the `anthropic` package. " "You can install it with 'pip install anthropic'." ) - self.name = "Anthropic" + self._name = "Anthropic" self._model = model self._max_tokens = max_tokens @@ -213,6 +213,14 @@ def __init__( self._client = Anthropic(**kwargs_full) # type: ignore self._async_client = AsyncAnthropic(**kwargs_full) + @property + def name(self): + return self._name + + @property + def model(self): + return self._model + @overload def chat_perform( self, diff --git a/chatlas/_chat.py b/chatlas/_chat.py index b2e6f06e..b3f1cd80 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -1,7 +1,6 @@ from __future__ import annotations import copy -import importlib.resources as resources import inspect import os import sys @@ -25,7 +24,6 @@ overload, ) -import orjson from pydantic import BaseModel from ._callbacks import CallbackManager @@ -50,9 +48,6 @@ from ._typing_extensions import TypedDict from ._utils import html_escape, wrap_async -f = resources.files("chatlas").joinpath("data/prices.json").read_text(encoding="utf-8") -prices_json = orjson.loads(f) - class AnyTypeDict(TypedDict, total=False): pass @@ -313,42 +308,7 @@ def get_tokens(self) -> list[TokensDict]: return res - def get_token_pricing(self) -> dict[str, str | float]: - """ - Get the token pricing for the chat if available based on the prices.json file. - - Returns - ------- - dict[str, str | float] - A dictionary with the token pricing for the chat. The keys are: - - `"provider"`: The provider name (e.g., "OpenAI", "Anthropic", etc.). - - `model`: The model name (e.g., "gpt-3.5-turbo", "claude-2", etc.). - - `"input"`: The cost per user token in USD. - - `"output"`: The cost per assistant token in USD. - """ - if not self.provider.name: # type: ignore - warnings.warn( - "Please specify a provider name to access pricing information." - ) - return {} - result = next( - ( - item - for item in prices_json - if item["provider"] == self.provider.name # type: ignore - and item["model"] == self.provider._model # type: ignore - ), - {}, - ) - - if not result: - warnings.warn( - "Token pricing for the provider and model you selected is not available. " - "Please check the provider's documentation." - ) - - return result - + # TODO: Add another arg for if people want to by cost data def get_cost( self, options: CostOptions = "all", diff --git a/chatlas/_databricks.py b/chatlas/_databricks.py index 47fe903f..9f22cc50 100644 --- a/chatlas/_databricks.py +++ b/chatlas/_databricks.py @@ -118,6 +118,7 @@ def __init__( from openai import AsyncOpenAI self._model = model + self._name = "Databricks" self._seed = None if workspace_client is None: @@ -137,3 +138,11 @@ def __init__( api_key="no-token", # A placeholder to pass validations, this will not be used http_client=httpx.AsyncClient(auth=client._client.auth), ) + + @property + def name(self): + return self._name + + @property + def model(self): + return self._model diff --git a/chatlas/_google.py b/chatlas/_google.py index 8f00839c..9b359859 100644 --- a/chatlas/_google.py +++ b/chatlas/_google.py @@ -176,10 +176,11 @@ def __init__( ) self._model = model + if kwargs and kwargs.get("vertexai"): - self.name = "Google/Vertex" + self._name = "Google/Vertex" else: - self.name = "Google/Gemini" + self._name = "Google/Gemini" kwargs_full: "ChatClientArgs" = { "api_key": api_key, @@ -188,6 +189,14 @@ def __init__( self._client = genai.Client(**kwargs_full) + @property + def name(self): + return self._name + + @property + def model(self): + return self._model + @overload def chat_perform( self, diff --git a/chatlas/_openai.py b/chatlas/_openai.py index f3490af7..da28618a 100644 --- a/chatlas/_openai.py +++ b/chatlas/_openai.py @@ -192,7 +192,6 @@ def __init__( self._model = model self._seed = seed - self.name = "OpenAI" kwargs_full: "ChatClientArgs" = { "api_key": api_key, @@ -204,6 +203,10 @@ def __init__( self._client = OpenAI(**kwargs_full) # type: ignore self._async_client = AsyncOpenAI(**kwargs_full) + @property + def name(self): + return "OpenAI" + @overload def chat_perform( self, @@ -678,7 +681,7 @@ def __init__( self._model = deployment_id self._seed = seed - self.name = "OpenAIAzure" + self._name = "OpenAIAzure" kwargs_full: "ChatAzureClientArgs" = { "azure_endpoint": endpoint, @@ -691,6 +694,14 @@ def __init__( self._client = AzureOpenAI(**kwargs_full) # type: ignore self._async_client = AsyncAzureOpenAI(**kwargs_full) # type: ignore + @property + def name(self): + return self._name + + @property + def model(self): + return self._model + class InvalidJSONParameterWarning(RuntimeWarning): """ diff --git a/chatlas/_provider.py b/chatlas/_provider.py index d6af75bf..ddffc8dc 100644 --- a/chatlas/_provider.py +++ b/chatlas/_provider.py @@ -40,6 +40,14 @@ class Provider( directly. """ + @property + def name(self): + raise NotImplementedError("Name property must be implemented.") + + @property + def model(self): + raise NotImplementedError("Model property must be implemented.") + @overload @abstractmethod def chat_perform( diff --git a/chatlas/_tokens.py b/chatlas/_tokens.py index 662a7aed..a2e9da82 100644 --- a/chatlas/_tokens.py +++ b/chatlas/_tokens.py @@ -1,6 +1,8 @@ from __future__ import annotations import copy +import warnings +import importlib.resources as resources from threading import Lock from typing import TYPE_CHECKING @@ -9,6 +11,7 @@ if TYPE_CHECKING: from ._provider import Provider +import orjson class TokenUsage(TypedDict): @@ -59,18 +62,7 @@ def tokens_log(provider: Provider, tokens: tuple[int, int]) -> None: """ Log token usage for a provider in a thread-safe manner. """ - - if hasattr(provider, "name"): - # Use the provider's name if it has one - name = provider.name # type: ignore - else: - # Fallback to class name if provider does not have a name attribute - logger.info( - f"Provider {provider.__class__.__name__} does not have a 'name' attribute. " - "Using class name instead." - ) - name = provider.__class__.__name__.replace("Provider", "") - _token_counter.log_tokens(name, tokens[0], tokens[1]) + _token_counter.log_tokens(provider.name, tokens[0], tokens[1]) def tokens_reset() -> None: @@ -81,6 +73,42 @@ def tokens_reset() -> None: _token_counter = ThreadSafeTokenCounter() +f = resources.files("chatlas").joinpath("data/prices.json").read_text(encoding="utf-8") +prices_json = orjson.loads(f) + + +def get_token_pricing(provider: Provider) -> dict[str, str | float]: + """ + Get the token pricing for the chat if available based on the prices.json file. + + Returns + ------- + dict[str, str | float] + A dictionary with the token pricing for the chat. The keys are: + - `"provider"`: The provider name (e.g., "OpenAI", "Anthropic", etc.). + - `model`: The model name (e.g., "gpt-3.5-turbo", "claude-2", etc.). + - `"input"`: The cost per user token in USD. + - `"output"`: The cost per assistant token in USD. + """ + result = next( + ( + item + for item in prices_json + if item["provider"] == provider.name + and item["model"] == provider._model # type: ignore + ), + {}, + ) + + if not result: + warnings.warn( + "Token pricing for the provider and model you selected is not available. " + "Please check the provider's documentation." + ) + + return result + + def get_token_cost( name: str, model: str, input_tokens: int, output_tokens: int ) -> float | None: diff --git a/tests/test_chat.py b/tests/test_chat.py index 5569181d..ef84bd48 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -19,14 +19,6 @@ def test_simple_batch_chat(): assert str(response) == "2" -def test_import_prices(): - print("Starting") - chat = ChatOpenAI() - print("Provider: ", chat.provider.name, chat.provider._model) - print("Pricing result: ", chat.get_token_pricing()) - print("DONE") - - @pytest.mark.asyncio async def test_simple_async_batch_chat(): chat = ChatOpenAI() diff --git a/tests/test_provider_anthropic.py b/tests/test_provider_anthropic.py index e165a880..1f72a54e 100644 --- a/tests/test_provider_anthropic.py +++ b/tests/test_provider_anthropic.py @@ -28,6 +28,8 @@ def test_anthropic_simple_request(): assert turn is not None assert turn.tokens == (26, 5) assert turn.finish_reason == "end_turn" + assert chat.provider.name == "Anthropic" + assert chat.provider.model == "claude-3-7-sonnet-latest" @pytest.mark.asyncio diff --git a/tests/test_tokens.py b/tests/test_tokens.py index 06b665e5..3e4ee231 100644 --- a/tests/test_tokens.py +++ b/tests/test_tokens.py @@ -1,6 +1,6 @@ from chatlas import ChatAnthropic, ChatGoogle, ChatOpenAI, Turn from chatlas._openai import OpenAIAzureProvider, OpenAIProvider -from chatlas._tokens import token_usage, tokens_log, tokens_reset +from chatlas._tokens import token_usage, tokens_log, tokens_reset, get_token_pricing def test_tokens_method(): @@ -48,6 +48,14 @@ def test_token_count_method(): assert chat.token_count("What is 1 + 1?") == 9 +def test_import_prices(): + chat = ChatOpenAI() + + print("Provider: ", chat.provider.name, chat.provider._model) + print("Pricing result: ", get_token_pricing(chat.provider)) + print("DONE") + + def test_usage_is_none(): tokens_reset() assert token_usage() is None From f7dbfbadbf5dd78fa1044384b2a4513af0b90d81 Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Wed, 25 Jun 2025 12:57:23 -0400 Subject: [PATCH 14/59] Updating all classes and adding tests --- chatlas/_anthropic.py | 9 +++++++++ chatlas/_snowflake.py | 9 +++++++++ chatlas/_tokens.py | 7 +++---- tests/test_provider_azure.py | 1 + tests/test_provider_databricks.py | 2 ++ tests/test_provider_google.py | 4 ++++ tests/test_provider_openai.py | 2 ++ tests/test_provider_snowflake.py | 2 ++ 8 files changed, 32 insertions(+), 4 deletions(-) diff --git a/chatlas/_anthropic.py b/chatlas/_anthropic.py index 1244ab2a..e4683651 100644 --- a/chatlas/_anthropic.py +++ b/chatlas/_anthropic.py @@ -749,6 +749,7 @@ def __init__( ) self._model = model + self._name = "AnthropicBedrock" self._max_tokens = max_tokens kwargs_full: "ChatBedrockClientArgs" = { @@ -763,3 +764,11 @@ def __init__( self._client = AnthropicBedrock(**kwargs_full) # type: ignore self._async_client = AsyncAnthropicBedrock(**kwargs_full) # type: ignore + + @property + def name(self): + return self._name + + @property + def model(self): + return self._model diff --git a/chatlas/_snowflake.py b/chatlas/_snowflake.py index 5a2d3bd1..4e90abd3 100644 --- a/chatlas/_snowflake.py +++ b/chatlas/_snowflake.py @@ -199,10 +199,19 @@ def __init__( ) self._model = model + self._name = "Snowflake" session = Session.builder.configs(configs).create() self._cortex_service = Root(session).cortex_inference_service + @property + def name(self): + return self._name + + @property + def model(self): + return self._model + @overload def chat_perform( self, diff --git a/chatlas/_tokens.py b/chatlas/_tokens.py index a2e9da82..e2da46e2 100644 --- a/chatlas/_tokens.py +++ b/chatlas/_tokens.py @@ -5,13 +5,13 @@ import importlib.resources as resources from threading import Lock from typing import TYPE_CHECKING +import orjson from ._logging import logger from ._typing_extensions import TypedDict if TYPE_CHECKING: from ._provider import Provider -import orjson class TokenUsage(TypedDict): @@ -94,15 +94,14 @@ def get_token_pricing(provider: Provider) -> dict[str, str | float]: ( item for item in prices_json - if item["provider"] == provider.name - and item["model"] == provider._model # type: ignore + if item["provider"] == provider.name and item["model"] == provider.model ), {}, ) if not result: warnings.warn( - "Token pricing for the provider and model you selected is not available. " + f"Token pricing for the provider '{provider.name}' and model '{provider.model}' you selected is not available. " "Please check the provider's documentation." ) diff --git a/tests/test_provider_azure.py b/tests/test_provider_azure.py index 8eb2850e..eef856c4 100644 --- a/tests/test_provider_azure.py +++ b/tests/test_provider_azure.py @@ -22,6 +22,7 @@ def test_azure_simple_request(): turn = chat.get_last_turn() assert turn is not None assert turn.tokens == (27, 2) + assert chat.provider.name == "OpenAIAzure" @pytest.mark.asyncio diff --git a/tests/test_provider_databricks.py b/tests/test_provider_databricks.py index 7b7ea747..8c9b66fd 100644 --- a/tests/test_provider_databricks.py +++ b/tests/test_provider_databricks.py @@ -16,6 +16,8 @@ def test_openai_simple_request(): assert turn.tokens[0] == 26 # Not testing turn.tokens[1] because it's not deterministic. Typically 1 or 2. assert turn.finish_reason == "stop" + assert chat.provider.name == "Databricks" + assert chat.provider.model == "databricks-claude-3-7-sonnet" @pytest.mark.asyncio diff --git a/tests/test_provider_google.py b/tests/test_provider_google.py index 1352b9d6..37ffa5fc 100644 --- a/tests/test_provider_google.py +++ b/tests/test_provider_google.py @@ -33,6 +33,8 @@ def test_google_simple_request(): assert turn is not None assert turn.tokens == (16, 2) assert turn.finish_reason == "STOP" + assert chat.provider.name == "Google/Gemini" + assert chat.provider.model == "gemini-2.0-flash" def test_vertex_simple_request(): @@ -44,6 +46,8 @@ def test_vertex_simple_request(): assert turn is not None assert turn.tokens == (16, 2) assert turn.finish_reason == "STOP" + assert chat.provider.name == "Google/Vertex" + assert chat.provider.model == "gemini-2.0-flash" def test_name_setting(): diff --git a/tests/test_provider_openai.py b/tests/test_provider_openai.py index 43c47a0d..79ba8162 100644 --- a/tests/test_provider_openai.py +++ b/tests/test_provider_openai.py @@ -28,6 +28,8 @@ def test_openai_simple_request(): assert turn.tokens[0] == 27 # Not testing turn.tokens[1] because it's not deterministic. Typically 1 or 2. assert turn.finish_reason == "stop" + assert chat.provider.name == "OpenAI" + assert chat.provider.model == "gpt-4o" @pytest.mark.asyncio diff --git a/tests/test_provider_snowflake.py b/tests/test_provider_snowflake.py index 5148e39f..be575e19 100644 --- a/tests/test_provider_snowflake.py +++ b/tests/test_provider_snowflake.py @@ -24,6 +24,8 @@ def test_openai_simple_request(): chat.chat("What is 1 + 1?") turn = chat.get_last_turn() assert turn is not None + assert chat.provider.name == "Snowflake" + assert chat.provider.model == "claude-3-7-sonnet" # No token / finish_reason info available? # assert turn.tokens is not None From 965d4064d32a5c8cf21ab6595db1e00f684964be Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Wed, 25 Jun 2025 13:31:06 -0400 Subject: [PATCH 15/59] Removing old test --- tests/test_tokens.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/tests/test_tokens.py b/tests/test_tokens.py index 3e4ee231..db781350 100644 --- a/tests/test_tokens.py +++ b/tests/test_tokens.py @@ -78,19 +78,11 @@ def test_can_retrieve_and_log_tokens(): api_key="fake_key", endpoint="foo", api_version="bar" ) - # Check that the provider has a name attribute to start - assert provider2.name == "OpenAIAzure" - - delattr(provider2, "name") # Ensure no name attribute - assert hasattr(provider2, "name") is False - tokens_log(provider2, (5, 25)) usage = token_usage() assert usage is not None assert len(usage) == 2 - assert ( - usage[1]["name"] == "OpenAIAzure" - ) # Check that the name is set correctly for Providers without a name pre-set + assert usage[1]["name"] == "OpenAIAzure" assert usage[1]["input"] == 5 assert usage[1]["output"] == 25 From 75b8b1fb669f08bf472e54c08c88b0bb549edc32 Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Wed, 25 Jun 2025 13:51:43 -0400 Subject: [PATCH 16/59] Correcting class --- chatlas/_openai.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/chatlas/_openai.py b/chatlas/_openai.py index da28618a..a2635d92 100644 --- a/chatlas/_openai.py +++ b/chatlas/_openai.py @@ -190,6 +190,7 @@ def __init__( ): from openai import AsyncOpenAI, OpenAI + self._name = "OpenAI" self._model = model self._seed = seed @@ -205,7 +206,11 @@ def __init__( @property def name(self): - return "OpenAI" + return self._name + + @property + def model(self): + return self._model @overload def chat_perform( From 528a53cd1003ad09b09b9209f72c50aa4edc4cdf Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Thu, 26 Jun 2025 11:29:04 -0400 Subject: [PATCH 17/59] Updating class spec --- chatlas/_anthropic.py | 44 ++++++++--------------------------- chatlas/_chat.py | 52 +++++++++++++++++++++++++++++++----------- chatlas/_databricks.py | 18 ++++++--------- chatlas/_github.py | 1 + chatlas/_google.py | 21 ++++------------- chatlas/_groq.py | 1 + chatlas/_ollama.py | 1 + chatlas/_openai.py | 35 ++++++++++++---------------- chatlas/_perplexity.py | 1 + chatlas/_provider.py | 9 ++++++-- chatlas/_snowflake.py | 15 +++--------- chatlas/_tokens.py | 47 ++++++++++++++------------------------ tests/test_chat.py | 19 +++++++++++++-- tests/test_tokens.py | 19 ++++++++++----- 14 files changed, 135 insertions(+), 148 deletions(-) diff --git a/chatlas/_anthropic.py b/chatlas/_anthropic.py index e4683651..dc9ab649 100644 --- a/chatlas/_anthropic.py +++ b/chatlas/_anthropic.py @@ -6,6 +6,7 @@ import orjson from pydantic import BaseModel +from anthropic import Anthropic, AsyncAnthropic, AnthropicBedrock, AsyncAnthropicBedrock from ._chat import Chat from ._content import ( @@ -62,6 +63,7 @@ def ChatAnthropic( model: "Optional[ModelParam]" = None, api_key: Optional[str] = None, max_tokens: int = 4096, + name: Literal["Anthropic"] = "Anthropic", kwargs: Optional["ChatClientArgs"] = None, ) -> Chat["SubmitInputArgs", Message]: """ @@ -191,17 +193,11 @@ def __init__( max_tokens: int, model: str, api_key: str | None, + name: str = "Anthropic", kwargs: Optional["ChatClientArgs"] = None, ): - try: - from anthropic import Anthropic, AsyncAnthropic - except ImportError: - raise ImportError( - "`ChatAnthropic()` requires the `anthropic` package. " - "You can install it with 'pip install anthropic'." - ) - self._name = "Anthropic" - self._model = model + super().__init__(name=name, model=model) + self._max_tokens = max_tokens kwargs_full: "ChatClientArgs" = { @@ -213,14 +209,6 @@ def __init__( self._client = Anthropic(**kwargs_full) # type: ignore self._async_client = AsyncAnthropic(**kwargs_full) - @property - def name(self): - return self._name - - @property - def model(self): - return self._model - @overload def chat_perform( self, @@ -330,7 +318,7 @@ def _structured_tool_call(**kwargs: Any): kwargs_full: "SubmitInputArgs" = { "stream": stream, "messages": self._as_message_params(turns), - "model": self._model, + "model": self.model, "max_tokens": self._max_tokens, "tools": tool_schemas, **(kwargs or {}), @@ -738,18 +726,12 @@ def __init__( aws_session_token: str | None, max_tokens: int, base_url: str | None, + name: str = "AnthropicBedrock", kwargs: Optional["ChatBedrockClientArgs"] = None, ): - try: - from anthropic import AnthropicBedrock, AsyncAnthropicBedrock - except ImportError: - raise ImportError( - "`ChatBedrockAnthropic()` requires the `anthropic` package. " - "Install it with `pip install anthropic[bedrock]`." - ) - self._model = model - self._name = "AnthropicBedrock" + super().__init__(name=name, model=model) + self._max_tokens = max_tokens kwargs_full: "ChatBedrockClientArgs" = { @@ -764,11 +746,3 @@ def __init__( self._client = AnthropicBedrock(**kwargs_full) # type: ignore self._async_client = AsyncAnthropicBedrock(**kwargs_full) # type: ignore - - @property - def name(self): - return self._name - - @property - def model(self): - return self._model diff --git a/chatlas/_chat.py b/chatlas/_chat.py index b3f1cd80..72758d2c 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -44,6 +44,7 @@ from ._logging import log_tool_error from ._provider import Provider from ._tools import Tool, ToolRejectError +from ._tokens import TokenPrice, get_token_pricing from ._turn import Turn, user_turn from ._typing_extensions import TypedDict from ._utils import html_escape, wrap_async @@ -312,6 +313,7 @@ def get_tokens(self) -> list[TokensDict]: def get_cost( self, options: CostOptions = "all", + tokenPrice: Optional[type[TokenPrice]] = None, ) -> float: """ Get the cost of the chat. @@ -322,6 +324,12 @@ def get_cost( One of the following (default is "all"): - `"all"`: Return the total cost of all turns in the chat. - `"last"`: Return the cost of the last turn in the chat. + tokenPrice | None + A dictionary with the token pricing for the chat. This can be specified if your provider and/or model does not currently have pricing in our `data/prices.json`. + - `"provider"`: The provider name (e.g., "OpenAI", "Anthropic", etc.). + - `model`: The model name (e.g., "gpt-3.5-turbo", "claude-2", etc.). + - `"input"`: The cost per user token in USD per million tokens. + - `"output"`: The cost per assistant token in USD per million tokens. Returns ------- @@ -330,26 +338,44 @@ def get_cost( """ # Look up token cost for user and input tokens based on the provider and model - turns = self.get_turns(include_system_prompt=False) + turns_tokens = self.get_tokens() + if tokenPrice: + price_token = tokenPrice + else: + price_token = get_token_pricing(self.provider) - if len(turns) == 0: + if not price_token: + raise KeyError( + f"We could not locate provider ' { self.provider.name } ' and model '{ self.provider.model } ' in our pricing information. Please supply your own if you wish to use the cost function." + ) + + if len(turns_tokens) == 0: return 0.0 if options == "last": - # Multiply last user token count by user token cost - # Multiply last assistant token count by assistant token cost + # Get the last turn + last_turn = turns_tokens[len(turns_tokens) - 1] # Add - # Return - print("hi") + acc = 0.0 + if last_turn["role"] == "assistant": + acc += last_turn["tokens"] * (price_token["output"] / 1000000) + elif last_turn["role"] == "user": + acc += last_turn["tokens_total"] * (price_token["input"] / 1000000) + else: + raise ValueError(f"Unrecognized role type { last_turn['role'] }") + return acc if options == "all": - # Get all the user token counts - # Get all the assistant token counts - # Multiply all the user token counts by the user token cost - # Multiply all the assistant token counts by the assistant token cost - # Add them together and return - print("hello") - return 0.0 + asst_tokens = sum( + u["tokens_total"] for u in turns_tokens if u["role"] == "assistant" + ) + user_tokens = sum( + u["tokens_total"] for u in turns_tokens if u["role"] == "user" + ) + cost = (asst_tokens * (price_token["output"] / 1000000)) + ( + user_tokens * (price_token["input"] / 1000000) + ) + return cost def token_count( self, diff --git a/chatlas/_databricks.py b/chatlas/_databricks.py index 9f22cc50..f26bf0ed 100644 --- a/chatlas/_databricks.py +++ b/chatlas/_databricks.py @@ -1,11 +1,14 @@ from __future__ import annotations from typing import TYPE_CHECKING, Optional - +import httpx +from openai import AsyncOpenAI from ._chat import Chat from ._logging import log_model_default from ._openai import OpenAIProvider from ._turn import Turn, normalize_turns +from databricks.sdk import WorkspaceClient + if TYPE_CHECKING: from databricks.sdk import WorkspaceClient @@ -104,6 +107,7 @@ def __init__( self, *, model: str, + name: str = "Databricks", workspace_client: Optional["WorkspaceClient"] = None, ): try: @@ -117,8 +121,8 @@ def __init__( import httpx from openai import AsyncOpenAI - self._model = model - self._name = "Databricks" + super().__init__(name=name, model=model) + self._seed = None if workspace_client is None: @@ -138,11 +142,3 @@ def __init__( api_key="no-token", # A placeholder to pass validations, this will not be used http_client=httpx.AsyncClient(auth=client._client.auth), ) - - @property - def name(self): - return self._name - - @property - def model(self): - return self._model diff --git a/chatlas/_github.py b/chatlas/_github.py index f6a91968..f68ca5ee 100644 --- a/chatlas/_github.py +++ b/chatlas/_github.py @@ -137,5 +137,6 @@ def ChatGithub( api_key=api_key, base_url=base_url, seed=seed, + name="GitHub", kwargs=kwargs, ) diff --git a/chatlas/_google.py b/chatlas/_google.py index 9b359859..2ab6355f 100644 --- a/chatlas/_google.py +++ b/chatlas/_google.py @@ -146,6 +146,7 @@ def ChatGoogle( provider=GoogleProvider( model=model, api_key=api_key, + name="Google/Gemini", kwargs=kwargs, ), turns=normalize_turns( @@ -165,6 +166,7 @@ def __init__( *, model: str, api_key: str | None, + name: Optional[str] = "Google/Gemini", kwargs: Optional["ChatClientArgs"], ): try: @@ -174,13 +176,7 @@ def __init__( f"The {self.__class__.__name__} class requires the `google-genai` package. " "Install it with `pip install google-genai`." ) - - self._model = model - - if kwargs and kwargs.get("vertexai"): - self._name = "Google/Vertex" - else: - self._name = "Google/Gemini" + super().__init__(name=name, model=model) kwargs_full: "ChatClientArgs" = { "api_key": api_key, @@ -189,14 +185,6 @@ def __init__( self._client = genai.Client(**kwargs_full) - @property - def name(self): - return self._name - - @property - def model(self): - return self._model - @overload def chat_perform( self, @@ -280,7 +268,7 @@ def _chat_perform_args( from google.genai.types import Tool as GoogleTool kwargs_full: "SubmitInputArgs" = { - "model": self._model, + "model": self.model, "contents": cast("GoogleContent", self._google_contents(turns)), **(kwargs or {}), } @@ -624,6 +612,7 @@ def ChatVertex( provider=GoogleProvider( model=model, api_key=api_key, + name="Google/Vertex", kwargs=kwargs, ), turns=normalize_turns( diff --git a/chatlas/_groq.py b/chatlas/_groq.py index 89e9bf57..95104a16 100644 --- a/chatlas/_groq.py +++ b/chatlas/_groq.py @@ -133,5 +133,6 @@ def ChatGroq( api_key=api_key, base_url=base_url, seed=seed, + name="Groq", kwargs=kwargs, ) diff --git a/chatlas/_ollama.py b/chatlas/_ollama.py index d0d87621..d6776ec5 100644 --- a/chatlas/_ollama.py +++ b/chatlas/_ollama.py @@ -110,6 +110,7 @@ def ChatOllama( base_url=f"{base_url}/v1", model=model, seed=seed, + name="Ollama", kwargs=kwargs, ) diff --git a/chatlas/_openai.py b/chatlas/_openai.py index a2635d92..25e952e3 100644 --- a/chatlas/_openai.py +++ b/chatlas/_openai.py @@ -5,6 +5,8 @@ import orjson from pydantic import BaseModel +from openai import AsyncOpenAI, OpenAI + from ._chat import Chat from ._content import ( @@ -58,6 +60,7 @@ def ChatOpenAI( api_key: Optional[str] = None, base_url: str = "https://api.openai.com/v1", seed: int | None | MISSING_TYPE = MISSING, + name: Optional[str] = "OpenAI", kwargs: Optional["ChatClientArgs"] = None, ) -> Chat["SubmitInputArgs", ChatCompletion]: """ @@ -169,6 +172,7 @@ def ChatOpenAI( model=model, base_url=base_url, seed=seed, + name=name, kwargs=kwargs, ), turns=normalize_turns( @@ -179,6 +183,7 @@ def ChatOpenAI( class OpenAIProvider(Provider[ChatCompletion, ChatCompletionChunk, ChatCompletionDict]): + def __init__( self, *, @@ -186,12 +191,13 @@ def __init__( model: str, base_url: str = "https://api.openai.com/v1", seed: Optional[int] = None, + name: str = "OpenAI", kwargs: Optional["ChatClientArgs"] = None, ): from openai import AsyncOpenAI, OpenAI - self._name = "OpenAI" - self._model = model + super().__init__(name=name, model=model) + self._seed = seed kwargs_full: "ChatClientArgs" = { @@ -204,14 +210,6 @@ def __init__( self._client = OpenAI(**kwargs_full) # type: ignore self._async_client = AsyncOpenAI(**kwargs_full) - @property - def name(self): - return self._name - - @property - def model(self): - return self._model - @overload def chat_perform( self, @@ -293,7 +291,7 @@ def _chat_perform_args( kwargs_full: "SubmitInputArgs" = { "stream": stream, "messages": self._as_message_param(turns), - "model": self._model, + "model": self.model, **(kwargs or {}), } @@ -672,6 +670,7 @@ def ChatAzureOpenAI( class OpenAIAzureProvider(OpenAIProvider): + def __init__( self, *, @@ -680,13 +679,15 @@ def __init__( api_version: Optional[str] = None, api_key: Optional[str] = None, seed: int | None = None, + name: str = "OpenAIAzure", + model: str = "UnusedValue", kwargs: Optional["ChatAzureClientArgs"] = None, ): from openai import AsyncAzureOpenAI, AzureOpenAI - self._model = deployment_id + super().__init__(name=name, model=deployment_id) + self._seed = seed - self._name = "OpenAIAzure" kwargs_full: "ChatAzureClientArgs" = { "azure_endpoint": endpoint, @@ -699,14 +700,6 @@ def __init__( self._client = AzureOpenAI(**kwargs_full) # type: ignore self._async_client = AsyncAzureOpenAI(**kwargs_full) # type: ignore - @property - def name(self): - return self._name - - @property - def model(self): - return self._model - class InvalidJSONParameterWarning(RuntimeWarning): """ diff --git a/chatlas/_perplexity.py b/chatlas/_perplexity.py index 673585e2..83f9d209 100644 --- a/chatlas/_perplexity.py +++ b/chatlas/_perplexity.py @@ -138,5 +138,6 @@ def ChatPerplexity( api_key=api_key, base_url=base_url, seed=seed, + name="Perplexity", kwargs=kwargs, ) diff --git a/chatlas/_provider.py b/chatlas/_provider.py index ddffc8dc..e37055ad 100644 --- a/chatlas/_provider.py +++ b/chatlas/_provider.py @@ -40,13 +40,18 @@ class Provider( directly. """ + # TODO: Add docstring for props + def __init__(self, *, name: str, model: str): + self._name = name + self._model = model + @property def name(self): - raise NotImplementedError("Name property must be implemented.") + return self._name @property def model(self): - raise NotImplementedError("Model property must be implemented.") + return self._model @overload @abstractmethod diff --git a/chatlas/_snowflake.py b/chatlas/_snowflake.py index 4e90abd3..1ed474e4 100644 --- a/chatlas/_snowflake.py +++ b/chatlas/_snowflake.py @@ -175,6 +175,7 @@ def __init__( password: Optional[str], private_key_file: Optional[str], private_key_file_pwd: Optional[str], + name: Optional[str] = "Snowflake", kwargs: Optional[dict[str, "str | int"]], ): try: @@ -185,6 +186,7 @@ def __init__( "`ChatSnowflake()` requires the `snowflake-ml-python` package. " "Please install it via `pip install snowflake-ml-python`." ) + super().__init__(name=name, model=model) configs: dict[str, str | int] = drop_none( { @@ -198,20 +200,9 @@ def __init__( } ) - self._model = model - self._name = "Snowflake" - session = Session.builder.configs(configs).create() self._cortex_service = Root(session).cortex_inference_service - @property - def name(self): - return self._name - - @property - def model(self): - return self._model - @overload def chat_perform( self, @@ -323,7 +314,7 @@ def _complete_request( from snowflake.core.cortex.inference_service import CompleteRequest req = CompleteRequest( - model=self._model, + model=self.model, messages=self._as_request_messages(turns), stream=stream, ) diff --git a/chatlas/_tokens.py b/chatlas/_tokens.py index e2da46e2..0d0f5fc2 100644 --- a/chatlas/_tokens.py +++ b/chatlas/_tokens.py @@ -73,8 +73,21 @@ def tokens_reset() -> None: _token_counter = ThreadSafeTokenCounter() +class TokenPrice(TypedDict): + """ + Defines the necessary information to look up pricing for a given turn. + """ + + provider: str + model: str + cached_input: float + input: float + output: float + + +# Load in pricing pulled from ellmer f = resources.files("chatlas").joinpath("data/prices.json").read_text(encoding="utf-8") -prices_json = orjson.loads(f) +PricingList: list[TokenPrice] = orjson.loads(f) def get_token_pricing(provider: Provider) -> dict[str, str | float]: @@ -87,13 +100,13 @@ def get_token_pricing(provider: Provider) -> dict[str, str | float]: A dictionary with the token pricing for the chat. The keys are: - `"provider"`: The provider name (e.g., "OpenAI", "Anthropic", etc.). - `model`: The model name (e.g., "gpt-3.5-turbo", "claude-2", etc.). - - `"input"`: The cost per user token in USD. - - `"output"`: The cost per assistant token in USD. + - `"input"`: The cost per user token in USD per million tokens. + - `"output"`: The cost per assistant token in USD per million tokens. """ result = next( ( item - for item in prices_json + for item in PricingList if item["provider"] == provider.name and item["model"] == provider.model ), {}, @@ -108,32 +121,6 @@ def get_token_pricing(provider: Provider) -> dict[str, str | float]: return result -def get_token_cost( - name: str, model: str, input_tokens: int, output_tokens: int -) -> float | None: - """ - Get the cost of tokens for a given provider and model. - - Parameters - ---------- - name : Provider - The provider instance. - model : str - The model name. - input_tokens : int - The number of input tokens. - output_tokens : int - The number of output tokens. - - Returns - ------- - float - The cost of the tokens, or None if the cost is not known. - """ - - # return get_token_cost(provider.__name__, model, input_tokens, output_tokens) - - def token_usage() -> list[TokenUsage] | None: """ Report on token usage in the current session diff --git a/tests/test_chat.py b/tests/test_chat.py index ef84bd48..2f15a2f0 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -39,11 +39,11 @@ def test_simple_streaming_chat(): chunks = [chunk for chunk in res] assert len(chunks) > 2 result = "".join(chunks) - res = re.sub(r"[\s`]+", "", result).lower() + res = re.sub(r"\s+", "", result).lower() assert res == "redorangeyellowgreenblueindigoviolet" turn = chat.get_last_turn() assert turn is not None - res = re.sub(r"[\s`]+", "", turn.text).lower() + res = re.sub(r"\s+", "", turn.text).lower() assert res == "redorangeyellowgreenblueindigoviolet" @@ -283,3 +283,18 @@ def test_tool(user: str) -> str: assert str(response).lower() == "joe unknown hadley red" assert "Joe denied the request." in capsys.readouterr().out + + +def test_get_cost(): + chat = ChatOpenAI( + api_key="fake_key", + turns=[ + Turn(role="user", contents="Hi"), + Turn(role="assistant", contents="Hello", tokens=(2, 10)), + Turn(role="user", contents="Hi"), + Turn(role="assistant", contents="Hello", tokens=(14, 10)), + ], + ) + + cost = chat.get_cost(options="all") + print("COST", cost) diff --git a/tests/test_tokens.py b/tests/test_tokens.py index db781350..271bb7dc 100644 --- a/tests/test_tokens.py +++ b/tests/test_tokens.py @@ -1,6 +1,11 @@ from chatlas import ChatAnthropic, ChatGoogle, ChatOpenAI, Turn from chatlas._openai import OpenAIAzureProvider, OpenAIProvider -from chatlas._tokens import token_usage, tokens_log, tokens_reset, get_token_pricing +from chatlas._tokens import ( + token_usage, + tokens_log, + tokens_reset, + get_token_pricing, +) def test_tokens_method(): @@ -49,11 +54,13 @@ def test_token_count_method(): def test_import_prices(): - chat = ChatOpenAI() - - print("Provider: ", chat.provider.name, chat.provider._model) - print("Pricing result: ", get_token_pricing(chat.provider)) - print("DONE") + chat = ChatOpenAI(model="o1-mini") + pricing = get_token_pricing(chat.provider) + assert pricing["provider"] == "OpenAI" + assert pricing["model"] == "o1-mini" + assert isinstance(pricing["cached_input"], float) + assert isinstance(pricing["input"], float) + assert isinstance(pricing["output"], float) def test_usage_is_none(): From 0b76317e62588f98110ee2ddeb9776816467c36d Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Thu, 26 Jun 2025 11:47:12 -0400 Subject: [PATCH 18/59] Added token tests --- chatlas/_chat.py | 1 - tests/test_chat.py | 4 +++- tests/test_tokens.py | 12 +++++++++++- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/chatlas/_chat.py b/chatlas/_chat.py index 72758d2c..3eed5457 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -309,7 +309,6 @@ def get_tokens(self) -> list[TokensDict]: return res - # TODO: Add another arg for if people want to by cost data def get_cost( self, options: CostOptions = "all", diff --git a/tests/test_chat.py b/tests/test_chat.py index 2f15a2f0..a285ae67 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -295,6 +295,8 @@ def test_get_cost(): Turn(role="assistant", contents="Hello", tokens=(14, 10)), ], ) - cost = chat.get_cost(options="all") print("COST", cost) + + last = chat.get_cost(options="last") + print("Last:", last) diff --git a/tests/test_tokens.py b/tests/test_tokens.py index 271bb7dc..10c56cfe 100644 --- a/tests/test_tokens.py +++ b/tests/test_tokens.py @@ -6,6 +6,8 @@ tokens_reset, get_token_pricing, ) +import warnings +import pytest def test_tokens_method(): @@ -53,7 +55,7 @@ def test_token_count_method(): assert chat.token_count("What is 1 + 1?") == 9 -def test_import_prices(): +def test_get_token_prices(): chat = ChatOpenAI(model="o1-mini") pricing = get_token_pricing(chat.provider) assert pricing["provider"] == "OpenAI" @@ -62,6 +64,14 @@ def test_import_prices(): assert isinstance(pricing["input"], float) assert isinstance(pricing["output"], float) + with pytest.warns( + match="Token pricing for the provider 'NOPE' and model 'ABCD' you selected is not available. " + "Please check the provider's documentation." + ): + chat = ChatOpenAI(model="ABCD", name="NOPE") + pricing = get_token_pricing(chat.provider) + assert pricing == {} + def test_usage_is_none(): tokens_reset() From 70709fe92ca8bf4dcde8a47fdc2196317509e0f6 Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Thu, 26 Jun 2025 12:09:11 -0400 Subject: [PATCH 19/59] Adding cost test --- chatlas/_chat.py | 2 -- tests/test_chat.py | 29 ++++++++++++++++++++++++++--- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/chatlas/_chat.py b/chatlas/_chat.py index 3eed5457..a4f3d634 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -352,9 +352,7 @@ def get_cost( return 0.0 if options == "last": - # Get the last turn last_turn = turns_tokens[len(turns_tokens) - 1] - # Add acc = 0.0 if last_turn["role"] == "assistant": acc += last_turn["tokens"] * (price_token["output"] / 1000000) diff --git a/tests/test_chat.py b/tests/test_chat.py index a285ae67..a5563681 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -2,6 +2,8 @@ import tempfile import pytest +from pydantic import BaseModel + from chatlas import ( ChatOpenAI, ContentToolRequest, @@ -10,7 +12,7 @@ Turn, ) from chatlas._chat import ToolFailureWarning -from pydantic import BaseModel +from chatlas._tokens import TokenPrice def test_simple_batch_chat(): @@ -295,8 +297,29 @@ def test_get_cost(): Turn(role="assistant", contents="Hello", tokens=(14, 10)), ], ) + # Checking that these have the right form vs. the actual calculation because the price may change cost = chat.get_cost(options="all") - print("COST", cost) + assert isinstance(cost, float) + assert cost > 0 last = chat.get_cost(options="last") - print("Last:", last) + assert isinstance(last, float) + assert last > 0 + + assert cost > last + + byoc = TokenPrice( + { + "provider": "fake", + "model": "fake", + "cached_input": 1.0, + "input": 2.0, + "output": 3.0, + } + ) + + cost2 = chat.get_cost(options="all", tokenPrice=byoc) + assert cost2 == 0.000092 + + last2 = chat.get_cost(options="last", tokenPrice=byoc) + assert last2 == 0.00003 From c62a9dd725085a298f8a5ba53cc4f7a8b58616f4 Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Thu, 26 Jun 2025 13:41:58 -0400 Subject: [PATCH 20/59] Updating classes --- chatlas/_anthropic.py | 28 +++++++++++++++++++++------- chatlas/_openai.py | 9 +++------ 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/chatlas/_anthropic.py b/chatlas/_anthropic.py index dc9ab649..814103a9 100644 --- a/chatlas/_anthropic.py +++ b/chatlas/_anthropic.py @@ -187,17 +187,24 @@ def ChatAnthropic( class AnthropicProvider(Provider[Message, RawMessageStreamEvent, Message]): + def __init__( self, *, - max_tokens: int, + max_tokens: int = 4096, model: str, - api_key: str | None, + api_key: Optional[str] = None, name: str = "Anthropic", kwargs: Optional["ChatClientArgs"] = None, ): super().__init__(name=name, model=model) - + try: + from anthropic import Anthropic, AsyncAnthropic + except ImportError: + raise ImportError( + "`ChatAnthropic()` requires the `anthropic` package. " + "You can install it with 'pip install anthropic'." + ) self._max_tokens = max_tokens kwargs_full: "ChatClientArgs" = { @@ -715,6 +722,7 @@ def ChatBedrockAnthropic( class AnthropicBedrockProvider(AnthropicProvider): + def __init__( self, *, @@ -724,15 +732,21 @@ def __init__( aws_region: str | None, aws_profile: str | None, aws_session_token: str | None, - max_tokens: int, + max_tokens: int = 4096, base_url: str | None, - name: str = "AnthropicBedrock", + name: Optional[str] = "AnthropicBedrock", kwargs: Optional["ChatBedrockClientArgs"] = None, ): - super().__init__(name=name, model=model) + super().__init__(name=name, model=model, max_tokens=max_tokens) - self._max_tokens = max_tokens + try: + from anthropic import AnthropicBedrock, AsyncAnthropicBedrock + except ImportError: + raise ImportError( + "`ChatBedrockAnthropic()` requires the `anthropic` package. " + "Install it with `pip install anthropic[bedrock]`." + ) kwargs_full: "ChatBedrockClientArgs" = { "aws_secret_key": aws_secret_key, diff --git a/chatlas/_openai.py b/chatlas/_openai.py index 25e952e3..293e66b5 100644 --- a/chatlas/_openai.py +++ b/chatlas/_openai.py @@ -5,7 +5,7 @@ import orjson from pydantic import BaseModel -from openai import AsyncOpenAI, OpenAI +from openai import AsyncOpenAI, OpenAI, AsyncAzureOpenAI, AzureOpenAI from ._chat import Chat @@ -194,8 +194,6 @@ def __init__( name: str = "OpenAI", kwargs: Optional["ChatClientArgs"] = None, ): - from openai import AsyncOpenAI, OpenAI - super().__init__(name=name, model=model) self._seed = seed @@ -679,11 +677,10 @@ def __init__( api_version: Optional[str] = None, api_key: Optional[str] = None, seed: int | None = None, - name: str = "OpenAIAzure", - model: str = "UnusedValue", + name: Optional[str] = "OpenAIAzure", + model: Optional[str] = "UnusedValue", kwargs: Optional["ChatAzureClientArgs"] = None, ): - from openai import AsyncAzureOpenAI, AzureOpenAI super().__init__(name=name, model=deployment_id) From 8fbe70f9c472ee02a54376bcb9ebaedb4150f1ee Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Thu, 26 Jun 2025 13:43:19 -0400 Subject: [PATCH 21/59] Removing flaky tests that are prone to change in future based on model default changes --- tests/test_provider_anthropic.py | 2 -- tests/test_provider_databricks.py | 2 -- tests/test_provider_google.py | 2 -- tests/test_provider_openai.py | 2 -- tests/test_provider_snowflake.py | 1 - 5 files changed, 9 deletions(-) diff --git a/tests/test_provider_anthropic.py b/tests/test_provider_anthropic.py index 1f72a54e..e165a880 100644 --- a/tests/test_provider_anthropic.py +++ b/tests/test_provider_anthropic.py @@ -28,8 +28,6 @@ def test_anthropic_simple_request(): assert turn is not None assert turn.tokens == (26, 5) assert turn.finish_reason == "end_turn" - assert chat.provider.name == "Anthropic" - assert chat.provider.model == "claude-3-7-sonnet-latest" @pytest.mark.asyncio diff --git a/tests/test_provider_databricks.py b/tests/test_provider_databricks.py index 8c9b66fd..7b7ea747 100644 --- a/tests/test_provider_databricks.py +++ b/tests/test_provider_databricks.py @@ -16,8 +16,6 @@ def test_openai_simple_request(): assert turn.tokens[0] == 26 # Not testing turn.tokens[1] because it's not deterministic. Typically 1 or 2. assert turn.finish_reason == "stop" - assert chat.provider.name == "Databricks" - assert chat.provider.model == "databricks-claude-3-7-sonnet" @pytest.mark.asyncio diff --git a/tests/test_provider_google.py b/tests/test_provider_google.py index 37ffa5fc..d9ea896f 100644 --- a/tests/test_provider_google.py +++ b/tests/test_provider_google.py @@ -34,7 +34,6 @@ def test_google_simple_request(): assert turn.tokens == (16, 2) assert turn.finish_reason == "STOP" assert chat.provider.name == "Google/Gemini" - assert chat.provider.model == "gemini-2.0-flash" def test_vertex_simple_request(): @@ -47,7 +46,6 @@ def test_vertex_simple_request(): assert turn.tokens == (16, 2) assert turn.finish_reason == "STOP" assert chat.provider.name == "Google/Vertex" - assert chat.provider.model == "gemini-2.0-flash" def test_name_setting(): diff --git a/tests/test_provider_openai.py b/tests/test_provider_openai.py index 79ba8162..43c47a0d 100644 --- a/tests/test_provider_openai.py +++ b/tests/test_provider_openai.py @@ -28,8 +28,6 @@ def test_openai_simple_request(): assert turn.tokens[0] == 27 # Not testing turn.tokens[1] because it's not deterministic. Typically 1 or 2. assert turn.finish_reason == "stop" - assert chat.provider.name == "OpenAI" - assert chat.provider.model == "gpt-4o" @pytest.mark.asyncio diff --git a/tests/test_provider_snowflake.py b/tests/test_provider_snowflake.py index be575e19..dd5879e7 100644 --- a/tests/test_provider_snowflake.py +++ b/tests/test_provider_snowflake.py @@ -25,7 +25,6 @@ def test_openai_simple_request(): turn = chat.get_last_turn() assert turn is not None assert chat.provider.name == "Snowflake" - assert chat.provider.model == "claude-3-7-sonnet" # No token / finish_reason info available? # assert turn.tokens is not None From b601fdb4f7c0c18212b71340045d4b7780abf7c5 Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Thu, 26 Jun 2025 15:30:02 -0400 Subject: [PATCH 22/59] Fixing import orders --- chatlas/_anthropic.py | 1 - chatlas/_chat.py | 4 +++- chatlas/_databricks.py | 7 +++---- chatlas/_openai.py | 3 +-- chatlas/_tokens.py | 4 +++- 5 files changed, 10 insertions(+), 9 deletions(-) diff --git a/chatlas/_anthropic.py b/chatlas/_anthropic.py index 814103a9..e8878f68 100644 --- a/chatlas/_anthropic.py +++ b/chatlas/_anthropic.py @@ -6,7 +6,6 @@ import orjson from pydantic import BaseModel -from anthropic import Anthropic, AsyncAnthropic, AnthropicBedrock, AsyncAnthropicBedrock from ._chat import Chat from ._content import ( diff --git a/chatlas/_chat.py b/chatlas/_chat.py index a4f3d634..77195195 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -43,8 +43,8 @@ ) from ._logging import log_tool_error from ._provider import Provider -from ._tools import Tool, ToolRejectError from ._tokens import TokenPrice, get_token_pricing +from ._tools import Tool, ToolRejectError from ._turn import Turn, user_turn from ._typing_extensions import TypedDict from ._utils import html_escape, wrap_async @@ -309,6 +309,7 @@ def get_tokens(self) -> list[TokensDict]: return res + # TODO: BYOP to tuple (input, output) + add big 'ol disclaimer that WE DON'T KNOW def get_cost( self, options: CostOptions = "all", @@ -1689,6 +1690,7 @@ def __str__(self): res += f"## {icon} {turn.role.capitalize()} turn:\n\n{str(turn)}\n\n" return res + # TODO: Update this to get tokens and also provide cost add provider and model def __repr__(self): turns = self.get_turns(include_system_prompt=True) tokens = sum(sum(turn.tokens) for turn in turns if turn.tokens) diff --git a/chatlas/_databricks.py b/chatlas/_databricks.py index f26bf0ed..d75108b7 100644 --- a/chatlas/_databricks.py +++ b/chatlas/_databricks.py @@ -1,14 +1,13 @@ from __future__ import annotations from typing import TYPE_CHECKING, Optional -import httpx -from openai import AsyncOpenAI + +from databricks.sdk import WorkspaceClient + from ._chat import Chat from ._logging import log_model_default from ._openai import OpenAIProvider from ._turn import Turn, normalize_turns -from databricks.sdk import WorkspaceClient - if TYPE_CHECKING: from databricks.sdk import WorkspaceClient diff --git a/chatlas/_openai.py b/chatlas/_openai.py index 293e66b5..ba9a0fcf 100644 --- a/chatlas/_openai.py +++ b/chatlas/_openai.py @@ -4,9 +4,8 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, cast, overload import orjson +from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI from pydantic import BaseModel -from openai import AsyncOpenAI, OpenAI, AsyncAzureOpenAI, AzureOpenAI - from ._chat import Chat from ._content import ( diff --git a/chatlas/_tokens.py b/chatlas/_tokens.py index 0d0f5fc2..085fbd6b 100644 --- a/chatlas/_tokens.py +++ b/chatlas/_tokens.py @@ -1,10 +1,11 @@ from __future__ import annotations import copy -import warnings import importlib.resources as resources +import warnings from threading import Lock from typing import TYPE_CHECKING + import orjson from ._logging import logger @@ -121,6 +122,7 @@ def get_token_pricing(provider: Provider) -> dict[str, str | float]: return result +# TODO: Add price to this print def token_usage() -> list[TokenUsage] | None: """ Report on token usage in the current session From 952cfb46ce51b448b47b20bbba7a9ce6c801ff38 Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Thu, 26 Jun 2025 15:50:29 -0400 Subject: [PATCH 23/59] Updating to tuples --- chatlas/_chat.py | 31 ++++++++++++++++--------------- tests/test_chat.py | 14 +++----------- 2 files changed, 19 insertions(+), 26 deletions(-) diff --git a/chatlas/_chat.py b/chatlas/_chat.py index 77195195..4fb1604b 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -43,7 +43,7 @@ ) from ._logging import log_tool_error from ._provider import Provider -from ._tokens import TokenPrice, get_token_pricing +from ._tokens import get_token_pricing from ._tools import Tool, ToolRejectError from ._turn import Turn, user_turn from ._typing_extensions import TypedDict @@ -313,7 +313,7 @@ def get_tokens(self) -> list[TokensDict]: def get_cost( self, options: CostOptions = "all", - tokenPrice: Optional[type[TokenPrice]] = None, + token_price: Optional[tuple[float, float]] = None, ) -> float: """ Get the cost of the chat. @@ -324,12 +324,10 @@ def get_cost( One of the following (default is "all"): - `"all"`: Return the total cost of all turns in the chat. - `"last"`: Return the cost of the last turn in the chat. - tokenPrice | None - A dictionary with the token pricing for the chat. This can be specified if your provider and/or model does not currently have pricing in our `data/prices.json`. - - `"provider"`: The provider name (e.g., "OpenAI", "Anthropic", etc.). - - `model`: The model name (e.g., "gpt-3.5-turbo", "claude-2", etc.). - - `"input"`: The cost per user token in USD per million tokens. - - `"output"`: The cost per assistant token in USD per million tokens. + token_price + An optional tuple in the format of (input_token_cost, output_token_cost) for bringing your own cost information. + - `"input_token_cost"`: The cost per user token in USD per million tokens. + - `"output_token_cost"`: The cost per assistant token in USD per million tokens. Returns ------- @@ -339,12 +337,15 @@ def get_cost( # Look up token cost for user and input tokens based on the provider and model turns_tokens = self.get_tokens() - if tokenPrice: - price_token = tokenPrice + if token_price: + input_token_price = token_price[0] / 1000000 + output_token_price = token_price[1] / 1000000 else: price_token = get_token_pricing(self.provider) + input_token_price = price_token["input"] / 1000000 + output_token_price = price_token["output"] / 1000000 - if not price_token: + if not input_token_price and not output_token_price: raise KeyError( f"We could not locate provider ' { self.provider.name } ' and model '{ self.provider.model } ' in our pricing information. Please supply your own if you wish to use the cost function." ) @@ -356,9 +357,9 @@ def get_cost( last_turn = turns_tokens[len(turns_tokens) - 1] acc = 0.0 if last_turn["role"] == "assistant": - acc += last_turn["tokens"] * (price_token["output"] / 1000000) + acc += last_turn["tokens"] * output_token_price elif last_turn["role"] == "user": - acc += last_turn["tokens_total"] * (price_token["input"] / 1000000) + acc += last_turn["tokens_total"] * input_token_price else: raise ValueError(f"Unrecognized role type { last_turn['role'] }") return acc @@ -370,8 +371,8 @@ def get_cost( user_tokens = sum( u["tokens_total"] for u in turns_tokens if u["role"] == "user" ) - cost = (asst_tokens * (price_token["output"] / 1000000)) + ( - user_tokens * (price_token["input"] / 1000000) + cost = (asst_tokens * output_token_price) + ( + user_tokens * input_token_price ) return cost diff --git a/tests/test_chat.py b/tests/test_chat.py index a5563681..515579ff 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -308,18 +308,10 @@ def test_get_cost(): assert cost > last - byoc = TokenPrice( - { - "provider": "fake", - "model": "fake", - "cached_input": 1.0, - "input": 2.0, - "output": 3.0, - } - ) + byoc = (2.0, 3.0) - cost2 = chat.get_cost(options="all", tokenPrice=byoc) + cost2 = chat.get_cost(options="all", token_price=byoc) assert cost2 == 0.000092 - last2 = chat.get_cost(options="last", tokenPrice=byoc) + last2 = chat.get_cost(options="last", token_price=byoc) assert last2 == 0.00003 From 07e1cdcb288de6c370e8050795c628565e7c6b0f Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Fri, 27 Jun 2025 11:35:15 -0400 Subject: [PATCH 24/59] Importing OpenAI Pto replace Chat --- chatlas/_anthropic.py | 3 +-- chatlas/_chat.py | 3 +-- chatlas/_databricks.py | 2 +- chatlas/_github.py | 32 +++++++++++++++++++++----------- chatlas/_groq.py | 31 ++++++++++++++++++++----------- chatlas/_ollama.py | 30 ++++++++++++++++++++---------- chatlas/_openai.py | 4 +--- chatlas/_perplexity.py | 37 ++++++++++++++++++++++++++----------- chatlas/_tokens.py | 2 +- 9 files changed, 92 insertions(+), 52 deletions(-) diff --git a/chatlas/_anthropic.py b/chatlas/_anthropic.py index e8878f68..83bae2f8 100644 --- a/chatlas/_anthropic.py +++ b/chatlas/_anthropic.py @@ -62,7 +62,6 @@ def ChatAnthropic( model: "Optional[ModelParam]" = None, api_key: Optional[str] = None, max_tokens: int = 4096, - name: Literal["Anthropic"] = "Anthropic", kwargs: Optional["ChatClientArgs"] = None, ) -> Chat["SubmitInputArgs", Message]: """ @@ -193,7 +192,7 @@ def __init__( max_tokens: int = 4096, model: str, api_key: Optional[str] = None, - name: str = "Anthropic", + name: Optional[str] = "Anthropic", kwargs: Optional["ChatClientArgs"] = None, ): super().__init__(name=name, model=model) diff --git a/chatlas/_chat.py b/chatlas/_chat.py index 4fb1604b..f5e195b2 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -309,14 +309,13 @@ def get_tokens(self) -> list[TokensDict]: return res - # TODO: BYOP to tuple (input, output) + add big 'ol disclaimer that WE DON'T KNOW def get_cost( self, options: CostOptions = "all", token_price: Optional[tuple[float, float]] = None, ) -> float: """ - Get the cost of the chat. + Get the cost of the chat. Note that this is a rough estimate. Providers may change their pricing frequently and without notice. Parameters ---------- diff --git a/chatlas/_databricks.py b/chatlas/_databricks.py index d75108b7..8e4dffab 100644 --- a/chatlas/_databricks.py +++ b/chatlas/_databricks.py @@ -106,7 +106,7 @@ def __init__( self, *, model: str, - name: str = "Databricks", + name: Optional[str] = "Databricks", workspace_client: Optional["WorkspaceClient"] = None, ): try: diff --git a/chatlas/_github.py b/chatlas/_github.py index f68ca5ee..1ce9d67e 100644 --- a/chatlas/_github.py +++ b/chatlas/_github.py @@ -5,9 +5,9 @@ from ._chat import Chat from ._logging import log_model_default -from ._openai import ChatOpenAI +from ._openai import OpenAIProvider, normalize_turns from ._turn import Turn -from ._utils import MISSING, MISSING_TYPE +from ._utils import MISSING, MISSING_TYPE, is_testing if TYPE_CHECKING: from ._openai import ChatCompletion @@ -130,13 +130,23 @@ def ChatGithub( if api_key is None: api_key = os.getenv("GITHUB_PAT") - return ChatOpenAI( - system_prompt=system_prompt, - turns=turns, - model=model, - api_key=api_key, - base_url=base_url, - seed=seed, - name="GitHub", - kwargs=kwargs, + if isinstance(seed, MISSING_TYPE): + seed = 1014 if is_testing() else None + + if model is None: + model = log_model_default("gpt-4o") + + return Chat( + provider=OpenAIProvider( + api_key=api_key, + model=model, + base_url=base_url, + seed=seed, + name="GitHub", + kwargs=kwargs, + ), + turns=normalize_turns( + turns or [], + system_prompt, + ), ) diff --git a/chatlas/_groq.py b/chatlas/_groq.py index 95104a16..0322fc4c 100644 --- a/chatlas/_groq.py +++ b/chatlas/_groq.py @@ -5,9 +5,9 @@ from ._chat import Chat from ._logging import log_model_default -from ._openai import ChatOpenAI +from ._openai import OpenAIProvider, normalize_turns from ._turn import Turn -from ._utils import MISSING, MISSING_TYPE +from ._utils import MISSING, MISSING_TYPE, is_testing if TYPE_CHECKING: from ._openai import ChatCompletion @@ -125,14 +125,23 @@ def ChatGroq( model = log_model_default("llama3-8b-8192") if api_key is None: api_key = os.getenv("GROQ_API_KEY") + if isinstance(seed, MISSING_TYPE): + seed = 1014 if is_testing() else None - return ChatOpenAI( - system_prompt=system_prompt, - turns=turns, - model=model, - api_key=api_key, - base_url=base_url, - seed=seed, - name="Groq", - kwargs=kwargs, + if model is None: + model = log_model_default("gpt-4o") + + return Chat( + provider=OpenAIProvider( + api_key=api_key, + model=model, + base_url=base_url, + seed=seed, + name="Groq", + kwargs=kwargs, + ), + turns=normalize_turns( + turns or [], + system_prompt, + ), ) diff --git a/chatlas/_ollama.py b/chatlas/_ollama.py index d6776ec5..26d4f66f 100644 --- a/chatlas/_ollama.py +++ b/chatlas/_ollama.py @@ -7,8 +7,9 @@ import orjson from ._chat import Chat -from ._openai import ChatOpenAI +from ._openai import OpenAIProvider, log_model_default, normalize_turns from ._turn import Turn +from ._utils import MISSING_TYPE, is_testing if TYPE_CHECKING: from ._openai import ChatCompletion @@ -102,16 +103,25 @@ def ChatOllama( raise ValueError( f"Must specify model. Locally installed models: {', '.join(models)}" ) + if isinstance(seed, MISSING_TYPE): + seed = 1014 if is_testing() else None - return ChatOpenAI( - system_prompt=system_prompt, - api_key="ollama", # ignored - turns=turns, - base_url=f"{base_url}/v1", - model=model, - seed=seed, - name="Ollama", - kwargs=kwargs, + if model is None: + model = log_model_default("gpt-4o") + + return Chat( + provider=OpenAIProvider( + api_key="ollama", # ignored + model=model, + base_url=f"{base_url}/v1", + seed=seed, + name="Ollama", + kwargs=kwargs, + ), + turns=normalize_turns( + turns or [], + system_prompt, + ), ) diff --git a/chatlas/_openai.py b/chatlas/_openai.py index ba9a0fcf..37862974 100644 --- a/chatlas/_openai.py +++ b/chatlas/_openai.py @@ -59,7 +59,6 @@ def ChatOpenAI( api_key: Optional[str] = None, base_url: str = "https://api.openai.com/v1", seed: int | None | MISSING_TYPE = MISSING, - name: Optional[str] = "OpenAI", kwargs: Optional["ChatClientArgs"] = None, ) -> Chat["SubmitInputArgs", ChatCompletion]: """ @@ -171,7 +170,6 @@ def ChatOpenAI( model=model, base_url=base_url, seed=seed, - name=name, kwargs=kwargs, ), turns=normalize_turns( @@ -190,7 +188,7 @@ def __init__( model: str, base_url: str = "https://api.openai.com/v1", seed: Optional[int] = None, - name: str = "OpenAI", + name: Optional[str] = "OpenAI", kwargs: Optional["ChatClientArgs"] = None, ): super().__init__(name=name, model=model) diff --git a/chatlas/_perplexity.py b/chatlas/_perplexity.py index 83f9d209..5d33a468 100644 --- a/chatlas/_perplexity.py +++ b/chatlas/_perplexity.py @@ -5,9 +5,9 @@ from ._chat import Chat from ._logging import log_model_default -from ._openai import ChatOpenAI +from ._openai import OpenAIProvider, normalize_turns from ._turn import Turn -from ._utils import MISSING, MISSING_TYPE +from ._utils import MISSING, MISSING_TYPE, is_testing if TYPE_CHECKING: from ._openai import ChatCompletion @@ -131,13 +131,28 @@ def ChatPerplexity( if api_key is None: api_key = os.getenv("PERPLEXITY_API_KEY") - return ChatOpenAI( - system_prompt=system_prompt, - turns=turns, - model=model, - api_key=api_key, - base_url=base_url, - seed=seed, - name="Perplexity", - kwargs=kwargs, + if model is None: + model = log_model_default("gpt-4o") + if api_key is None: + api_key = os.getenv("GITHUB_PAT") + + if isinstance(seed, MISSING_TYPE): + seed = 1014 if is_testing() else None + + if model is None: + model = log_model_default("gpt-4o") + + return Chat( + provider=OpenAIProvider( + api_key=api_key, + model=model, + base_url=base_url, + seed=seed, + name="Perplexity", + kwargs=kwargs, + ), + turns=normalize_turns( + turns or [], + system_prompt, + ), ) diff --git a/chatlas/_tokens.py b/chatlas/_tokens.py index 085fbd6b..51c3d128 100644 --- a/chatlas/_tokens.py +++ b/chatlas/_tokens.py @@ -91,7 +91,7 @@ class TokenPrice(TypedDict): PricingList: list[TokenPrice] = orjson.loads(f) -def get_token_pricing(provider: Provider) -> dict[str, str | float]: +def get_token_pricing(provider: Provider) -> TokenPrice | dict: """ Get the token pricing for the chat if available based on the prices.json file. From c096a4873f62b7f27e20aece8cfc5e210f7274db Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Fri, 27 Jun 2025 17:08:16 -0400 Subject: [PATCH 25/59] Stashing changes --- chatlas/_chat.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/chatlas/_chat.py b/chatlas/_chat.py index f5e195b2..ae4a8267 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -294,7 +294,7 @@ def get_tokens(self) -> list[TokensDict]: "role": "user", # Implied token count for the user input "tokens": tj.tokens[0] - sum(ti.tokens), - # Total tokens = Total User Tokens for the Trn = Distinct new tokens + context sent + # Total tokens = Total User Tokens for the Turn = Distinct new tokens + context sent "tokens_total": tj.tokens[0], }, { @@ -1693,8 +1693,16 @@ def __str__(self): # TODO: Update this to get tokens and also provide cost add provider and model def __repr__(self): turns = self.get_turns(include_system_prompt=True) + tokens = self.get_tokens() + cost = self.get_cost() + # Sum tokens assistant + print("TURNS:", turns) + print("TOKENS:", tokens) + tokens_asst = 0 + tokens_user = 0 + # Sum tokens user tokens = sum(sum(turn.tokens) for turn in turns if turn.tokens) - res = f"" + res = f"" for turn in turns: res += "\n" + turn.__repr__(indent=2) return res + "\n" From b5963cf0218bff83cc1b0f4ec3224be5bdaab1a3 Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Fri, 27 Jun 2025 17:28:19 -0400 Subject: [PATCH 26/59] Updating repr test --- chatlas/_chat.py | 12 ++++-------- tests/__snapshots__/test_chat.ambr | 2 +- tests/test_chat.py | 5 ++--- 3 files changed, 7 insertions(+), 12 deletions(-) diff --git a/chatlas/_chat.py b/chatlas/_chat.py index 2e96d035..90e7b007 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -2051,14 +2051,10 @@ def __repr__(self): turns = self.get_turns(include_system_prompt=True) tokens = self.get_tokens() cost = self.get_cost() - # Sum tokens assistant - print("TURNS:", turns) - print("TOKENS:", tokens) - tokens_asst = 0 - tokens_user = 0 - # Sum tokens user - tokens = sum(sum(turn.tokens) for turn in turns if turn.tokens) - res = f"" + tokens_asst = sum(u["tokens_total"] for u in tokens if u["role"] == "assistant") + tokens_user = sum(u["tokens_total"] for u in tokens if u["role"] == "user") + + res = f"" for turn in turns: res += "\n" + turn.__repr__(indent=2) return res + "\n" diff --git a/tests/__snapshots__/test_chat.ambr b/tests/__snapshots__/test_chat.ambr index 88b3a275..f910d543 100644 --- a/tests/__snapshots__/test_chat.ambr +++ b/tests/__snapshots__/test_chat.ambr @@ -32,7 +32,7 @@ # --- # name: test_basic_repr ''' - + diff --git a/tests/test_chat.py b/tests/test_chat.py index 515579ff..eb6305c8 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -2,8 +2,6 @@ import tempfile import pytest -from pydantic import BaseModel - from chatlas import ( ChatOpenAI, ContentToolRequest, @@ -12,7 +10,7 @@ Turn, ) from chatlas._chat import ToolFailureWarning -from chatlas._tokens import TokenPrice +from pydantic import BaseModel def test_simple_batch_chat(): @@ -76,6 +74,7 @@ def test_basic_repr(snapshot): Turn("assistant", "2 3", tokens=(15, 5)), ], ) + print(repr(chat)) assert snapshot == repr(chat) From da507eb47215038c8fa3ae4c06be2b05e5073fc5 Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Fri, 27 Jun 2025 17:37:52 -0400 Subject: [PATCH 27/59] Fixing tests --- tests/test_chat.py | 1 - tests/test_tokens.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/test_chat.py b/tests/test_chat.py index eb6305c8..3e0f64dc 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -74,7 +74,6 @@ def test_basic_repr(snapshot): Turn("assistant", "2 3", tokens=(15, 5)), ], ) - print(repr(chat)) assert snapshot == repr(chat) diff --git a/tests/test_tokens.py b/tests/test_tokens.py index 10c56cfe..d1ed1c85 100644 --- a/tests/test_tokens.py +++ b/tests/test_tokens.py @@ -65,10 +65,10 @@ def test_get_token_prices(): assert isinstance(pricing["output"], float) with pytest.warns( - match="Token pricing for the provider 'NOPE' and model 'ABCD' you selected is not available. " + match="Token pricing for the provider 'OpenAI' and model 'ABCD' you selected is not available. " "Please check the provider's documentation." ): - chat = ChatOpenAI(model="ABCD", name="NOPE") + chat = ChatOpenAI(model="ABCD") pricing = get_token_pricing(chat.provider) assert pricing == {} From 44d2de7eff9122a91a44070874715a9220a1eb5b Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Mon, 30 Jun 2025 16:34:20 -0400 Subject: [PATCH 28/59] Updating token_usage() --- chatlas/_chat.py | 11 +++++------ chatlas/_tokens.py | 29 ++++++++++++++++++++--------- tests/test_chat.py | 2 ++ tests/test_tokens.py | 13 +++++++------ 4 files changed, 34 insertions(+), 21 deletions(-) diff --git a/chatlas/_chat.py b/chatlas/_chat.py index 90e7b007..55caec3e 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -339,12 +339,12 @@ def get_cost( # Look up token cost for user and input tokens based on the provider and model turns_tokens = self.get_tokens() if token_price: - input_token_price = token_price[0] / 1000000 - output_token_price = token_price[1] / 1000000 + input_token_price = token_price[0] / 1e6 + output_token_price = token_price[1] / 1e6 else: - price_token = get_token_pricing(self.provider) - input_token_price = price_token["input"] / 1000000 - output_token_price = price_token["output"] / 1000000 + price_token = get_token_pricing(self.provider.name, self.provider.model) + input_token_price = price_token["input"] / 1e6 + output_token_price = price_token["output"] / 1e6 if not input_token_price and not output_token_price: raise KeyError( @@ -2046,7 +2046,6 @@ def __str__(self): res += f"## {icon} {turn.role.capitalize()} turn:\n\n{str(turn)}\n\n" return res - # TODO: Update this to get tokens and also provide cost add provider and model def __repr__(self): turns = self.get_turns(include_system_prompt=True) tokens = self.get_tokens() diff --git a/chatlas/_tokens.py b/chatlas/_tokens.py index 51c3d128..c181c70b 100644 --- a/chatlas/_tokens.py +++ b/chatlas/_tokens.py @@ -21,6 +21,7 @@ class TokenUsage(TypedDict): """ name: str + model: str input: int output: int @@ -30,7 +31,9 @@ def __init__(self): self._lock = Lock() self._tokens: dict[str, TokenUsage] = {} - def log_tokens(self, name: str, input_tokens: int, output_tokens: int) -> None: + def log_tokens( + self, name: str, model: str, input_tokens: int, output_tokens: int + ) -> None: logger.info( f"Provider '{name}' generated a response of {output_tokens} tokens " f"from an input of {input_tokens} tokens." @@ -40,6 +43,7 @@ def log_tokens(self, name: str, input_tokens: int, output_tokens: int) -> None: if name not in self._tokens: self._tokens[name] = { "name": name, + "model": model, "input": input_tokens, "output": output_tokens, } @@ -63,7 +67,7 @@ def tokens_log(provider: Provider, tokens: tuple[int, int]) -> None: """ Log token usage for a provider in a thread-safe manner. """ - _token_counter.log_tokens(provider.name, tokens[0], tokens[1]) + _token_counter.log_tokens(provider.name, provider.model, tokens[0], tokens[1]) def tokens_reset() -> None: @@ -91,7 +95,7 @@ class TokenPrice(TypedDict): PricingList: list[TokenPrice] = orjson.loads(f) -def get_token_pricing(provider: Provider) -> TokenPrice | dict: +def get_token_pricing(name: str, model: str) -> TokenPrice | dict: """ Get the token pricing for the chat if available based on the prices.json file. @@ -108,21 +112,19 @@ def get_token_pricing(provider: Provider) -> TokenPrice | dict: ( item for item in PricingList - if item["provider"] == provider.name and item["model"] == provider.model + if item["provider"] == name and item["model"] == model ), {}, ) - if not result: warnings.warn( - f"Token pricing for the provider '{provider.name}' and model '{provider.model}' you selected is not available. " + f"Token pricing for the provider '{name}' and model '{model}' you selected is not available. " "Please check the provider's documentation." ) return result -# TODO: Add price to this print def token_usage() -> list[TokenUsage] | None: """ Report on token usage in the current session @@ -136,6 +138,15 @@ def token_usage() -> list[TokenUsage] | None: A list of dictionaries with the following keys: "name", "input", and "output". If no tokens have been logged, then None is returned. """ - _token_counter.get_usage() + tokens = _token_counter.get_usage() + if tokens: + for item in tokens: + price = get_token_pricing(item["name"], item["model"]) + if price: + item["cost"] = item["input"] * (price["input"] / 1e6) + item[ + "output" + ] * (price["output"] / 1e6) + else: + item["cost"] = None - return _token_counter.get_usage() + return tokens diff --git a/tests/test_chat.py b/tests/test_chat.py index 3e0f64dc..805c6339 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -10,6 +10,7 @@ Turn, ) from chatlas._chat import ToolFailureWarning +from chatlas._tokens import token_usage from pydantic import BaseModel @@ -74,6 +75,7 @@ def test_basic_repr(snapshot): Turn("assistant", "2 3", tokens=(15, 5)), ], ) + print(token_usage()) assert snapshot == repr(chat) diff --git a/tests/test_tokens.py b/tests/test_tokens.py index d1ed1c85..1bce4b51 100644 --- a/tests/test_tokens.py +++ b/tests/test_tokens.py @@ -1,13 +1,12 @@ +import pytest from chatlas import ChatAnthropic, ChatGoogle, ChatOpenAI, Turn from chatlas._openai import OpenAIAzureProvider, OpenAIProvider from chatlas._tokens import ( + get_token_pricing, token_usage, tokens_log, tokens_reset, - get_token_pricing, ) -import warnings -import pytest def test_tokens_method(): @@ -57,7 +56,7 @@ def test_token_count_method(): def test_get_token_prices(): chat = ChatOpenAI(model="o1-mini") - pricing = get_token_pricing(chat.provider) + pricing = get_token_pricing(chat.provider.name, chat.provider.model) assert pricing["provider"] == "OpenAI" assert pricing["model"] == "o1-mini" assert isinstance(pricing["cached_input"], float) @@ -69,7 +68,7 @@ def test_get_token_prices(): "Please check the provider's documentation." ): chat = ChatOpenAI(model="ABCD") - pricing = get_token_pricing(chat.provider) + pricing = get_token_pricing(chat.provider.name, chat.provider.model) assert pricing == {} @@ -81,7 +80,7 @@ def test_usage_is_none(): def test_can_retrieve_and_log_tokens(): tokens_reset() - provider = OpenAIProvider(api_key="fake_key", model="foo") + provider = OpenAIProvider(api_key="fake_key", model="gpt-4.1") tokens_log(provider, (10, 50)) tokens_log(provider, (0, 10)) usage = token_usage() @@ -90,6 +89,7 @@ def test_can_retrieve_and_log_tokens(): assert usage[0]["name"] == "OpenAI" assert usage[0]["input"] == 10 assert usage[0]["output"] == 60 + assert usage[0]["cost"] is not None provider2 = OpenAIAzureProvider( api_key="fake_key", endpoint="foo", api_version="bar" @@ -102,5 +102,6 @@ def test_can_retrieve_and_log_tokens(): assert usage[1]["name"] == "OpenAIAzure" assert usage[1]["input"] == 5 assert usage[1]["output"] == 25 + assert usage[1]["cost"] is None tokens_reset() From ed9452191bfe158df1451cb589c9b7e15da06aea Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Mon, 30 Jun 2025 16:58:32 -0400 Subject: [PATCH 29/59] Updating tests and fixing type issues: --- chatlas/_anthropic.py | 4 ++-- chatlas/_databricks.py | 2 +- chatlas/_google.py | 2 +- chatlas/_openai.py | 6 +++--- chatlas/_snowflake.py | 2 +- chatlas/_tokens.py | 2 ++ tests/test_tokens.py | 2 +- 7 files changed, 11 insertions(+), 9 deletions(-) diff --git a/chatlas/_anthropic.py b/chatlas/_anthropic.py index d1ebc778..4a9fe763 100644 --- a/chatlas/_anthropic.py +++ b/chatlas/_anthropic.py @@ -193,7 +193,7 @@ def __init__( max_tokens: int = 4096, model: str, api_key: Optional[str] = None, - name: Optional[str] = "Anthropic", + name: str = "Anthropic", kwargs: Optional["ChatClientArgs"] = None, ): super().__init__(name=name, model=model) @@ -738,7 +738,7 @@ def __init__( aws_session_token: str | None, max_tokens: int = 4096, base_url: str | None, - name: Optional[str] = "AnthropicBedrock", + name: str = "AnthropicBedrock", kwargs: Optional["ChatBedrockClientArgs"] = None, ): diff --git a/chatlas/_databricks.py b/chatlas/_databricks.py index 8e4dffab..d75108b7 100644 --- a/chatlas/_databricks.py +++ b/chatlas/_databricks.py @@ -106,7 +106,7 @@ def __init__( self, *, model: str, - name: Optional[str] = "Databricks", + name: str = "Databricks", workspace_client: Optional["WorkspaceClient"] = None, ): try: diff --git a/chatlas/_google.py b/chatlas/_google.py index 2ab6355f..5d3afe25 100644 --- a/chatlas/_google.py +++ b/chatlas/_google.py @@ -166,7 +166,7 @@ def __init__( *, model: str, api_key: str | None, - name: Optional[str] = "Google/Gemini", + name: str = "Google/Gemini", kwargs: Optional["ChatClientArgs"], ): try: diff --git a/chatlas/_openai.py b/chatlas/_openai.py index c6e17982..43ad9fb2 100644 --- a/chatlas/_openai.py +++ b/chatlas/_openai.py @@ -188,7 +188,7 @@ def __init__( model: str, base_url: str = "https://api.openai.com/v1", seed: Optional[int] = None, - name: Optional[str] = "OpenAI", + name: str = "OpenAI", kwargs: Optional["ChatClientArgs"] = None, ): super().__init__(name=name, model=model) @@ -673,11 +673,11 @@ def __init__( self, *, endpoint: Optional[str] = None, - deployment_id: Optional[str] = None, + deployment_id: str, api_version: Optional[str] = None, api_key: Optional[str] = None, seed: int | None = None, - name: Optional[str] = "OpenAIAzure", + name: str = "OpenAIAzure", model: Optional[str] = "UnusedValue", kwargs: Optional["ChatAzureClientArgs"] = None, ): diff --git a/chatlas/_snowflake.py b/chatlas/_snowflake.py index 1ed474e4..95628438 100644 --- a/chatlas/_snowflake.py +++ b/chatlas/_snowflake.py @@ -175,7 +175,7 @@ def __init__( password: Optional[str], private_key_file: Optional[str], private_key_file_pwd: Optional[str], - name: Optional[str] = "Snowflake", + name: str = "Snowflake", kwargs: Optional[dict[str, "str | int"]], ): try: diff --git a/chatlas/_tokens.py b/chatlas/_tokens.py index c181c70b..bc5267a3 100644 --- a/chatlas/_tokens.py +++ b/chatlas/_tokens.py @@ -24,6 +24,7 @@ class TokenUsage(TypedDict): model: str input: int output: int + cost: float | None class ThreadSafeTokenCounter: @@ -46,6 +47,7 @@ def log_tokens( "model": model, "input": input_tokens, "output": output_tokens, + "cost": None, } else: self._tokens[name]["input"] += input_tokens diff --git a/tests/test_tokens.py b/tests/test_tokens.py index 1bce4b51..d231f9f9 100644 --- a/tests/test_tokens.py +++ b/tests/test_tokens.py @@ -92,7 +92,7 @@ def test_can_retrieve_and_log_tokens(): assert usage[0]["cost"] is not None provider2 = OpenAIAzureProvider( - api_key="fake_key", endpoint="foo", api_version="bar" + api_key="fake_key", endpoint="foo", deployment_id="test", api_version="bar" ) tokens_log(provider2, (5, 25)) From 549e30cd03586c92bd89456e66583cbf66a89d58 Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Tue, 1 Jul 2025 10:07:30 -0400 Subject: [PATCH 30/59] Removing duplicate setting in github etc. models --- chatlas/_github.py | 3 --- chatlas/_groq.py | 3 --- chatlas/_perplexity.py | 8 -------- chatlas/_tokens.py | 3 ++- 4 files changed, 2 insertions(+), 15 deletions(-) diff --git a/chatlas/_github.py b/chatlas/_github.py index 7f27ce83..03d57b63 100644 --- a/chatlas/_github.py +++ b/chatlas/_github.py @@ -133,9 +133,6 @@ def ChatGithub( if isinstance(seed, MISSING_TYPE): seed = 1014 if is_testing() else None - if model is None: - model = log_model_default("gpt-4o") - return Chat( provider=OpenAIProvider( api_key=api_key, diff --git a/chatlas/_groq.py b/chatlas/_groq.py index 0322fc4c..ab897000 100644 --- a/chatlas/_groq.py +++ b/chatlas/_groq.py @@ -128,9 +128,6 @@ def ChatGroq( if isinstance(seed, MISSING_TYPE): seed = 1014 if is_testing() else None - if model is None: - model = log_model_default("gpt-4o") - return Chat( provider=OpenAIProvider( api_key=api_key, diff --git a/chatlas/_perplexity.py b/chatlas/_perplexity.py index 5d33a468..ea738de7 100644 --- a/chatlas/_perplexity.py +++ b/chatlas/_perplexity.py @@ -131,17 +131,9 @@ def ChatPerplexity( if api_key is None: api_key = os.getenv("PERPLEXITY_API_KEY") - if model is None: - model = log_model_default("gpt-4o") - if api_key is None: - api_key = os.getenv("GITHUB_PAT") - if isinstance(seed, MISSING_TYPE): seed = 1014 if is_testing() else None - if model is None: - model = log_model_default("gpt-4o") - return Chat( provider=OpenAIProvider( api_key=api_key, diff --git a/chatlas/_tokens.py b/chatlas/_tokens.py index bc5267a3..df5855d5 100644 --- a/chatlas/_tokens.py +++ b/chatlas/_tokens.py @@ -137,7 +137,8 @@ def token_usage() -> list[TokenUsage] | None: Returns ------- list[TokenUsage] | None - A list of dictionaries with the following keys: "name", "input", and "output". + A list of dictionaries with the following keys: "name", "input", "output", and "cost". + If no cost data is available for the name/model combination chosen, then "cost" will be None. If no tokens have been logged, then None is returned. """ tokens = _token_counter.get_usage() From 271e927e7505232132b1763d1b37b82f2f1fa539 Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Tue, 1 Jul 2025 10:10:29 -0400 Subject: [PATCH 31/59] Added properties docstring --- chatlas/_provider.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/chatlas/_provider.py b/chatlas/_provider.py index e37055ad..dd1a79aa 100644 --- a/chatlas/_provider.py +++ b/chatlas/_provider.py @@ -40,17 +40,22 @@ class Provider( directly. """ - # TODO: Add docstring for props def __init__(self, *, name: str, model: str): self._name = name self._model = model @property def name(self): + """ + Get the name of the provider + """ return self._name @property def model(self): + """ + Get the model used by the provider + """ return self._model @overload From 8300471a527da4d7134d6c1f1c555c62e359154a Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Tue, 1 Jul 2025 10:25:24 -0400 Subject: [PATCH 32/59] Updated changelog --- CHANGELOG.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index d2a4f516..71ccd78b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### New features + +* Token pricing can now be looked up from our `prices.json` price list using `get_token_pricing()`. (#106) +* `Chat`'s representation will now include cost information if it can be calculated. (#106) +* `Chat` gains new `.get_cost()` method, making it easier to get the estimated cost of your chat. Use our pricing list or bring your own token prices. (#106) * `Chat` gains new `.register_mcp_tools_http_stream_async()` and `.register_mcp_tools_stdio_async()` methods, making it easy to register tools from a [MCP server](https://modelcontextprotocol.io/). (#39) * `Chat` gains new `.get_tools()`/`.set_tools()` methods -- making it possible to inspect and remove tools. (#39) * Tool functions passed to `.register_tool()` can now `yield` numerous results. (#39) @@ -20,6 +24,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changes +* `token_usage()` will include cost if it can be calculated. (#106) +* `log_tokens()` now requires a model and will provide cost information if it can be calculated. (#106) +* `Chat`'s `.tokens()` methods have been removed in favor of `.get_tokens()` which returns both cumulative tokens in the turn and discrete tokens. (#106) +* `Provider`s now require a name and model. Sensible defaults are supplied for our implemented providers. They are set during initialization as properties. (#106) * `Tool`'s constructor no longer takes a function as input. Use the new `.from_func()` method instead to create a `Tool` from a function. (#39) * `.register_tool()` now throws an exception when the tool has the same name as an already registered tool. Set the new `force` parameter to `True` to force the registration. (#39) From 2b43c7745a103396e519086fb4849a8178923bb1 Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Tue, 1 Jul 2025 10:26:07 -0400 Subject: [PATCH 33/59] Updating ollama chat --- chatlas/_ollama.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/chatlas/_ollama.py b/chatlas/_ollama.py index 26d4f66f..b138d016 100644 --- a/chatlas/_ollama.py +++ b/chatlas/_ollama.py @@ -106,9 +106,6 @@ def ChatOllama( if isinstance(seed, MISSING_TYPE): seed = 1014 if is_testing() else None - if model is None: - model = log_model_default("gpt-4o") - return Chat( provider=OpenAIProvider( api_key="ollama", # ignored From b29def363eeea16a88752ee0c8dc9c836d604ec6 Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Tue, 1 Jul 2025 10:54:54 -0400 Subject: [PATCH 34/59] Fix unused import --- chatlas/_ollama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chatlas/_ollama.py b/chatlas/_ollama.py index b138d016..a351c18c 100644 --- a/chatlas/_ollama.py +++ b/chatlas/_ollama.py @@ -7,7 +7,7 @@ import orjson from ._chat import Chat -from ._openai import OpenAIProvider, log_model_default, normalize_turns +from ._openai import OpenAIProvider, normalize_turns from ._turn import Turn from ._utils import MISSING_TYPE, is_testing From 1a7942c737068a93ed59fcc5e966ea58523b7fe9 Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Tue, 1 Jul 2025 11:55:39 -0400 Subject: [PATCH 35/59] Adding test workflow --- .github/workflows/update-pricing.yml | 39 ++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 .github/workflows/update-pricing.yml diff --git a/.github/workflows/update-pricing.yml b/.github/workflows/update-pricing.yml new file mode 100644 index 00000000..7c017493 --- /dev/null +++ b/.github/workflows/update-pricing.yml @@ -0,0 +1,39 @@ +name: Update Pricing + +on: + workflow_dispatch: + +jobs: + check-pricing: + name: Check for pricing updates in Ellmer + runs-on: ubuntu-latest + + steps: + + - name: Checkout current prices.json in chatlas + uses: actions/checkout@v4 + with: + sparse-checkout: /chatlas/data/prices.json + sparse-checkout-cone-mode: false + path: main + + - name: Get Ellmer prices.json + uses: actions/checkout@v4 + with: + sparse-checkout: /data-raw/prices.json + sparse-checkout-cone-mode: false + repository: https://github.com/tidyverse/ellmer.git + path: ellmer + + - name: Check for differences + run: | + echo "Checking diff between prices.json" + git diff --no-index https://github.com/posit-dev/chatlas/blob/feature/get_cost/chatlas/data/prices.json https://github.com/tidyverse/ellmer/blob/main/data-raw/prices.json + + + + + + + + \ No newline at end of file From ad7fd6e81fb521144120fa3e7150410a05f39b95 Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Tue, 1 Jul 2025 14:12:16 -0400 Subject: [PATCH 36/59] Updating to test on pull so can view workflow --- .github/workflows/update-pricing.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/update-pricing.yml b/.github/workflows/update-pricing.yml index 7c017493..6a49c617 100644 --- a/.github/workflows/update-pricing.yml +++ b/.github/workflows/update-pricing.yml @@ -2,6 +2,8 @@ name: Update Pricing on: workflow_dispatch: + pull_request: + types: [opened, ready_for_review] jobs: check-pricing: From c607b67ccdddbb41db76c2fb8c005392b17c1e2b Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Tue, 1 Jul 2025 14:14:36 -0400 Subject: [PATCH 37/59] Updating workflow triggers --- .github/workflows/update-pricing.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/update-pricing.yml b/.github/workflows/update-pricing.yml index 6a49c617..ae78785e 100644 --- a/.github/workflows/update-pricing.yml +++ b/.github/workflows/update-pricing.yml @@ -3,7 +3,7 @@ name: Update Pricing on: workflow_dispatch: pull_request: - types: [opened, ready_for_review] + types: [opened, edited, ready_for_review] jobs: check-pricing: From 1542d0c52b18aba847924cc8fa2ea881927f190d Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Tue, 1 Jul 2025 14:16:08 -0400 Subject: [PATCH 38/59] Updating workflow triggers --- .github/workflows/update-pricing.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/update-pricing.yml b/.github/workflows/update-pricing.yml index ae78785e..63a38647 100644 --- a/.github/workflows/update-pricing.yml +++ b/.github/workflows/update-pricing.yml @@ -3,7 +3,7 @@ name: Update Pricing on: workflow_dispatch: pull_request: - types: [opened, edited, ready_for_review] + types: [opened, synchronize, reopened, ready_for_review] jobs: check-pricing: From a995e6e8160c612b0117edc5211a53b28603d727 Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Tue, 1 Jul 2025 14:18:16 -0400 Subject: [PATCH 39/59] Updating to try straight diff --- .github/workflows/update-pricing.yml | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/.github/workflows/update-pricing.yml b/.github/workflows/update-pricing.yml index 63a38647..abab120c 100644 --- a/.github/workflows/update-pricing.yml +++ b/.github/workflows/update-pricing.yml @@ -12,21 +12,6 @@ jobs: steps: - - name: Checkout current prices.json in chatlas - uses: actions/checkout@v4 - with: - sparse-checkout: /chatlas/data/prices.json - sparse-checkout-cone-mode: false - path: main - - - name: Get Ellmer prices.json - uses: actions/checkout@v4 - with: - sparse-checkout: /data-raw/prices.json - sparse-checkout-cone-mode: false - repository: https://github.com/tidyverse/ellmer.git - path: ellmer - - name: Check for differences run: | echo "Checking diff between prices.json" From f9886271d08a11ce777b738e1bd4baaab833119e Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Tue, 1 Jul 2025 14:34:02 -0400 Subject: [PATCH 40/59] Adding back previous and testing --- .github/workflows/update-pricing.yml | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/.github/workflows/update-pricing.yml b/.github/workflows/update-pricing.yml index abab120c..2529b176 100644 --- a/.github/workflows/update-pricing.yml +++ b/.github/workflows/update-pricing.yml @@ -12,10 +12,25 @@ jobs: steps: + - name: Checkout current prices.json in chatlas + uses: actions/checkout@v4 + with: + sparse-checkout: /chatlas/data/prices.json + sparse-checkout-cone-mode: false + path: main + + - name: Get Ellmer prices.json + uses: actions/checkout@v4 + with: + sparse-checkout: /data-raw/prices.json + sparse-checkout-cone-mode: false + repository: tidyverse/ellmer + path: ellmer + - name: Check for differences run: | echo "Checking diff between prices.json" - git diff --no-index https://github.com/posit-dev/chatlas/blob/feature/get_cost/chatlas/data/prices.json https://github.com/tidyverse/ellmer/blob/main/data-raw/prices.json + git diff --no-index ellmer/data-raw/prices.json main/chatlas/data/prices.json From 7fd717acc023b0cbfd6f4c70a1a54807e6d40418 Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Tue, 1 Jul 2025 14:43:16 -0400 Subject: [PATCH 41/59] Update action + correct extra line --- .github/workflows/update-pricing.yml | 5 +++-- chatlas/_anthropic.py | 1 - 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/update-pricing.yml b/.github/workflows/update-pricing.yml index 2529b176..6fa1e121 100644 --- a/.github/workflows/update-pricing.yml +++ b/.github/workflows/update-pricing.yml @@ -30,8 +30,9 @@ jobs: - name: Check for differences run: | echo "Checking diff between prices.json" - git diff --no-index ellmer/data-raw/prices.json main/chatlas/data/prices.json - + git diff --no-index --stat ellmer/data-raw/prices.json main/chatlas/data/prices.json + if [[ -n $(git diff chatlas/types) ]]; then + echo "Changes detected:" diff --git a/chatlas/_anthropic.py b/chatlas/_anthropic.py index 4a9fe763..a48db0ca 100644 --- a/chatlas/_anthropic.py +++ b/chatlas/_anthropic.py @@ -186,7 +186,6 @@ def ChatAnthropic( class AnthropicProvider(Provider[Message, RawMessageStreamEvent, Message]): - def __init__( self, *, From 05b7c2c522408b5acee52aa76528acbffd825d09 Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Tue, 1 Jul 2025 14:54:45 -0400 Subject: [PATCH 42/59] Minor corrections --- chatlas/_chat.py | 20 +++++++++----------- chatlas/_databricks.py | 2 -- chatlas/_tokens.py | 2 +- tests/test_chat.py | 1 - 4 files changed, 10 insertions(+), 15 deletions(-) diff --git a/chatlas/_chat.py b/chatlas/_chat.py index 55caec3e..caad4a85 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -354,17 +354,6 @@ def get_cost( if len(turns_tokens) == 0: return 0.0 - if options == "last": - last_turn = turns_tokens[len(turns_tokens) - 1] - acc = 0.0 - if last_turn["role"] == "assistant": - acc += last_turn["tokens"] * output_token_price - elif last_turn["role"] == "user": - acc += last_turn["tokens_total"] * input_token_price - else: - raise ValueError(f"Unrecognized role type { last_turn['role'] }") - return acc - if options == "all": asst_tokens = sum( u["tokens_total"] for u in turns_tokens if u["role"] == "assistant" @@ -377,6 +366,15 @@ def get_cost( ) return cost + last_turn = turns_tokens[-1] + if last_turn["role"] == "assistant": + return last_turn["tokens"] * output_token_price + if last_turn["role"] == "user": + return last_turn["tokens_total"] * input_token_price + raise ValueError( + f"Expected last turn to have a role of 'user' or `'assistant'`, not '{ last_turn['role'] }'" + ) + def token_count( self, *args: Content | str, diff --git a/chatlas/_databricks.py b/chatlas/_databricks.py index d75108b7..a0a3580e 100644 --- a/chatlas/_databricks.py +++ b/chatlas/_databricks.py @@ -2,8 +2,6 @@ from typing import TYPE_CHECKING, Optional -from databricks.sdk import WorkspaceClient - from ._chat import Chat from ._logging import log_model_default from ._openai import OpenAIProvider diff --git a/chatlas/_tokens.py b/chatlas/_tokens.py index df5855d5..ef627557 100644 --- a/chatlas/_tokens.py +++ b/chatlas/_tokens.py @@ -65,7 +65,7 @@ def get_usage(self) -> list[TokenUsage] | None: _token_counter = ThreadSafeTokenCounter() -def tokens_log(provider: Provider, tokens: tuple[int, int]) -> None: +def tokens_log(provider: "Provider", tokens: tuple[int, int]) -> None: """ Log token usage for a provider in a thread-safe manner. """ diff --git a/tests/test_chat.py b/tests/test_chat.py index 805c6339..5797808d 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -75,7 +75,6 @@ def test_basic_repr(snapshot): Turn("assistant", "2 3", tokens=(15, 5)), ], ) - print(token_usage()) assert snapshot == repr(chat) From 4159fcb5f8792d9ab5d4e4fb42b7220430653b19 Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Tue, 1 Jul 2025 15:00:53 -0400 Subject: [PATCH 43/59] Updating github action + test --- .github/workflows/update-pricing.yml | 4 +++- chatlas/_tokens.py | 8 ++++---- tests/test_tokens.py | 2 +- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/.github/workflows/update-pricing.yml b/.github/workflows/update-pricing.yml index 6fa1e121..1c016bc9 100644 --- a/.github/workflows/update-pricing.yml +++ b/.github/workflows/update-pricing.yml @@ -33,7 +33,9 @@ jobs: git diff --no-index --stat ellmer/data-raw/prices.json main/chatlas/data/prices.json if [[ -n $(git diff chatlas/types) ]]; then echo "Changes detected:" - + echo "::error::Ellmer's prices.json does not match the current Chatlas prices.json" + exit 1 + fi diff --git a/chatlas/_tokens.py b/chatlas/_tokens.py index ef627557..d03bc22c 100644 --- a/chatlas/_tokens.py +++ b/chatlas/_tokens.py @@ -97,13 +97,13 @@ class TokenPrice(TypedDict): PricingList: list[TokenPrice] = orjson.loads(f) -def get_token_pricing(name: str, model: str) -> TokenPrice | dict: +def get_token_pricing(name: str, model: str) -> TokenPrice | None: """ Get the token pricing for the chat if available based on the prices.json file. Returns ------- - dict[str, str | float] + TokenPrice | None A dictionary with the token pricing for the chat. The keys are: - `"provider"`: The provider name (e.g., "OpenAI", "Anthropic", etc.). - `model`: The model name (e.g., "gpt-3.5-turbo", "claude-2", etc.). @@ -116,9 +116,9 @@ def get_token_pricing(name: str, model: str) -> TokenPrice | dict: for item in PricingList if item["provider"] == name and item["model"] == model ), - {}, + None, ) - if not result: + if result is None: warnings.warn( f"Token pricing for the provider '{name}' and model '{model}' you selected is not available. " "Please check the provider's documentation." diff --git a/tests/test_tokens.py b/tests/test_tokens.py index d231f9f9..a8f517ed 100644 --- a/tests/test_tokens.py +++ b/tests/test_tokens.py @@ -69,7 +69,7 @@ def test_get_token_prices(): ): chat = ChatOpenAI(model="ABCD") pricing = get_token_pricing(chat.provider.name, chat.provider.model) - assert pricing == {} + assert pricing is None def test_usage_is_none(): From 88d0aae2bc800f53c15ff79af0247677f4c341e1 Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Tue, 1 Jul 2025 15:06:52 -0400 Subject: [PATCH 44/59] Updating diff command --- .github/workflows/update-pricing.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/update-pricing.yml b/.github/workflows/update-pricing.yml index 1c016bc9..b6f50b97 100644 --- a/.github/workflows/update-pricing.yml +++ b/.github/workflows/update-pricing.yml @@ -31,7 +31,7 @@ jobs: run: | echo "Checking diff between prices.json" git diff --no-index --stat ellmer/data-raw/prices.json main/chatlas/data/prices.json - if [[ -n $(git diff chatlas/types) ]]; then + if [[ -n $(git diff --no-index --stat ellmer/data-raw/prices.json main/chatlas/data/prices.json) ]]; then echo "Changes detected:" echo "::error::Ellmer's prices.json does not match the current Chatlas prices.json" exit 1 From 20b5ba077ec12a8ab201846a6cb8b4dee631db91 Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Tue, 1 Jul 2025 15:27:38 -0400 Subject: [PATCH 45/59] Testing a different prices.json --- chatlas/data/prices.json | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/chatlas/data/prices.json b/chatlas/data/prices.json index d9cd7eb9..9b0aaba4 100644 --- a/chatlas/data/prices.json +++ b/chatlas/data/prices.json @@ -1,4 +1,11 @@ [ + { + "provider": "TEST CHANGE", + "model": "gpt-4.5-preview", + "cached_input": 37.5, + "input": 75, + "output": 150 + }, { "provider": "OpenAI", "model": "gpt-4.5-preview", @@ -261,4 +268,4 @@ "input": 0.3, "output": 0.075 } -] +] \ No newline at end of file From b2162e69222a1c4bc3493ee516820bd9e458152e Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Tue, 1 Jul 2025 15:32:37 -0400 Subject: [PATCH 46/59] Prior test worked --- chatlas/data/prices.json | 7 ------- 1 file changed, 7 deletions(-) diff --git a/chatlas/data/prices.json b/chatlas/data/prices.json index 9b0aaba4..051ffa0c 100644 --- a/chatlas/data/prices.json +++ b/chatlas/data/prices.json @@ -1,11 +1,4 @@ [ - { - "provider": "TEST CHANGE", - "model": "gpt-4.5-preview", - "cached_input": 37.5, - "input": 75, - "output": 150 - }, { "provider": "OpenAI", "model": "gpt-4.5-preview", From fe4669c5a78057361014e0c860f326a00f1773d0 Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Wed, 2 Jul 2025 09:41:07 -0400 Subject: [PATCH 47/59] Update tokens pricing scripts --- chatlas/_tokens.py | 53 +++++++++++++++++++++++++++++--------------- tests/test_tokens.py | 12 ++++++++++ 2 files changed, 47 insertions(+), 18 deletions(-) diff --git a/chatlas/_tokens.py b/chatlas/_tokens.py index d03bc22c..4a416348 100644 --- a/chatlas/_tokens.py +++ b/chatlas/_tokens.py @@ -47,11 +47,18 @@ def log_tokens( "model": model, "input": input_tokens, "output": output_tokens, - "cost": None, + "cost": compute_price(name, model, input_tokens, output_tokens), } else: self._tokens[name]["input"] += input_tokens self._tokens[name]["output"] += output_tokens + price = compute_price(name, model, input_tokens, output_tokens) + if price is not None: + cost = self._tokens[name]["cost"] + if cost is None: + self._tokens[name]["cost"] = price + else: + self._tokens[name]["cost"] = cost + price def get_usage(self) -> list[TokenUsage] | None: with self._lock: @@ -83,6 +90,13 @@ def tokens_reset() -> None: class TokenPrice(TypedDict): """ Defines the necessary information to look up pricing for a given turn. + + The keys are: + - `"provider"`: The provider name (e.g., "OpenAI", "Anthropic", etc.). + - `"model"`: The model name (e.g., "gpt-3.5-turbo", "claude-2", etc.). + - `"cached_input"`: The cost per user token in USD per million tokens for cached input. + - `"input"`: The cost per user token in USD per million tokens. + - `"output"`: The cost per assistant token in USD per million tokens. """ provider: str @@ -104,11 +118,6 @@ def get_token_pricing(name: str, model: str) -> TokenPrice | None: Returns ------- TokenPrice | None - A dictionary with the token pricing for the chat. The keys are: - - `"provider"`: The provider name (e.g., "OpenAI", "Anthropic", etc.). - - `model`: The model name (e.g., "gpt-3.5-turbo", "claude-2", etc.). - - `"input"`: The cost per user token in USD per million tokens. - - `"output"`: The cost per assistant token in USD per million tokens. """ result = next( ( @@ -127,6 +136,25 @@ def get_token_pricing(name: str, model: str) -> TokenPrice | None: return result +def compute_price( + name: str, model: str, input_tokens: int, output_tokens: int +) -> float | None: + """ + Compute the cost of a turn. + + Returns + ------- + float | None + The cost of the turn in USD, or None if the cost could not be calculated. + """ + price = get_token_pricing(name, model) + if price is None: + return None + input_price = input_tokens * (price["input"] / 1e6) + output_price = output_tokens * (price["output"] / 1e6) + return input_price + output_price + + def token_usage() -> list[TokenUsage] | None: """ Report on token usage in the current session @@ -141,15 +169,4 @@ def token_usage() -> list[TokenUsage] | None: If no cost data is available for the name/model combination chosen, then "cost" will be None. If no tokens have been logged, then None is returned. """ - tokens = _token_counter.get_usage() - if tokens: - for item in tokens: - price = get_token_pricing(item["name"], item["model"]) - if price: - item["cost"] = item["input"] * (price["input"] / 1e6) + item[ - "output" - ] * (price["output"] / 1e6) - else: - item["cost"] = None - - return tokens + return _token_counter.get_usage() diff --git a/tests/test_tokens.py b/tests/test_tokens.py index a8f517ed..7f0b3c00 100644 --- a/tests/test_tokens.py +++ b/tests/test_tokens.py @@ -2,6 +2,7 @@ from chatlas import ChatAnthropic, ChatGoogle, ChatOpenAI, Turn from chatlas._openai import OpenAIAzureProvider, OpenAIProvider from chatlas._tokens import ( + compute_price, get_token_pricing, token_usage, tokens_log, @@ -72,6 +73,17 @@ def test_get_token_prices(): assert pricing is None +def test_compute_price(): + chat = ChatOpenAI(model="o1-mini") + price = compute_price(chat.provider.name, chat.provider.model, 10, 50) + assert isinstance(price, float) + assert price > 0 + + chat = ChatOpenAI(model="ABCD") + price = compute_price(chat.provider.name, chat.provider.model, 10, 50) + assert price is None + + def test_usage_is_none(): tokens_reset() assert token_usage() is None From 83f5c82bf2ea576271837e53c13b406017a984be Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Wed, 2 Jul 2025 10:02:32 -0400 Subject: [PATCH 48/59] PR edits --- CHANGELOG.md | 13 ++++++++----- chatlas/_tokens.py | 19 +++++++++++-------- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 71ccd78b..dfbf7de4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,9 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### New features - -* Token pricing can now be looked up from our `prices.json` price list using `get_token_pricing()`. (#106) -* `Chat`'s representation will now include cost information if it can be calculated. (#106) +* Token pricing can now be looked up from our internally maintained price list using `get_token_pricing()`. (#106) * `Chat` gains new `.get_cost()` method, making it easier to get the estimated cost of your chat. Use our pricing list or bring your own token prices. (#106) * `Chat` gains new `.register_mcp_tools_http_stream_async()` and `.register_mcp_tools_stdio_async()` methods, making it easy to register tools from a [MCP server](https://modelcontextprotocol.io/). (#39) * `Chat` gains new `.get_tools()`/`.set_tools()` methods -- making it possible to inspect and remove tools. (#39) @@ -24,15 +22,20 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changes -* `token_usage()` will include cost if it can be calculated. (#106) -* `log_tokens()` now requires a model and will provide cost information if it can be calculated. (#106) +#### Breaking Changes + * `Chat`'s `.tokens()` methods have been removed in favor of `.get_tokens()` which returns both cumulative tokens in the turn and discrete tokens. (#106) * `Provider`s now require a name and model. Sensible defaults are supplied for our implemented providers. They are set during initialization as properties. (#106) + +#### Other Changes + * `Tool`'s constructor no longer takes a function as input. Use the new `.from_func()` method instead to create a `Tool` from a function. (#39) * `.register_tool()` now throws an exception when the tool has the same name as an already registered tool. Set the new `force` parameter to `True` to force the registration. (#39) ### Improvements +* `Chat`'s representation will now include cost information if it can be calculated. (#106) +* `token_usage()` will include cost if it can be calculated. (#106) * `ChatOpenAI()` and `ChatGithub()` now default to GPT 4.1 (instead of 4o). (#115) * `ChatAnthropic()` now supports `content_image_url()`. (#112) * HTML styling improvements for `ContentToolResult` and `ContentToolRequest`. (#39) diff --git a/chatlas/_tokens.py b/chatlas/_tokens.py index 4a416348..b1fa45de 100644 --- a/chatlas/_tokens.py +++ b/chatlas/_tokens.py @@ -90,20 +90,18 @@ def tokens_reset() -> None: class TokenPrice(TypedDict): """ Defines the necessary information to look up pricing for a given turn. - - The keys are: - - `"provider"`: The provider name (e.g., "OpenAI", "Anthropic", etc.). - - `"model"`: The model name (e.g., "gpt-3.5-turbo", "claude-2", etc.). - - `"cached_input"`: The cost per user token in USD per million tokens for cached input. - - `"input"`: The cost per user token in USD per million tokens. - - `"output"`: The cost per assistant token in USD per million tokens. """ provider: str + """The provider name (e.g., "OpenAI", "Anthropic", etc.)""" model: str + """The model name (e.g., "gpt-3.5-turbo", "claude-2", etc.)""" cached_input: float + """The cost per user token in USD per million tokens for cached input""" input: float + """The cost per user token in USD per million tokens""" output: float + """The cost per assistant token in USD per million tokens""" # Load in pricing pulled from ellmer @@ -113,7 +111,12 @@ class TokenPrice(TypedDict): def get_token_pricing(name: str, model: str) -> TokenPrice | None: """ - Get the token pricing for the chat if available based on the prices.json file. + Get token pricing information given a provider name and model + + Note + ---- + Only a subset of providers and models and currently supported. + The pricing information derives from ellmer. Returns ------- From 965128334b63f243431fb42683db52beb2b3c385 Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Wed, 2 Jul 2025 10:52:54 -0400 Subject: [PATCH 49/59] Adding additional PR updates --- chatlas/_chat.py | 15 ++++++++++----- tests/test_chat.py | 23 +++++++++++++++++++++++ 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/chatlas/_chat.py b/chatlas/_chat.py index 139f6c54..97e04560 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -368,17 +368,22 @@ def get_cost( output_token_price = token_price[1] / 1e6 else: price_token = get_token_pricing(self.provider.name, self.provider.model) + if not price_token: + raise KeyError( + f"We could not locate pricing information for model '{ self.provider.model }' from provider '{ self.provider.name }'. " + "If you know the pricing for this model, specify it in `token_price`." + ) input_token_price = price_token["input"] / 1e6 output_token_price = price_token["output"] / 1e6 - if not input_token_price and not output_token_price: - raise KeyError( - f"We could not locate provider ' { self.provider.name } ' and model '{ self.provider.model } ' in our pricing information. Please supply your own if you wish to use the cost function." - ) - if len(turns_tokens) == 0: return 0.0 + if options not in ("all", "last"): + raise ValueError( + f"Expected `options` to be one of 'all' or 'last', not '{ options }'" + ) + if options == "all": asst_tokens = sum( u["tokens_total"] for u in turns_tokens if u["role"] == "assistant" diff --git a/tests/test_chat.py b/tests/test_chat.py index 29a1b36a..8a942962 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -304,6 +304,14 @@ def test_get_cost(): ] ) + with pytest.raises( + ValueError, + match=re.escape( + "Expected `options` to be one of 'all' or 'last', not 'bad_option'" + ), + ): + chat.get_cost(options="bad_option") + # Checking that these have the right form vs. the actual calculation because the price may change cost = chat.get_cost(options="all") assert isinstance(cost, float) @@ -322,3 +330,18 @@ def test_get_cost(): last2 = chat.get_cost(options="last", token_price=byoc) assert last2 == 0.00003 + + chat2 = ChatOpenAI(api_key="fake_key", model="BADBAD") + chat2.set_turns( + [ + Turn(role="user", contents="Hi"), + Turn(role="assistant", contents="Hello", tokens=(2, 10)), + Turn(role="user", contents="Hi"), + Turn(role="assistant", contents="Hello", tokens=(14, 10)), + ] + ) + with pytest.raises( + KeyError, + match="We could not locate pricing information for model 'BADBAD' from provider 'OpenAI'. If you know the pricing for this model, specify it in `token_price`.", + ): + chat2.get_cost(options="all") From 667f06dd31b543673b9c35943af8281d44b4bf81 Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Wed, 2 Jul 2025 10:54:23 -0400 Subject: [PATCH 50/59] Adding back whitespace --- chatlas/data/prices.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chatlas/data/prices.json b/chatlas/data/prices.json index 051ffa0c..d9cd7eb9 100644 --- a/chatlas/data/prices.json +++ b/chatlas/data/prices.json @@ -261,4 +261,4 @@ "input": 0.3, "output": 0.075 } -] \ No newline at end of file +] From 0ec80155d83af3affe62cdcf55334d09d954e142 Mon Sep 17 00:00:00 2001 From: E Nelson Date: Wed, 2 Jul 2025 11:39:57 -0400 Subject: [PATCH 51/59] Update .github/workflows/update-pricing.yml Co-authored-by: Carson Sievert --- .github/workflows/update-pricing.yml | 6 ------ 1 file changed, 6 deletions(-) diff --git a/.github/workflows/update-pricing.yml b/.github/workflows/update-pricing.yml index b6f50b97..9851e230 100644 --- a/.github/workflows/update-pricing.yml +++ b/.github/workflows/update-pricing.yml @@ -36,9 +36,3 @@ jobs: echo "::error::Ellmer's prices.json does not match the current Chatlas prices.json" exit 1 fi - - - - - - \ No newline at end of file From 7834ab33d8528c43748763994d9e1925e9c45c74 Mon Sep 17 00:00:00 2001 From: E Nelson Date: Wed, 2 Jul 2025 11:40:07 -0400 Subject: [PATCH 52/59] Update CHANGELOG.md Co-authored-by: Carson Sievert --- CHANGELOG.md | 1 - 1 file changed, 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 36fc4bba..206eb67e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,7 +21,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * A `Tool` can now be constructed from a pre-existing tool schema (via a new `__init__` method). (#39) * The `Chat.app()` method gains a `host` parameter. (#122) * `ChatGithub()` now supports the more standard `GITHUB_TOKEN` environment variable for storing the API key. (#123) -* Token pricing can now be looked up from our internally maintained price list using `get_token_pricing()`. (#106) ### Breaking changes From f93966407e235b02d5d2700bcef81cfd494081be Mon Sep 17 00:00:00 2001 From: E Nelson Date: Wed, 2 Jul 2025 11:40:18 -0400 Subject: [PATCH 53/59] Update CHANGELOG.md Co-authored-by: Carson Sievert --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 206eb67e..6a007298 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,7 +31,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 #### Breaking Changes * `Chat`'s `.tokens()` methods have been removed in favor of `.get_tokens()` which returns both cumulative tokens in the turn and discrete tokens. (#106) -* `Provider`s now require a name and model. Sensible defaults are supplied for our implemented providers. They are set during initialization as properties. (#106) +* The base `Provider` class now includes a `name` and `model` property. In order for them to work properly, implementations should pass a `name` and `model` along to the `__init__()` method. (#106) #### Other Changes From fbeb5c8d11c75e4b02f12e4546b8e95d70805c15 Mon Sep 17 00:00:00 2001 From: E Nelson Date: Wed, 2 Jul 2025 11:40:26 -0400 Subject: [PATCH 54/59] Update chatlas/_openai.py Co-authored-by: Carson Sievert --- chatlas/_openai.py | 1 - 1 file changed, 1 deletion(-) diff --git a/chatlas/_openai.py b/chatlas/_openai.py index 1632fbb6..f574b0b3 100644 --- a/chatlas/_openai.py +++ b/chatlas/_openai.py @@ -169,7 +169,6 @@ def ChatOpenAI( class OpenAIProvider(Provider[ChatCompletion, ChatCompletionChunk, ChatCompletionDict]): - def __init__( self, *, From 9542bbc3b4213df2425871457ba9137fdc06c249 Mon Sep 17 00:00:00 2001 From: E Nelson Date: Wed, 2 Jul 2025 11:48:59 -0400 Subject: [PATCH 55/59] Update chatlas/_openai.py Co-authored-by: Carson Sievert --- chatlas/_openai.py | 1 - 1 file changed, 1 deletion(-) diff --git a/chatlas/_openai.py b/chatlas/_openai.py index f574b0b3..d336e3c1 100644 --- a/chatlas/_openai.py +++ b/chatlas/_openai.py @@ -645,7 +645,6 @@ def ChatAzureOpenAI( class OpenAIAzureProvider(OpenAIProvider): - def __init__( self, *, From 51098e5056a1ed70d0eb147412707bcf813b8ddd Mon Sep 17 00:00:00 2001 From: E Nelson Date: Wed, 2 Jul 2025 11:49:08 -0400 Subject: [PATCH 56/59] Update chatlas/_anthropic.py Co-authored-by: Carson Sievert --- chatlas/_anthropic.py | 1 - 1 file changed, 1 deletion(-) diff --git a/chatlas/_anthropic.py b/chatlas/_anthropic.py index cb7e826f..7f81723c 100644 --- a/chatlas/_anthropic.py +++ b/chatlas/_anthropic.py @@ -703,7 +703,6 @@ def ChatBedrockAnthropic( class AnthropicBedrockProvider(AnthropicProvider): - def __init__( self, *, From a78784376474bc399f12ff625c3166df56b82857 Mon Sep 17 00:00:00 2001 From: E Nelson Date: Wed, 2 Jul 2025 11:54:38 -0400 Subject: [PATCH 57/59] Update CHANGELOG.md Co-authored-by: Carson Sievert --- CHANGELOG.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6a007298..efc3b9c3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,8 +40,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Improvements -* `Chat`'s representation will now include cost information if it can be calculated. (#106) -* `token_usage()` will include cost if it can be calculated. (#106) +* `Chat`'s representation now includes cost information if it can be calculated. (#106) +* `token_usage()` includes cost if it can be calculated. (#106) * `ChatOpenAI()` and `ChatGithub()` now default to GPT 4.1 (instead of 4o). (#115) * `ChatAnthropic()` now supports `content_image_url()`. (#112) * HTML styling improvements for `ContentToolResult` and `ContentToolRequest`. (#39) From ae6539c135a1e48c3c61c6399c16186355f477e7 Mon Sep 17 00:00:00 2001 From: Liz Nelson Date: Wed, 2 Jul 2025 11:54:55 -0400 Subject: [PATCH 58/59] PR Updates --- chatlas/_chat.py | 12 +++++++----- chatlas/_tokens.py | 4 ++-- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/chatlas/_chat.py b/chatlas/_chat.py index 97e04560..00c10bb6 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -58,6 +58,12 @@ class AnyTypeDict(TypedDict, total=False): class TokensDict(TypedDict): """ A TypedDict representing the token counts for a turn in the chat. + This is used to represent the token counts for each turn in the chat. + `role` represents the role of the turn (i.e., "user" or "assistant"). + `tokens` represents the new tokens used in the turn. + `tokens_total` represents the total tokens used in the turn. + Ex. A new user input of 2 tokens is sent, plus 10 tokens of context from prior turns (input and output). + This would have a `tokens_total` of 12. """ role: Literal["user", "assistant"] @@ -235,11 +241,7 @@ def get_tokens(self) -> list[TokensDict]: Returns ------- list[TokensDict] - A list of dictionaries with the token counts for each (non-system) turn - in the chat. - `tokens` represents the new tokens used in the turn. - `tokens_total` represents the total tokens used in the turn. - Ex. A new user input of 2 tokens is sent, plus 10 tokens of context from prior turns (input and output) would have a `tokens_total` of 12. + A list of dictionaries with the token counts for each (non-system) turn Raises ------ diff --git a/chatlas/_tokens.py b/chatlas/_tokens.py index b1fa45de..2a42bea2 100644 --- a/chatlas/_tokens.py +++ b/chatlas/_tokens.py @@ -106,7 +106,7 @@ class TokenPrice(TypedDict): # Load in pricing pulled from ellmer f = resources.files("chatlas").joinpath("data/prices.json").read_text(encoding="utf-8") -PricingList: list[TokenPrice] = orjson.loads(f) +pricing_list: list[TokenPrice] = orjson.loads(f) def get_token_pricing(name: str, model: str) -> TokenPrice | None: @@ -125,7 +125,7 @@ def get_token_pricing(name: str, model: str) -> TokenPrice | None: result = next( ( item - for item in PricingList + for item in pricing_list if item["provider"] == name and item["model"] == model ), None, From 80e8fb2136bd2fe85bd0c6ee138684f045292220 Mon Sep 17 00:00:00 2001 From: Carson Sievert Date: Wed, 2 Jul 2025 11:08:31 -0500 Subject: [PATCH 59/59] Update CHANGELOG.md --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index efc3b9c3..c8b37e44 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,7 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * `.register_mcp_tools_http_stream_async()` and `.register_mcp_tools_stdio_async()`: for registering tools from a [MCP server](https://modelcontextprotocol.io/). (#39) * `.get_tools()` and `.set_tools()`: for fine-grained control over registered tools. (#39) * `.add_turn()`: to add `Turn`(s) to the current chat history. (#126) - * `Chat` gains new `.get_cost()` method, making it easier to get the estimated cost of your chat. Use our pricing list or bring your own token prices. (#106) + * `.get_cost()`: to get the estimated cost of the chat. Only popular models are supported, but you can also supply your own token prices. (#106) * Tool functions passed to `.register_tool()` can now `yield` numerous results. (#39) * New content classes (`ContentToolResultImage` and `ContentToolResultResource`) were added, primarily to represent MCP tool results that include images/files. (#39) * A `Tool` can now be constructed from a pre-existing tool schema (via a new `__init__` method). (#39)