Skip to content

Commit fd841fb

Browse files
committed
Add .set_token_limit() method to automatically drop old turns when specified limits are reached
1 parent 200e26c commit fd841fb

File tree

1 file changed

+129
-0
lines changed

1 file changed

+129
-0
lines changed

chatlas/_chat.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def __init__(
8989
self.provider = provider
9090
self._turns: list[Turn] = list(turns or [])
9191
self._tools: dict[str, Tool] = {}
92+
self.token_limits: Optional[tuple[int, int]] = None
9293
self._echo_options: EchoOptions = {
9394
"rich_markdown": {},
9495
"rich_console": {},
@@ -381,6 +382,121 @@ async def token_count_async(
381382
data_model=data_model,
382383
)
383384

385+
def set_token_limits(
386+
self,
387+
context_window: int,
388+
max_tokens: int,
389+
):
390+
"""
391+
Set a limit on the number of tokens that can be sent to the model.
392+
393+
By default, the size of the chat history is unbounded -- it keeps
394+
growing as you submit more input. This can be wasteful if you don't
395+
need to keep the entire chat history around, and can also lead to
396+
errors if the chat history gets too large for the model to handle.
397+
398+
This method allows you to set a limit to the number of tokens that can
399+
be sent to the model. If the limit is exceeded, the chat history will be
400+
truncated to fit within the limit (i.e., the oldest turns will be
401+
dropped).
402+
403+
Note that many models publish a context window as well as a maximum
404+
output token limit. For example,
405+
406+
<https://platform.openai.com/docs/models/gp#gpt-4o-realtime>
407+
<https://docs.anthropic.com/en/docs/about-claude/models#model-comparison-table>
408+
409+
Also, since the context window is the maximum number of input + output
410+
tokens, the maximum number of tokens that can be sent to the model in a
411+
single request is `context_window - max_tokens`.
412+
413+
Parameters
414+
----------
415+
context_window
416+
The maximum number of tokens that can be sent to the model.
417+
max_tokens
418+
The maximum number of tokens that the model is allowed to generate
419+
in a single response.
420+
421+
Note
422+
----
423+
This method uses `.token_count()` to estimate the token count for new input
424+
before truncating the chat history. This is an estimate, so it may not be
425+
perfect. Morever, any chat models based on `ChatOpenAI()` currently do not
426+
take the tool loop into account when estimating token counts. This means, if
427+
your input will trigger many tool calls, and/or the tool results are large,
428+
it's recommended to set a conservative limit on the `context_window`.
429+
430+
Examples
431+
--------
432+
```python
433+
from chatlas import ChatOpenAI
434+
435+
chat = ChatOpenAI(model="claude-3-5-sonnet-20241022")
436+
chat.set_token_limit(200000, 8192)
437+
```
438+
"""
439+
if max_tokens >= context_window:
440+
raise ValueError("`max_tokens` must be less than the `context_window`.")
441+
self.token_limits = (context_window, max_tokens)
442+
443+
def _maybe_drop_turns(
444+
self,
445+
*args: Content | str,
446+
data_model: Optional[type[BaseModel]] = None,
447+
):
448+
"""
449+
Drop turns from the chat history if they exceed the token limits.
450+
"""
451+
452+
# Do nothing if token limits are not set
453+
if self.token_limits is None:
454+
return None
455+
456+
turns = self.get_turns(include_system_prompt=False)
457+
458+
# Do nothing if this is the first turn
459+
if len(turns) == 0:
460+
return None
461+
462+
last_turn = turns[-1]
463+
464+
# Sanity checks (i.e., when about to submit new input, the last turn should
465+
# be from the assistant and should contain token counts)
466+
if last_turn.role != "assistant":
467+
raise ValueError(
468+
"Expected the last turn must be from the assistant. Please report this issue."
469+
)
470+
471+
if last_turn.tokens is None:
472+
raise ValueError(
473+
"Can't impose token limits since assistant turns contain token counts. "
474+
"Please report this issue and consider setting `.token_limits` to `None`."
475+
)
476+
477+
context_window, max_tokens = self.token_limits
478+
max_input_size = context_window - max_tokens
479+
480+
# Estimate the token count for the (new) user turn
481+
input_tokens = self.token_count(*args, data_model=data_model)
482+
483+
# Do nothing if current history size plus input size is within the limit
484+
remaining_tokens = max_input_size - input_tokens
485+
if sum(last_turn.tokens) < remaining_tokens:
486+
return self
487+
488+
tokens = self.tokens()
489+
490+
# Drop turns until they (plus the new input) fit within the token limits
491+
# TODO: we also need to account for the fact that dropping part of a tool loop is problematic
492+
while sum(tokens) >= remaining_tokens:
493+
del turns[2:]
494+
del tokens[2:]
495+
496+
self.set_turns(turns)
497+
498+
return None
499+
384500
def app(
385501
self,
386502
*,
@@ -531,6 +647,8 @@ def chat(
531647
A (consumed) response from the chat. Apply `str()` to this object to
532648
get the text content of the response.
533649
"""
650+
self._maybe_drop_turns(*args)
651+
534652
turn = user_turn(*args)
535653

536654
display = self._markdown_display(echo=echo)
@@ -581,6 +699,9 @@ async def chat_async(
581699
A (consumed) response from the chat. Apply `str()` to this object to
582700
get the text content of the response.
583701
"""
702+
# TODO: async version?
703+
self._maybe_drop_turns(*args)
704+
584705
turn = user_turn(*args)
585706

586707
display = self._markdown_display(echo=echo)
@@ -627,6 +748,8 @@ def stream(
627748
An (unconsumed) response from the chat. Iterate over this object to
628749
consume the response.
629750
"""
751+
self._maybe_drop_turns(*args)
752+
630753
turn = user_turn(*args)
631754

632755
display = self._markdown_display(echo=echo)
@@ -672,6 +795,9 @@ async def stream_async(
672795
An (unconsumed) response from the chat. Iterate over this object to
673796
consume the response.
674797
"""
798+
# TODO: async version?
799+
self._maybe_drop_turns(*args)
800+
675801
turn = user_turn(*args)
676802

677803
display = self._markdown_display(echo=echo)
@@ -715,6 +841,7 @@ def extract_data(
715841
dict[str, Any]
716842
The extracted data.
717843
"""
844+
self._maybe_drop_turns(*args, data_model=data_model)
718845

719846
display = self._markdown_display(echo=echo)
720847

@@ -775,6 +902,8 @@ async def extract_data_async(
775902
dict[str, Any]
776903
The extracted data.
777904
"""
905+
# TODO: async version?
906+
self._maybe_drop_turns(*args, data_model=data_model)
778907

779908
display = self._markdown_display(echo=echo)
780909

0 commit comments

Comments
 (0)