Skip to content

Commit e033684

Browse files
authored
chat.tokens() gains a values argument (#27)
* The .tokens() method now returns a list of ints: where each int represents the number of tokens each turn takes * Add format argument to .tokens(); make default behavior same as before: * Improvements and tests * Update changelog * Update test expectation
1 parent f5a300f commit e033684

File tree

4 files changed

+142
-6
lines changed

4 files changed

+142
-6
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1111

1212
### New features
1313

14+
* `Chat`'s `.tokens()` method gains a `values` argument. Set it to `"discrete"` to get a result that can be summed to determine the token cost of submitting the current turns. The default (`"cumulative"`), remains the same (the result can be summed to determine the overall token cost of the conversation).
15+
1416
### Bug fixes
1517

1618
* `ChatOllama` no longer fails when a `OPENAI_API_KEY` environment variable is not set.

chatlas/_chat.py

Lines changed: 111 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
Optional,
1717
Sequence,
1818
TypeVar,
19+
overload,
1920
)
2021

2122
from pydantic import BaseModel
@@ -176,17 +177,122 @@ def system_prompt(self, value: str | None):
176177
if value is not None:
177178
self._turns.insert(0, Turn("system", value))
178179

179-
def tokens(self) -> list[tuple[int, int] | None]:
180+
@overload
181+
def tokens(self) -> list[tuple[int, int] | None]: ...
182+
183+
@overload
184+
def tokens(
185+
self,
186+
values: Literal["cumulative"],
187+
) -> list[tuple[int, int] | None]: ...
188+
189+
@overload
190+
def tokens(
191+
self,
192+
values: Literal["discrete"],
193+
) -> list[int]: ...
194+
195+
def tokens(
196+
self,
197+
values: Literal["cumulative", "discrete"] = "discrete",
198+
) -> list[int] | list[tuple[int, int] | None]:
180199
"""
181200
Get the tokens for each turn in the chat.
182201
202+
Parameters
203+
----------
204+
values
205+
If "cumulative" (the default), the result can be summed to get the
206+
chat's overall token usage (helpful for computing overall cost of
207+
the chat). If "discrete", the result can be summed to get the number of
208+
tokens the turns will cost to generate the next response (helpful
209+
for estimating cost of the next response, or for determining if you
210+
are about to exceed the token limit).
211+
183212
Returns
184213
-------
185-
list[tuple[int, int] | None]
186-
A list of tuples, where each tuple contains the start and end token
187-
indices for a turn.
214+
list[int]
215+
A list of token counts for each (non-system) turn in the chat. The
216+
1st turn includes the tokens count for the system prompt (if any).
217+
218+
Raises
219+
------
220+
ValueError
221+
If the chat's turns (i.e., `.get_turns()`) are not in an expected
222+
format. This may happen if the chat history is manually set (i.e.,
223+
`.set_turns()`). In this case, you can inspect the "raw" token
224+
values via the `.get_turns()` method (each turn has a `.tokens`
225+
attribute).
188226
"""
189-
return [turn.tokens for turn in self._turns]
227+
228+
turns = self.get_turns(include_system_prompt=False)
229+
230+
if values == "cumulative":
231+
return [turn.tokens for turn in turns]
232+
233+
if len(turns) == 0:
234+
return []
235+
236+
err_info = (
237+
"This can happen if the chat history is manually set (i.e., `.set_turns()`). "
238+
"Consider getting the 'raw' token values via the `.get_turns()` method "
239+
"(each turn has a `.tokens` attribute)."
240+
)
241+
242+
# Sanity checks for the assumptions made to figure out user token counts
243+
if len(turns) == 1:
244+
raise ValueError(
245+
"Expected at least two turns in the chat history. " + err_info
246+
)
247+
248+
if len(turns) % 2 != 0:
249+
raise ValueError(
250+
"Expected an even number of turns in the chat history. " + err_info
251+
)
252+
253+
if turns[0].role != "user":
254+
raise ValueError(
255+
"Expected the 1st non-system turn to have role='user'. " + err_info
256+
)
257+
258+
if turns[1].role != "assistant":
259+
raise ValueError(
260+
"Expected the 2nd turn non-system to have role='assistant'. " + err_info
261+
)
262+
263+
if turns[1].tokens is None:
264+
raise ValueError(
265+
"Expected the 1st assistant turn to contain token counts. " + err_info
266+
)
267+
268+
res: list[int] = [
269+
# Implied token count for the 1st user input
270+
turns[1].tokens[0],
271+
# The token count for the 1st assistant response
272+
turns[1].tokens[1],
273+
]
274+
for i in range(1, len(turns) - 1, 2):
275+
ti = turns[i]
276+
tj = turns[i + 2]
277+
if ti.role != "assistant" or tj.role != "assistant":
278+
raise ValueError(
279+
"Expected even turns to have role='assistant'." + err_info
280+
)
281+
if ti.tokens is None or tj.tokens is None:
282+
raise ValueError(
283+
"Expected role='assistant' turns to contain token counts."
284+
+ err_info
285+
)
286+
res.extend(
287+
[
288+
# Implied token count for the user input
289+
tj.tokens[0] - sum(ti.tokens),
290+
# The token count for the assistant response
291+
tj.tokens[1],
292+
]
293+
)
294+
295+
return res
190296

191297
def app(
192298
self,

tests/test_provider_openai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def test_openai_simple_request():
2121
chat.chat("What is 1 + 1?")
2222
turn = chat.get_last_turn()
2323
assert turn is not None
24-
assert turn.tokens == (27, 1)
24+
assert turn.tokens == (27, 2)
2525
assert turn.finish_reason == "stop"
2626

2727

tests/test_tokens.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,35 @@
1+
from chatlas import ChatOpenAI, Turn
12
from chatlas._openai import OpenAIAzureProvider, OpenAIProvider
23
from chatlas._tokens import token_usage, tokens_log, tokens_reset
34

45

6+
def test_tokens_method():
7+
chat = ChatOpenAI()
8+
assert chat.tokens(values="discrete") == []
9+
10+
chat = ChatOpenAI(
11+
turns=[
12+
Turn(role="user", contents="Hi"),
13+
Turn(role="assistant", contents="Hello", tokens=(2, 10)),
14+
]
15+
)
16+
17+
assert chat.tokens(values="discrete") == [2, 10]
18+
19+
chat = ChatOpenAI(
20+
turns=[
21+
Turn(role="user", contents="Hi"),
22+
Turn(role="assistant", contents="Hello", tokens=(2, 10)),
23+
Turn(role="user", contents="Hi"),
24+
Turn(role="assistant", contents="Hello", tokens=(14, 10)),
25+
]
26+
)
27+
28+
assert chat.tokens(values="discrete") == [2, 10, 2, 10]
29+
30+
assert chat.tokens(values="cumulative") == [None, (2, 10), None, (14, 10)]
31+
32+
533
def test_usage_is_none():
634
tokens_reset()
735
assert token_usage() is None

0 commit comments

Comments
 (0)