Skip to content

Commit 200e26c

Browse files
authored
Add .token_count() for estimating input tokens (#23)
* Improvements to token usage reporting * Update changelog * Clean up docstring * Make token_usage() a method not a property Just in case we want parameters * Fix imports * Rollback breaking changes * Cleanup * Doc improvements * Add .token_count_async(); require the whole data_model * Slightly more accurate/conservative token count for OpenAI * Add tests * Add note * Tweak changelog * Tweak docstring
1 parent e033684 commit 200e26c

File tree

8 files changed

+279
-5
lines changed

8 files changed

+279
-5
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1212
### New features
1313

1414
* `Chat`'s `.tokens()` method gains a `values` argument. Set it to `"discrete"` to get a result that can be summed to determine the token cost of submitting the current turns. The default (`"cumulative"`), remains the same (the result can be summed to determine the overall token cost of the conversation).
15+
* `Chat` gains a `.token_count()` method to help estimate token cost of new input. (#23)
1516

1617
### Bug fixes
1718

1819
* `ChatOllama` no longer fails when a `OPENAI_API_KEY` environment variable is not set.
1920
* `ChatOpenAI` now correctly includes the relevant `detail` on `ContentImageRemote()` input.
21+
* `ChatGoogle` now correctly logs its `token_usage()`. (#23)
2022

2123

2224
## [0.2.0] - 2024-12-11

chatlas/_anthropic.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from ._provider import Provider
2121
from ._tokens import tokens_log
2222
from ._tools import Tool, basemodel_to_param_schema
23-
from ._turn import Turn, normalize_turns
23+
from ._turn import Turn, normalize_turns, user_turn
2424

2525
if TYPE_CHECKING:
2626
from anthropic.types import (
@@ -380,6 +380,59 @@ async def stream_turn_async(self, completion, has_data_model, stream) -> Turn:
380380
def value_turn(self, completion, has_data_model) -> Turn:
381381
return self._as_turn(completion, has_data_model)
382382

383+
def token_count(
384+
self,
385+
*args: Content | str,
386+
tools: dict[str, Tool],
387+
data_model: Optional[type[BaseModel]],
388+
) -> int:
389+
kwargs = self._token_count_args(
390+
*args,
391+
tools=tools,
392+
data_model=data_model,
393+
)
394+
res = self._client.messages.count_tokens(**kwargs)
395+
return res.input_tokens
396+
397+
async def token_count_async(
398+
self,
399+
*args: Content | str,
400+
tools: dict[str, Tool],
401+
data_model: Optional[type[BaseModel]],
402+
) -> int:
403+
kwargs = self._token_count_args(
404+
*args,
405+
tools=tools,
406+
data_model=data_model,
407+
)
408+
res = await self._async_client.messages.count_tokens(**kwargs)
409+
return res.input_tokens
410+
411+
def _token_count_args(
412+
self,
413+
*args: Content | str,
414+
tools: dict[str, Tool],
415+
data_model: Optional[type[BaseModel]],
416+
) -> dict[str, Any]:
417+
turn = user_turn(*args)
418+
419+
kwargs = self._chat_perform_args(
420+
stream=False,
421+
turns=[turn],
422+
tools=tools,
423+
data_model=data_model,
424+
)
425+
426+
args_to_keep = [
427+
"messages",
428+
"model",
429+
"system",
430+
"tools",
431+
"tool_choice",
432+
]
433+
434+
return {arg: kwargs[arg] for arg in args_to_keep if arg in kwargs}
435+
383436
def _as_message_params(self, turns: list[Turn]) -> list["MessageParam"]:
384437
messages: list["MessageParam"] = []
385438
for turn in turns:

chatlas/_chat.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,93 @@ def tokens(
294294

295295
return res
296296

297+
def token_count(
298+
self,
299+
*args: Content | str,
300+
data_model: Optional[type[BaseModel]] = None,
301+
) -> int:
302+
"""
303+
Get an estimated token count for the given input.
304+
305+
Estimate the token size of input content. This can help determine whether input(s)
306+
and/or conversation history (i.e., `.get_turns()`) should be reduced in size before
307+
sending it to the model.
308+
309+
Parameters
310+
----------
311+
args
312+
The input to get a token count for.
313+
data_model
314+
If the input is meant for data extraction (i.e., `.extract_data()`), then
315+
this should be the Pydantic model that describes the structure of the data to
316+
extract.
317+
318+
Returns
319+
-------
320+
int
321+
The token count for the input.
322+
323+
Note
324+
----
325+
Remember that the token count is an estimate. Also, models based on
326+
`ChatOpenAI()` currently does not take tools into account when
327+
estimating token counts.
328+
329+
Examples
330+
--------
331+
```python
332+
from chatlas import ChatAnthropic
333+
334+
chat = ChatAnthropic()
335+
# Estimate the token count before sending the input
336+
print(chat.token_count("What is 2 + 2?"))
337+
338+
# Once input is sent, you can get the actual input and output
339+
# token counts from the chat object
340+
chat.chat("What is 2 + 2?", echo="none")
341+
print(chat.token_usage())
342+
```
343+
"""
344+
345+
return self.provider.token_count(
346+
*args,
347+
tools=self._tools,
348+
data_model=data_model,
349+
)
350+
351+
async def token_count_async(
352+
self,
353+
*args: Content | str,
354+
data_model: Optional[type[BaseModel]] = None,
355+
) -> int:
356+
"""
357+
Get an estimated token count for the given input asynchronously.
358+
359+
Estimate the token size of input content. This can help determine whether input(s)
360+
and/or conversation history (i.e., `.get_turns()`) should be reduced in size before
361+
sending it to the model.
362+
363+
Parameters
364+
----------
365+
args
366+
The input to get a token count for.
367+
data_model
368+
If this input is meant for data extraction (i.e., `.extract_data_async()`),
369+
then this should be the Pydantic model that describes the structure of the data
370+
to extract.
371+
372+
Returns
373+
-------
374+
int
375+
The token count for the input.
376+
"""
377+
378+
return await self.provider.token_count_async(
379+
*args,
380+
tools=self._tools,
381+
data_model=data_model,
382+
)
383+
297384
def app(
298385
self,
299386
*,

chatlas/_google.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717
)
1818
from ._logging import log_model_default
1919
from ._provider import Provider
20+
from ._tokens import tokens_log
2021
from ._tools import Tool, basemodel_to_param_schema
21-
from ._turn import Turn, normalize_turns
22+
from ._turn import Turn, normalize_turns, user_turn
2223

2324
if TYPE_CHECKING:
2425
from google.generativeai.types.content_types import (
@@ -332,6 +333,55 @@ async def stream_turn_async(
332333
def value_turn(self, completion, has_data_model) -> Turn:
333334
return self._as_turn(completion, has_data_model)
334335

336+
def token_count(
337+
self,
338+
*args: Content | str,
339+
tools: dict[str, Tool],
340+
data_model: Optional[type[BaseModel]],
341+
):
342+
kwargs = self._token_count_args(
343+
*args,
344+
tools=tools,
345+
data_model=data_model,
346+
)
347+
348+
res = self._client.count_tokens(**kwargs)
349+
return res.total_tokens
350+
351+
async def token_count_async(
352+
self,
353+
*args: Content | str,
354+
tools: dict[str, Tool],
355+
data_model: Optional[type[BaseModel]],
356+
):
357+
kwargs = self._token_count_args(
358+
*args,
359+
tools=tools,
360+
data_model=data_model,
361+
)
362+
363+
res = await self._client.count_tokens_async(**kwargs)
364+
return res.total_tokens
365+
366+
def _token_count_args(
367+
self,
368+
*args: Content | str,
369+
tools: dict[str, Tool],
370+
data_model: Optional[type[BaseModel]],
371+
) -> dict[str, Any]:
372+
turn = user_turn(*args)
373+
374+
kwargs = self._chat_perform_args(
375+
stream=False,
376+
turns=[turn],
377+
tools=tools,
378+
data_model=data_model,
379+
)
380+
381+
args_to_keep = ["contents", "tools"]
382+
383+
return {arg: kwargs[arg] for arg in args_to_keep if arg in kwargs}
384+
335385
def _google_contents(self, turns: list[Turn]) -> list["ContentDict"]:
336386
contents: list["ContentDict"] = []
337387
for turn in turns:
@@ -421,6 +471,8 @@ def _as_turn(
421471
usage.candidates_token_count,
422472
)
423473

474+
tokens_log(self, tokens)
475+
424476
finish = message.candidates[0].finish_reason
425477

426478
return Turn(

chatlas/_openai.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from ._chat import Chat
99
from ._content import (
1010
Content,
11+
ContentImage,
1112
ContentImageInline,
1213
ContentImageRemote,
1314
ContentJson,
@@ -20,7 +21,7 @@
2021
from ._provider import Provider
2122
from ._tokens import tokens_log
2223
from ._tools import Tool, basemodel_to_param_schema
23-
from ._turn import Turn, normalize_turns
24+
from ._turn import Turn, normalize_turns, user_turn
2425
from ._utils import MISSING, MISSING_TYPE, is_testing
2526

2627
if TYPE_CHECKING:
@@ -351,6 +352,57 @@ async def stream_turn_async(self, completion, has_data_model, stream):
351352
def value_turn(self, completion, has_data_model) -> Turn:
352353
return self._as_turn(completion, has_data_model)
353354

355+
def token_count(
356+
self,
357+
*args: Content | str,
358+
tools: dict[str, Tool],
359+
data_model: Optional[type[BaseModel]],
360+
) -> int:
361+
try:
362+
import tiktoken
363+
except ImportError:
364+
raise ImportError(
365+
"The tiktoken package is required for token counting. "
366+
"Please install it with `pip install tiktoken`."
367+
)
368+
369+
encoding = tiktoken.encoding_for_model(self._model)
370+
371+
turn = user_turn(*args)
372+
373+
# Count the tokens in image contents
374+
image_tokens = sum(
375+
self._image_token_count(x)
376+
for x in turn.contents
377+
if isinstance(x, ContentImage)
378+
)
379+
380+
# For other contents, get the token count from the actual message param
381+
other_contents = [x for x in turn.contents if not isinstance(x, ContentImage)]
382+
other_full = self._as_message_param([Turn("user", other_contents)])
383+
other_tokens = len(encoding.encode(str(other_full)))
384+
385+
return other_tokens + image_tokens
386+
387+
async def token_count_async(
388+
self,
389+
*args: Content | str,
390+
tools: dict[str, Tool],
391+
data_model: Optional[type[BaseModel]],
392+
) -> int:
393+
return self.token_count(*args, tools=tools, data_model=data_model)
394+
395+
@staticmethod
396+
def _image_token_count(image: ContentImage) -> int:
397+
if isinstance(image, ContentImageRemote) and image.detail == "low":
398+
return 85
399+
else:
400+
# This is just the max token count for an image The highest possible
401+
# resolution is 768 x 2048, and 8 tiles of size 512px can fit inside
402+
# TODO: this is obviously a very conservative estimate and could be improved
403+
# https://platform.openai.com/docs/guides/vision/calculating-costs
404+
return 170 * 8 + 85
405+
354406
@staticmethod
355407
def _as_message_param(turns: list[Turn]) -> list["ChatCompletionMessageParam"]:
356408
from openai.types.chat import (

chatlas/_provider.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from pydantic import BaseModel
1616

17+
from ._content import Content
1718
from ._tools import Tool
1819
from ._turn import Turn
1920

@@ -141,3 +142,19 @@ def value_turn(
141142
completion: ChatCompletionT,
142143
has_data_model: bool,
143144
) -> Turn: ...
145+
146+
@abstractmethod
147+
def token_count(
148+
self,
149+
*args: Content | str,
150+
tools: dict[str, Tool],
151+
data_model: Optional[type[BaseModel]],
152+
) -> int: ...
153+
154+
@abstractmethod
155+
async def token_count_async(
156+
self,
157+
*args: Content | str,
158+
tools: dict[str, Tool],
159+
data_model: Optional[type[BaseModel]],
160+
) -> int: ...

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ dev = [
4848
"anthropic[bedrock]",
4949
"google-generativeai>=0.8.3",
5050
"numpy>1.24.4",
51+
"tiktoken",
5152
]
5253
docs = [
5354
"griffe>=1",

tests/test_tokens.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from chatlas import ChatOpenAI, Turn
1+
from chatlas import ChatAnthropic, ChatGoogle, ChatOpenAI, Turn
22
from chatlas._openai import OpenAIAzureProvider, OpenAIProvider
33
from chatlas._tokens import token_usage, tokens_log, tokens_reset
44

@@ -26,10 +26,20 @@ def test_tokens_method():
2626
)
2727

2828
assert chat.tokens(values="discrete") == [2, 10, 2, 10]
29-
3029
assert chat.tokens(values="cumulative") == [None, (2, 10), None, (14, 10)]
3130

3231

32+
def test_token_count_method():
33+
chat = ChatOpenAI(model="gpt-4o-mini")
34+
assert chat.token_count("What is 1 + 1?") == 31
35+
36+
chat = ChatAnthropic(model="claude-3-5-sonnet-20241022")
37+
assert chat.token_count("What is 1 + 1?") == 16
38+
39+
chat = ChatGoogle(model="gemini-1.5-flash")
40+
assert chat.token_count("What is 1 + 1?") == 9
41+
42+
3343
def test_usage_is_none():
3444
tokens_reset()
3545
assert token_usage() is None

0 commit comments

Comments
 (0)