26
26
27
27
from pydantic import BaseModel
28
28
29
+ from ._callbacks import CallbackManager
29
30
from ._content import (
30
31
Content ,
31
32
ContentJson ,
42
43
)
43
44
from ._logging import log_tool_error
44
45
from ._provider import Provider
45
- from ._tools import Tool
46
+ from ._tools import Tool , ToolRejectError
46
47
from ._turn import Turn , user_turn
47
48
from ._typing_extensions import TypedDict
48
49
from ._utils import html_escape , wrap_async
@@ -96,6 +97,8 @@ def __init__(
96
97
self .provider = provider
97
98
self ._turns : list [Turn ] = list (turns or [])
98
99
self ._tools : dict [str , Tool ] = {}
100
+ self ._on_tool_request_callbacks = CallbackManager ()
101
+ self ._on_tool_result_callbacks = CallbackManager ()
99
102
self ._current_display : Optional [MarkdownDisplay ] = None
100
103
self ._echo_options : EchoDisplayOptions = {
101
104
"rich_markdown" : {},
@@ -988,6 +991,53 @@ def add(a: int, b: int) -> int:
988
991
tool = Tool (func , model = model )
989
992
self ._tools [tool .name ] = tool
990
993
994
+ def on_tool_request (self , callback : Callable [[ContentToolRequest ], None ]):
995
+ """
996
+ Register a callback for a tool request event.
997
+
998
+ A tool request event occurs when the assistant requests a tool to be
999
+ called on its behalf. Before invoking the tool, `on_tool_request`
1000
+ handlers are called with the relevant `ContentToolRequest` object. This
1001
+ is useful if you want to handle tool requests in a custom way, such as
1002
+ requiring logging them or requiring user approval before invoking the
1003
+ tool
1004
+
1005
+ Parameters
1006
+ ----------
1007
+ callback
1008
+ A function to be called when a tool request event occurs.
1009
+ This function must have a single argument, which will be the
1010
+ tool request (i.e., a `ContentToolRequest` object).
1011
+
1012
+ Returns
1013
+ -------
1014
+ A callable that can be used to remove the callback later.
1015
+ """
1016
+ return self ._on_tool_request_callbacks .add (callback )
1017
+
1018
+ def on_tool_result (self , callback : Callable [[ContentToolResult ], None ]):
1019
+ """
1020
+ Register a callback for a tool result event.
1021
+
1022
+ A tool result event occurs when a tool has been invoked and the
1023
+ result is ready to be provided to the assistant. After the tool
1024
+ has been invoked, `on_tool_result` handlers are called with the
1025
+ relevant `ContentToolResult` object. This is useful if you want to
1026
+ handle tool results in a custom way such as logging them.
1027
+
1028
+ Parameters
1029
+ ----------
1030
+ callback
1031
+ A function to be called when a tool result event occurs.
1032
+ This function must have a single argument, which will be the
1033
+ tool result (i.e., a `ContentToolResult` object).
1034
+
1035
+ Returns
1036
+ -------
1037
+ A callable that can be used to remove the callback later.
1038
+ """
1039
+ return self ._on_tool_result_callbacks .add (callback )
1040
+
991
1041
@property
992
1042
def current_display (self ) -> Optional [MarkdownDisplay ]:
993
1043
"""
@@ -1418,28 +1468,43 @@ def _invoke_tool(self, x: ContentToolRequest) -> ContentToolResult:
1418
1468
e = RuntimeError (f"Unknown tool: { x .name } " )
1419
1469
return ContentToolResult (value = None , error = e , request = x )
1420
1470
1421
- args = x .arguments
1422
-
1471
+ # First, invoke the request callbacks. If a ToolRejectError is raised,
1472
+ # treat it like a tool failure (i.e., gracefully handle it).
1473
+ result : ContentToolResult | None = None
1423
1474
try :
1424
- if isinstance (args , dict ):
1425
- result = func (** args )
1426
- else :
1427
- result = func (args )
1475
+ self ._on_tool_request_callbacks .invoke (x )
1476
+ except ToolRejectError as e :
1477
+ result = ContentToolResult (value = None , error = e , request = x )
1478
+
1479
+ # Invoke the tool (if it hasn't been rejected).
1480
+ if result is None :
1481
+ try :
1482
+ if isinstance (x .arguments , dict ):
1483
+ res = func (** x .arguments )
1484
+ else :
1485
+ res = func (x .arguments )
1428
1486
1429
- if not isinstance (result , ContentToolResult ):
1430
- result = ContentToolResult (value = result )
1487
+ if isinstance (res , ContentToolResult ):
1488
+ result = res
1489
+ else :
1490
+ result = ContentToolResult (value = res )
1491
+
1492
+ result .request = x
1493
+ except Exception as e :
1494
+ result = ContentToolResult (value = None , error = e , request = x )
1431
1495
1432
- result .request = x
1433
- return result
1434
- except Exception as e :
1496
+ # If we've captured an error, notify and log it.
1497
+ if result .error :
1435
1498
warnings .warn (
1436
1499
f"Calling tool '{ x .name } ' led to an error." ,
1437
1500
ToolFailureWarning ,
1438
1501
stacklevel = 2 ,
1439
1502
)
1440
1503
traceback .print_exc ()
1441
- log_tool_error (x .name , str (args ), e )
1442
- return ContentToolResult (value = None , error = e , request = x )
1504
+ log_tool_error (x .name , str (x .arguments ), result .error )
1505
+
1506
+ self ._on_tool_result_callbacks .invoke (result )
1507
+ return result
1443
1508
1444
1509
async def _invoke_tool_async (self , x : ContentToolRequest ) -> ContentToolResult :
1445
1510
tool_def = self ._tools .get (x .name , None )
@@ -1454,28 +1519,43 @@ async def _invoke_tool_async(self, x: ContentToolRequest) -> ContentToolResult:
1454
1519
e = RuntimeError (f"Unknown tool: { x .name } " )
1455
1520
return ContentToolResult (value = None , error = e , request = x )
1456
1521
1457
- args = x .arguments
1458
-
1522
+ # First, invoke the request callbacks. If a ToolRejectError is raised,
1523
+ # treat it like a tool failure (i.e., gracefully handle it).
1524
+ result : ContentToolResult | None = None
1459
1525
try :
1460
- if isinstance (args , dict ):
1461
- result = await func (** args )
1462
- else :
1463
- result = await func (args )
1526
+ await self ._on_tool_request_callbacks .invoke_async (x )
1527
+ except ToolRejectError as e :
1528
+ result = ContentToolResult (value = None , error = e , request = x )
1529
+
1530
+ # Invoke the tool (if it hasn't been rejected).
1531
+ if result is None :
1532
+ try :
1533
+ if isinstance (x .arguments , dict ):
1534
+ res = await func (** x .arguments )
1535
+ else :
1536
+ res = await func (x .arguments )
1464
1537
1465
- if not isinstance (result , ContentToolResult ):
1466
- result = ContentToolResult (value = result )
1538
+ if isinstance (res , ContentToolResult ):
1539
+ result = res
1540
+ else :
1541
+ result = ContentToolResult (value = res )
1467
1542
1468
- result .request = x
1469
- return result
1470
- except Exception as e :
1543
+ result .request = x
1544
+ except Exception as e :
1545
+ result = ContentToolResult (value = None , error = e , request = x )
1546
+
1547
+ # If we've captured an error, notify and log it.
1548
+ if result .error :
1471
1549
warnings .warn (
1472
1550
f"Calling tool '{ x .name } ' led to an error." ,
1473
1551
ToolFailureWarning ,
1474
1552
stacklevel = 2 ,
1475
1553
)
1476
1554
traceback .print_exc ()
1477
- log_tool_error (x .name , str (args ), e )
1478
- return ContentToolResult (value = None , error = e , request = x )
1555
+ log_tool_error (x .name , str (x .arguments ), result .error )
1556
+
1557
+ await self ._on_tool_result_callbacks .invoke_async (result )
1558
+ return result
1479
1559
1480
1560
def _markdown_display (self , echo : EchoOptions ) -> ChatMarkdownDisplay :
1481
1561
"""
0 commit comments