Skip to content

Resolves: #76, Add get_cost to chat #106

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 28 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
370edbf
Added get_tokens
elnelson575 Jun 11, 2025
30694be
changed type of token dict to TypedDict
elnelson575 Jun 12, 2025
c0c8e67
Adding name to Providers where pricing is supported
elnelson575 Jun 23, 2025
6ea4ece
Correct kwargs.get
elnelson575 Jun 23, 2025
56463d4
removing unused import
elnelson575 Jun 23, 2025
64d1b5c
Adding start of pricing, fixing test
elnelson575 Jun 24, 2025
00b9389
Initial token price fetching
elnelson575 Jun 24, 2025
471e67a
Correcting importlib
elnelson575 Jun 24, 2025
80635a1
Fixing ordering
elnelson575 Jun 24, 2025
df255a3
Fixing imports again
elnelson575 Jun 24, 2025
a445686
Ignoring type issues that don't exist
elnelson575 Jun 24, 2025
06cbfda
Test removing backticks from streaming test
elnelson575 Jun 24, 2025
b6b2126
Switched to _name convention
elnelson575 Jun 25, 2025
f7dbfba
Updating all classes and adding tests
elnelson575 Jun 25, 2025
965d406
Removing old test
elnelson575 Jun 25, 2025
75b8b1f
Correcting class
elnelson575 Jun 25, 2025
528a53c
Updating class spec
elnelson575 Jun 26, 2025
0b76317
Added token tests
elnelson575 Jun 26, 2025
70709fe
Adding cost test
elnelson575 Jun 26, 2025
c62a9dd
Updating classes
elnelson575 Jun 26, 2025
8fbe70f
Removing flaky tests that are prone to change in future based on mode…
elnelson575 Jun 26, 2025
b601fdb
Fixing import orders
elnelson575 Jun 26, 2025
952cfb4
Updating to tuples
elnelson575 Jun 26, 2025
07e1cdc
Importing OpenAI Pto replace Chat
elnelson575 Jun 27, 2025
c096a48
Stashing changes
elnelson575 Jun 27, 2025
dffaf72
Merge remote-tracking branch 'origin/main' into feature/get_cost
elnelson575 Jun 27, 2025
b5963cf
Updating repr test
elnelson575 Jun 27, 2025
da507eb
Fixing tests
elnelson575 Jun 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions chatlas/_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,23 +186,24 @@ def ChatAnthropic(


class AnthropicProvider(Provider[Message, RawMessageStreamEvent, Message]):

def __init__(
self,
*,
max_tokens: int,
max_tokens: int = 4096,
model: str,
api_key: str | None,
api_key: Optional[str] = None,
name: Optional[str] = "Anthropic",
kwargs: Optional["ChatClientArgs"] = None,
):
super().__init__(name=name, model=model)
try:
from anthropic import Anthropic, AsyncAnthropic
except ImportError:
raise ImportError(
"`ChatAnthropic()` requires the `anthropic` package. "
"You can install it with 'pip install anthropic'."
)

self._model = model
self._max_tokens = max_tokens

kwargs_full: "ChatClientArgs" = {
Expand Down Expand Up @@ -325,7 +326,7 @@ def _structured_tool_call(**kwargs: Any):
kwargs_full: "SubmitInputArgs" = {
"stream": stream,
"messages": self._as_message_params(turns),
"model": self._model,
"model": self.model,
"max_tokens": self._max_tokens,
"tools": tool_schemas,
**(kwargs or {}),
Expand Down Expand Up @@ -725,6 +726,7 @@ def ChatBedrockAnthropic(


class AnthropicBedrockProvider(AnthropicProvider):

def __init__(
self,
*,
Expand All @@ -734,10 +736,14 @@ def __init__(
aws_region: str | None,
aws_profile: str | None,
aws_session_token: str | None,
max_tokens: int,
max_tokens: int = 4096,
base_url: str | None,
name: Optional[str] = "AnthropicBedrock",
kwargs: Optional["ChatBedrockClientArgs"] = None,
):

super().__init__(name=name, model=model, max_tokens=max_tokens)

try:
from anthropic import AnthropicBedrock, AsyncAnthropicBedrock
except ImportError:
Expand All @@ -746,9 +752,6 @@ def __init__(
"Install it with `pip install anthropic[bedrock]`."
)

self._model = model
self._max_tokens = max_tokens

kwargs_full: "ChatBedrockClientArgs" = {
"aws_secret_key": aws_secret_key,
"aws_access_key": aws_access_key,
Expand Down
175 changes: 125 additions & 50 deletions chatlas/_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from ._logging import log_tool_error
from ._mcp_manager import MCPSessionManager
from ._provider import Provider
from ._tokens import get_token_pricing
from ._tools import Tool, ToolRejectError
from ._turn import Turn, user_turn
from ._typing_extensions import TypedDict
Expand All @@ -54,6 +55,16 @@ class AnyTypeDict(TypedDict, total=False):
pass


class TokensDict(TypedDict):
"""
A TypedDict representing the token counts for a turn in the chat.
"""

role: Literal["user", "assistant"]
tokens: int
tokens_total: int


SubmitInputArgsT = TypeVar("SubmitInputArgsT", bound=AnyTypeDict)
"""
A TypedDict representing the arguments that can be passed to the `.chat()`
Expand All @@ -64,6 +75,8 @@ class AnyTypeDict(TypedDict, total=False):

EchoOptions = Literal["output", "all", "none", "text"]

CostOptions = Literal["all", "last"]


class Chat(Generic[SubmitInputArgsT, CompletionT]):
"""
Expand Down Expand Up @@ -190,43 +203,18 @@ def system_prompt(self, value: str | None):
if value is not None:
self._turns.insert(0, Turn("system", value))

@overload
def tokens(self) -> list[tuple[int, int] | None]: ...

@overload
def tokens(
self,
values: Literal["cumulative"],
) -> list[tuple[int, int] | None]: ...

@overload
def tokens(
self,
values: Literal["discrete"],
) -> list[int]: ...

def tokens(
self,
values: Literal["cumulative", "discrete"] = "discrete",
) -> list[int] | list[tuple[int, int] | None]:
def get_tokens(self) -> list[TokensDict]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make sure to highlight breaking changes in CHANGELOG.md

"""
Get the tokens for each turn in the chat.

Parameters
----------
values
If "cumulative" (the default), the result can be summed to get the
chat's overall token usage (helpful for computing overall cost of
the chat). If "discrete", the result can be summed to get the number of
tokens the turns will cost to generate the next response (helpful
for estimating cost of the next response, or for determining if you
are about to exceed the token limit).

Returns
-------
list[int]
A list of token counts for each (non-system) turn in the chat. The
1st turn includes the tokens count for the system prompt (if any).
list[TokensDict]
A list of dictionaries with the token counts for each (non-system) turn
in the chat.
`tokens` represents the new tokens used in the turn.
`tokens_total` represents the total tokens used in the turn.
Ex. A new user input of 2 tokens is sent, plus 10 tokens of context from prior turns (input and output) would have a `tokens_total` of 12.

Raises
------
Expand All @@ -240,9 +228,6 @@ def tokens(

turns = self.get_turns(include_system_prompt=False)

if values == "cumulative":
return [turn.tokens for turn in turns]

if len(turns) == 0:
return []

Expand Down Expand Up @@ -278,12 +263,21 @@ def tokens(
"Expected the 1st assistant turn to contain token counts. " + err_info
)

res: list[int] = [
res: list[TokensDict] = [
# Implied token count for the 1st user input
turns[1].tokens[0],
{
"role": "user",
"tokens": turns[1].tokens[0],
"tokens_total": turns[1].tokens[0],
},
# The token count for the 1st assistant response
turns[1].tokens[1],
{
"role": "assistant",
"tokens": turns[1].tokens[1],
"tokens_total": turns[1].tokens[1],
},
]

for i in range(1, len(turns) - 1, 2):
ti = turns[i]
tj = turns[i + 2]
Expand All @@ -298,15 +292,91 @@ def tokens(
)
res.extend(
[
# Implied token count for the user input
tj.tokens[0] - sum(ti.tokens),
# The token count for the assistant response
tj.tokens[1],
{
"role": "user",
# Implied token count for the user input
"tokens": tj.tokens[0] - sum(ti.tokens),
# Total tokens = Total User Tokens for the Turn = Distinct new tokens + context sent
"tokens_total": tj.tokens[0],
},
{
"role": "assistant",
# The token count for the assistant response
"tokens": tj.tokens[1],
# Total tokens = Total Assistant tokens used in the turn
"tokens_total": tj.tokens[1],
},
]
)

return res

def get_cost(
self,
options: CostOptions = "all",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add param to BYO cost info

token_price: Optional[tuple[float, float]] = None,
) -> float:
"""
Get the cost of the chat. Note that this is a rough estimate. Providers may change their pricing frequently and without notice.

Parameters
----------
options
One of the following (default is "all"):
- `"all"`: Return the total cost of all turns in the chat.
- `"last"`: Return the cost of the last turn in the chat.
token_price
An optional tuple in the format of (input_token_cost, output_token_cost) for bringing your own cost information.
- `"input_token_cost"`: The cost per user token in USD per million tokens.
- `"output_token_cost"`: The cost per assistant token in USD per million tokens.

Returns
-------
float
The cost of the chat, in USD.
"""

# Look up token cost for user and input tokens based on the provider and model
turns_tokens = self.get_tokens()
if token_price:
input_token_price = token_price[0] / 1000000
output_token_price = token_price[1] / 1000000
else:
price_token = get_token_pricing(self.provider)
input_token_price = price_token["input"] / 1000000
output_token_price = price_token["output"] / 1000000

if not input_token_price and not output_token_price:
raise KeyError(
f"We could not locate provider ' { self.provider.name } ' and model '{ self.provider.model } ' in our pricing information. Please supply your own if you wish to use the cost function."
)

if len(turns_tokens) == 0:
return 0.0

if options == "last":
last_turn = turns_tokens[len(turns_tokens) - 1]
acc = 0.0
if last_turn["role"] == "assistant":
acc += last_turn["tokens"] * output_token_price
elif last_turn["role"] == "user":
acc += last_turn["tokens_total"] * input_token_price
else:
raise ValueError(f"Unrecognized role type { last_turn['role'] }")
return acc

if options == "all":
asst_tokens = sum(
u["tokens_total"] for u in turns_tokens if u["role"] == "assistant"
)
user_tokens = sum(
u["tokens_total"] for u in turns_tokens if u["role"] == "user"
)
cost = (asst_tokens * output_token_price) + (
user_tokens * input_token_price
)
return cost

def token_count(
self,
*args: Content | str,
Expand Down Expand Up @@ -708,9 +778,9 @@ def stream(
kwargs=kwargs,
)

def wrapper() -> Generator[
str | ContentToolRequest | ContentToolResult, None, None
]:
def wrapper() -> (
Generator[str | ContentToolRequest | ContentToolResult, None, None]
):
with display:
for chunk in generator:
yield chunk
Expand Down Expand Up @@ -772,9 +842,9 @@ async def stream_async(

display = self._markdown_display(echo=echo)

async def wrapper() -> AsyncGenerator[
str | ContentToolRequest | ContentToolResult, None
]:
async def wrapper() -> (
AsyncGenerator[str | ContentToolRequest | ContentToolResult, None]
):
with display:
async for chunk in self._chat_impl_async(
turn,
Expand Down Expand Up @@ -1976,10 +2046,15 @@ def __str__(self):
res += f"## {icon} {turn.role.capitalize()} turn:\n\n{str(turn)}\n\n"
return res

# TODO: Update this to get tokens and also provide cost add provider and model
def __repr__(self):
turns = self.get_turns(include_system_prompt=True)
tokens = sum(sum(turn.tokens) for turn in turns if turn.tokens)
res = f"<Chat turns={len(turns)} tokens={tokens}>"
tokens = self.get_tokens()
cost = self.get_cost()
tokens_asst = sum(u["tokens_total"] for u in tokens if u["role"] == "assistant")
tokens_user = sum(u["tokens_total"] for u in tokens if u["role"] == "user")

res = f"<Chat {self.provider.name}/{self.provider.model} turns={len(turns)} tokens={tokens_user}/{tokens_asst} ${round(cost, ndigits=2)}>"
for turn in turns:
res += "\n" + turn.__repr__(indent=2)
return res + "\n"
Expand Down
6 changes: 5 additions & 1 deletion chatlas/_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from typing import TYPE_CHECKING, Optional

from databricks.sdk import WorkspaceClient

from ._chat import Chat
from ._logging import log_model_default
from ._openai import OpenAIProvider
Expand Down Expand Up @@ -104,6 +106,7 @@ def __init__(
self,
*,
model: str,
name: Optional[str] = "Databricks",
workspace_client: Optional["WorkspaceClient"] = None,
):
try:
Expand All @@ -117,7 +120,8 @@ def __init__(
import httpx
from openai import AsyncOpenAI

self._model = model
super().__init__(name=name, model=model)

self._seed = None

if workspace_client is None:
Expand Down
Loading
Loading