@@ -89,6 +89,7 @@ def __init__(
89
89
self .provider = provider
90
90
self ._turns : list [Turn ] = list (turns or [])
91
91
self ._tools : dict [str , Tool ] = {}
92
+ self .token_limits : Optional [tuple [int , int ]] = None
92
93
self ._echo_options : EchoOptions = {
93
94
"rich_markdown" : {},
94
95
"rich_console" : {},
@@ -381,6 +382,121 @@ async def token_count_async(
381
382
data_model = data_model ,
382
383
)
383
384
385
+ def set_token_limits (
386
+ self ,
387
+ context_window : int ,
388
+ max_tokens : int ,
389
+ ):
390
+ """
391
+ Set a limit on the number of tokens that can be sent to the model.
392
+
393
+ By default, the size of the chat history is unbounded -- it keeps
394
+ growing as you submit more input. This can be wasteful if you don't
395
+ need to keep the entire chat history around, and can also lead to
396
+ errors if the chat history gets too large for the model to handle.
397
+
398
+ This method allows you to set a limit to the number of tokens that can
399
+ be sent to the model. If the limit is exceeded, the chat history will be
400
+ truncated to fit within the limit (i.e., the oldest turns will be
401
+ dropped).
402
+
403
+ Note that many models publish a context window as well as a maximum
404
+ output token limit. For example,
405
+
406
+ <https://platform.openai.com/docs/models/gp#gpt-4o-realtime>
407
+ <https://docs.anthropic.com/en/docs/about-claude/models#model-comparison-table>
408
+
409
+ Also, since the context window is the maximum number of input + output
410
+ tokens, the maximum number of tokens that can be sent to the model in a
411
+ single request is `context_window - max_tokens`.
412
+
413
+ Parameters
414
+ ----------
415
+ context_window
416
+ The maximum number of tokens that can be sent to the model.
417
+ max_tokens
418
+ The maximum number of tokens that the model is allowed to generate
419
+ in a single response.
420
+
421
+ Note
422
+ ----
423
+ This method uses `.token_count()` to estimate the token count for new input
424
+ before truncating the chat history. This is an estimate, so it may not be
425
+ perfect. Morever, any chat models based on `ChatOpenAI()` currently do not
426
+ take the tool loop into account when estimating token counts. This means, if
427
+ your input will trigger many tool calls, and/or the tool results are large,
428
+ it's recommended to set a conservative limit on the `context_window`.
429
+
430
+ Examples
431
+ --------
432
+ ```python
433
+ from chatlas import ChatOpenAI
434
+
435
+ chat = ChatOpenAI(model="claude-3-5-sonnet-20241022")
436
+ chat.set_token_limit(200000, 8192)
437
+ ```
438
+ """
439
+ if max_tokens >= context_window :
440
+ raise ValueError ("`max_tokens` must be less than the `context_window`." )
441
+ self .token_limits = (context_window , max_tokens )
442
+
443
+ def _maybe_drop_turns (
444
+ self ,
445
+ * args : Content | str ,
446
+ data_model : Optional [type [BaseModel ]] = None ,
447
+ ):
448
+ """
449
+ Drop turns from the chat history if they exceed the token limits.
450
+ """
451
+
452
+ # Do nothing if token limits are not set
453
+ if self .token_limits is None :
454
+ return None
455
+
456
+ turns = self .get_turns (include_system_prompt = False )
457
+
458
+ # Do nothing if this is the first turn
459
+ if len (turns ) == 0 :
460
+ return None
461
+
462
+ last_turn = turns [- 1 ]
463
+
464
+ # Sanity checks (i.e., when about to submit new input, the last turn should
465
+ # be from the assistant and should contain token counts)
466
+ if last_turn .role != "assistant" :
467
+ raise ValueError (
468
+ "Expected the last turn must be from the assistant. Please report this issue."
469
+ )
470
+
471
+ if last_turn .tokens is None :
472
+ raise ValueError (
473
+ "Can't impose token limits since assistant turns contain token counts. "
474
+ "Please report this issue and consider setting `.token_limits` to `None`."
475
+ )
476
+
477
+ context_window , max_tokens = self .token_limits
478
+ max_input_size = context_window - max_tokens
479
+
480
+ # Estimate the token count for the (new) user turn
481
+ input_tokens = self .token_count (* args , data_model = data_model )
482
+
483
+ # Do nothing if current history size plus input size is within the limit
484
+ remaining_tokens = max_input_size - input_tokens
485
+ if sum (last_turn .tokens ) < remaining_tokens :
486
+ return self
487
+
488
+ tokens = self .tokens ()
489
+
490
+ # Drop turns until they (plus the new input) fit within the token limits
491
+ # TODO: we also need to account for the fact that dropping part of a tool loop is problematic
492
+ while sum (tokens ) >= remaining_tokens :
493
+ del turns [2 :]
494
+ del tokens [2 :]
495
+
496
+ self .set_turns (turns )
497
+
498
+ return None
499
+
384
500
def app (
385
501
self ,
386
502
* ,
@@ -531,6 +647,8 @@ def chat(
531
647
A (consumed) response from the chat. Apply `str()` to this object to
532
648
get the text content of the response.
533
649
"""
650
+ self ._maybe_drop_turns (* args )
651
+
534
652
turn = user_turn (* args )
535
653
536
654
display = self ._markdown_display (echo = echo )
@@ -581,6 +699,9 @@ async def chat_async(
581
699
A (consumed) response from the chat. Apply `str()` to this object to
582
700
get the text content of the response.
583
701
"""
702
+ # TODO: async version?
703
+ self ._maybe_drop_turns (* args )
704
+
584
705
turn = user_turn (* args )
585
706
586
707
display = self ._markdown_display (echo = echo )
@@ -627,6 +748,8 @@ def stream(
627
748
An (unconsumed) response from the chat. Iterate over this object to
628
749
consume the response.
629
750
"""
751
+ self ._maybe_drop_turns (* args )
752
+
630
753
turn = user_turn (* args )
631
754
632
755
display = self ._markdown_display (echo = echo )
@@ -672,6 +795,9 @@ async def stream_async(
672
795
An (unconsumed) response from the chat. Iterate over this object to
673
796
consume the response.
674
797
"""
798
+ # TODO: async version?
799
+ self ._maybe_drop_turns (* args )
800
+
675
801
turn = user_turn (* args )
676
802
677
803
display = self ._markdown_display (echo = echo )
@@ -715,6 +841,7 @@ def extract_data(
715
841
dict[str, Any]
716
842
The extracted data.
717
843
"""
844
+ self ._maybe_drop_turns (* args , data_model = data_model )
718
845
719
846
display = self ._markdown_display (echo = echo )
720
847
@@ -775,6 +902,8 @@ async def extract_data_async(
775
902
dict[str, Any]
776
903
The extracted data.
777
904
"""
905
+ # TODO: async version?
906
+ self ._maybe_drop_turns (* args , data_model = data_model )
778
907
779
908
display = self ._markdown_display (echo = echo )
780
909
0 commit comments