-
Notifications
You must be signed in to change notification settings - Fork 10
Resolves: #76, Add get_cost to chat #106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
elnelson575
wants to merge
28
commits into
main
Choose a base branch
from
feature/get_cost
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
28 commits
Select commit
Hold shift + click to select a range
370edbf
Added get_tokens
elnelson575 30694be
changed type of token dict to TypedDict
elnelson575 c0c8e67
Adding name to Providers where pricing is supported
elnelson575 6ea4ece
Correct kwargs.get
elnelson575 56463d4
removing unused import
elnelson575 64d1b5c
Adding start of pricing, fixing test
elnelson575 00b9389
Initial token price fetching
elnelson575 471e67a
Correcting importlib
elnelson575 80635a1
Fixing ordering
elnelson575 df255a3
Fixing imports again
elnelson575 a445686
Ignoring type issues that don't exist
elnelson575 06cbfda
Test removing backticks from streaming test
elnelson575 b6b2126
Switched to _name convention
elnelson575 f7dbfba
Updating all classes and adding tests
elnelson575 965d406
Removing old test
elnelson575 75b8b1f
Correcting class
elnelson575 528a53c
Updating class spec
elnelson575 0b76317
Added token tests
elnelson575 70709fe
Adding cost test
elnelson575 c62a9dd
Updating classes
elnelson575 8fbe70f
Removing flaky tests that are prone to change in future based on mode…
elnelson575 b601fdb
Fixing import orders
elnelson575 952cfb4
Updating to tuples
elnelson575 07e1cdc
Importing OpenAI Pto replace Chat
elnelson575 c096a48
Stashing changes
elnelson575 dffaf72
Merge remote-tracking branch 'origin/main' into feature/get_cost
elnelson575 b5963cf
Updating repr test
elnelson575 da507eb
Fixing tests
elnelson575 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -44,6 +44,7 @@ | |
from ._logging import log_tool_error | ||
from ._mcp_manager import MCPSessionManager | ||
from ._provider import Provider | ||
from ._tokens import get_token_pricing | ||
from ._tools import Tool, ToolRejectError | ||
from ._turn import Turn, user_turn | ||
from ._typing_extensions import TypedDict | ||
|
@@ -54,6 +55,16 @@ class AnyTypeDict(TypedDict, total=False): | |
pass | ||
|
||
|
||
class TokensDict(TypedDict): | ||
""" | ||
A TypedDict representing the token counts for a turn in the chat. | ||
""" | ||
|
||
role: Literal["user", "assistant"] | ||
tokens: int | ||
tokens_total: int | ||
|
||
|
||
SubmitInputArgsT = TypeVar("SubmitInputArgsT", bound=AnyTypeDict) | ||
""" | ||
A TypedDict representing the arguments that can be passed to the `.chat()` | ||
|
@@ -64,6 +75,8 @@ class AnyTypeDict(TypedDict, total=False): | |
|
||
EchoOptions = Literal["output", "all", "none", "text"] | ||
|
||
CostOptions = Literal["all", "last"] | ||
|
||
|
||
class Chat(Generic[SubmitInputArgsT, CompletionT]): | ||
""" | ||
|
@@ -190,43 +203,18 @@ def system_prompt(self, value: str | None): | |
if value is not None: | ||
self._turns.insert(0, Turn("system", value)) | ||
|
||
@overload | ||
def tokens(self) -> list[tuple[int, int] | None]: ... | ||
|
||
@overload | ||
def tokens( | ||
self, | ||
values: Literal["cumulative"], | ||
) -> list[tuple[int, int] | None]: ... | ||
|
||
@overload | ||
def tokens( | ||
self, | ||
values: Literal["discrete"], | ||
) -> list[int]: ... | ||
|
||
def tokens( | ||
self, | ||
values: Literal["cumulative", "discrete"] = "discrete", | ||
) -> list[int] | list[tuple[int, int] | None]: | ||
def get_tokens(self) -> list[TokensDict]: | ||
""" | ||
Get the tokens for each turn in the chat. | ||
|
||
Parameters | ||
---------- | ||
values | ||
If "cumulative" (the default), the result can be summed to get the | ||
chat's overall token usage (helpful for computing overall cost of | ||
the chat). If "discrete", the result can be summed to get the number of | ||
tokens the turns will cost to generate the next response (helpful | ||
for estimating cost of the next response, or for determining if you | ||
are about to exceed the token limit). | ||
|
||
Returns | ||
------- | ||
list[int] | ||
A list of token counts for each (non-system) turn in the chat. The | ||
1st turn includes the tokens count for the system prompt (if any). | ||
list[TokensDict] | ||
A list of dictionaries with the token counts for each (non-system) turn | ||
in the chat. | ||
`tokens` represents the new tokens used in the turn. | ||
`tokens_total` represents the total tokens used in the turn. | ||
Ex. A new user input of 2 tokens is sent, plus 10 tokens of context from prior turns (input and output) would have a `tokens_total` of 12. | ||
|
||
Raises | ||
------ | ||
|
@@ -240,9 +228,6 @@ def tokens( | |
|
||
turns = self.get_turns(include_system_prompt=False) | ||
|
||
if values == "cumulative": | ||
return [turn.tokens for turn in turns] | ||
|
||
if len(turns) == 0: | ||
return [] | ||
|
||
|
@@ -278,12 +263,21 @@ def tokens( | |
"Expected the 1st assistant turn to contain token counts. " + err_info | ||
) | ||
|
||
res: list[int] = [ | ||
res: list[TokensDict] = [ | ||
# Implied token count for the 1st user input | ||
turns[1].tokens[0], | ||
{ | ||
"role": "user", | ||
"tokens": turns[1].tokens[0], | ||
"tokens_total": turns[1].tokens[0], | ||
}, | ||
# The token count for the 1st assistant response | ||
turns[1].tokens[1], | ||
{ | ||
"role": "assistant", | ||
"tokens": turns[1].tokens[1], | ||
"tokens_total": turns[1].tokens[1], | ||
}, | ||
] | ||
|
||
for i in range(1, len(turns) - 1, 2): | ||
ti = turns[i] | ||
tj = turns[i + 2] | ||
|
@@ -298,15 +292,91 @@ def tokens( | |
) | ||
res.extend( | ||
[ | ||
# Implied token count for the user input | ||
tj.tokens[0] - sum(ti.tokens), | ||
# The token count for the assistant response | ||
tj.tokens[1], | ||
{ | ||
"role": "user", | ||
# Implied token count for the user input | ||
"tokens": tj.tokens[0] - sum(ti.tokens), | ||
# Total tokens = Total User Tokens for the Turn = Distinct new tokens + context sent | ||
"tokens_total": tj.tokens[0], | ||
}, | ||
{ | ||
"role": "assistant", | ||
# The token count for the assistant response | ||
"tokens": tj.tokens[1], | ||
# Total tokens = Total Assistant tokens used in the turn | ||
"tokens_total": tj.tokens[1], | ||
}, | ||
] | ||
) | ||
|
||
return res | ||
|
||
def get_cost( | ||
self, | ||
options: CostOptions = "all", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add param to BYO cost info |
||
token_price: Optional[tuple[float, float]] = None, | ||
) -> float: | ||
""" | ||
Get the cost of the chat. Note that this is a rough estimate. Providers may change their pricing frequently and without notice. | ||
|
||
Parameters | ||
---------- | ||
options | ||
One of the following (default is "all"): | ||
- `"all"`: Return the total cost of all turns in the chat. | ||
- `"last"`: Return the cost of the last turn in the chat. | ||
token_price | ||
An optional tuple in the format of (input_token_cost, output_token_cost) for bringing your own cost information. | ||
- `"input_token_cost"`: The cost per user token in USD per million tokens. | ||
- `"output_token_cost"`: The cost per assistant token in USD per million tokens. | ||
|
||
Returns | ||
------- | ||
float | ||
The cost of the chat, in USD. | ||
""" | ||
|
||
# Look up token cost for user and input tokens based on the provider and model | ||
turns_tokens = self.get_tokens() | ||
if token_price: | ||
input_token_price = token_price[0] / 1000000 | ||
output_token_price = token_price[1] / 1000000 | ||
else: | ||
price_token = get_token_pricing(self.provider) | ||
input_token_price = price_token["input"] / 1000000 | ||
output_token_price = price_token["output"] / 1000000 | ||
|
||
if not input_token_price and not output_token_price: | ||
raise KeyError( | ||
f"We could not locate provider ' { self.provider.name } ' and model '{ self.provider.model } ' in our pricing information. Please supply your own if you wish to use the cost function." | ||
) | ||
|
||
if len(turns_tokens) == 0: | ||
return 0.0 | ||
|
||
if options == "last": | ||
last_turn = turns_tokens[len(turns_tokens) - 1] | ||
acc = 0.0 | ||
if last_turn["role"] == "assistant": | ||
acc += last_turn["tokens"] * output_token_price | ||
elif last_turn["role"] == "user": | ||
acc += last_turn["tokens_total"] * input_token_price | ||
else: | ||
raise ValueError(f"Unrecognized role type { last_turn['role'] }") | ||
return acc | ||
|
||
if options == "all": | ||
asst_tokens = sum( | ||
u["tokens_total"] for u in turns_tokens if u["role"] == "assistant" | ||
) | ||
user_tokens = sum( | ||
u["tokens_total"] for u in turns_tokens if u["role"] == "user" | ||
) | ||
cost = (asst_tokens * output_token_price) + ( | ||
user_tokens * input_token_price | ||
) | ||
return cost | ||
|
||
def token_count( | ||
self, | ||
*args: Content | str, | ||
|
@@ -708,9 +778,9 @@ def stream( | |
kwargs=kwargs, | ||
) | ||
|
||
def wrapper() -> Generator[ | ||
str | ContentToolRequest | ContentToolResult, None, None | ||
]: | ||
def wrapper() -> ( | ||
Generator[str | ContentToolRequest | ContentToolResult, None, None] | ||
): | ||
with display: | ||
for chunk in generator: | ||
yield chunk | ||
|
@@ -772,9 +842,9 @@ async def stream_async( | |
|
||
display = self._markdown_display(echo=echo) | ||
|
||
async def wrapper() -> AsyncGenerator[ | ||
str | ContentToolRequest | ContentToolResult, None | ||
]: | ||
async def wrapper() -> ( | ||
AsyncGenerator[str | ContentToolRequest | ContentToolResult, None] | ||
): | ||
with display: | ||
async for chunk in self._chat_impl_async( | ||
turn, | ||
|
@@ -1976,10 +2046,15 @@ def __str__(self): | |
res += f"## {icon} {turn.role.capitalize()} turn:\n\n{str(turn)}\n\n" | ||
return res | ||
|
||
# TODO: Update this to get tokens and also provide cost add provider and model | ||
def __repr__(self): | ||
turns = self.get_turns(include_system_prompt=True) | ||
tokens = sum(sum(turn.tokens) for turn in turns if turn.tokens) | ||
res = f"<Chat turns={len(turns)} tokens={tokens}>" | ||
tokens = self.get_tokens() | ||
cost = self.get_cost() | ||
tokens_asst = sum(u["tokens_total"] for u in tokens if u["role"] == "assistant") | ||
tokens_user = sum(u["tokens_total"] for u in tokens if u["role"] == "user") | ||
|
||
res = f"<Chat {self.provider.name}/{self.provider.model} turns={len(turns)} tokens={tokens_user}/{tokens_asst} ${round(cost, ndigits=2)}>" | ||
for turn in turns: | ||
res += "\n" + turn.__repr__(indent=2) | ||
return res + "\n" | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make sure to highlight breaking changes in
CHANGELOG.md