Skip to content

Commit b781458

Browse files
authored
feat: Add ToolRejectError and tool request/result callbacks (#101)
* feat: Add ToolRejectError and tool request/result callbacks * Add to API reference and test for console output
1 parent 04bf10c commit b781458

File tree

8 files changed

+386
-31
lines changed

8 files changed

+386
-31
lines changed

CHANGELOG.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99

1010
## [UNRELEASED]
1111

12+
### New features
13+
14+
* New `.on_tool_request()` and `.on_tool_result()` methods register callbacks that fire when a tool is requested or produces a result. These callbacks can be used to implement custom logging or other actions when tools are called, without modifying the tool function (#101).
15+
* New `ToolRejectError` exception can be thrown from tool request/result callbacks or from within a tool function itself to prevent the tool from executing. Moreover, this exception will provide some context for the the LLM to know that the tool didn't produce a result because it was rejected. (#101)
16+
1217
### Improvements
1318

14-
* The `CHATLAS_LOG` environment variable nows enable logs for the relevant model provider. It now also supports a lovel of `debug` in addition to `info`. (#97)
19+
* The `CHATLAS_LOG` environment variable now enables logs for the relevant model provider. It now also supports a level of `debug` in addition to `info`. (#97)
1520
* `ChatSnowflake()` now supports tool calling. (#98)
1621
* `Chat` instances can now be deep copied, which is useful for forking the chat session. (#96)
1722

chatlas/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from ._provider import Provider
1717
from ._snowflake import ChatSnowflake
1818
from ._tokens import token_usage
19-
from ._tools import Tool
19+
from ._tools import Tool, ToolRejectError
2020
from ._turn import Turn
2121

2222
try:
@@ -51,6 +51,7 @@
5151
"Provider",
5252
"token_usage",
5353
"Tool",
54+
"ToolRejectError",
5455
"Turn",
5556
"types",
5657
)

chatlas/_callbacks.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from collections import OrderedDict
2+
from typing import Any, Callable
3+
4+
from ._utils import is_async_callable
5+
6+
7+
class CallbackManager:
8+
def __init__(self) -> None:
9+
self._callbacks: dict[str, Callable[..., Any]] = OrderedDict()
10+
self._id: int = 1
11+
12+
def add(self, callback: Callable[..., Any]) -> Callable[[], None]:
13+
callback_id = self._next_id()
14+
self._callbacks[callback_id] = callback
15+
16+
def _rm_callback() -> None:
17+
self._callbacks.pop(callback_id, None)
18+
19+
return _rm_callback
20+
21+
def invoke(self, *args: Any, **kwargs: Any) -> None:
22+
if not self._callbacks:
23+
return
24+
25+
# Invoke in reverse insertion order
26+
for callback_id in reversed(list(self._callbacks.keys())):
27+
callback = self._callbacks[callback_id]
28+
if is_async_callable(callback):
29+
raise RuntimeError(
30+
"Can't use async callbacks with `.chat()`/`.stream()`."
31+
"Async callbacks can only be used with `.chat_async()`/`.stream_async()`."
32+
)
33+
callback(*args, **kwargs)
34+
35+
async def invoke_async(self, *args: Any, **kwargs: Any) -> None:
36+
if not self._callbacks:
37+
return
38+
39+
# Invoke in reverse insertion order
40+
for callback_id in reversed(list(self._callbacks.keys())):
41+
callback = self._callbacks[callback_id]
42+
if is_async_callable(callback):
43+
await callback(*args, **kwargs)
44+
else:
45+
callback(*args, **kwargs)
46+
47+
def count(self) -> int:
48+
return len(self._callbacks)
49+
50+
def get_callbacks(self) -> list[Callable[..., Any]]:
51+
return list(self._callbacks.values())
52+
53+
def _next_id(self) -> str:
54+
current_id = self._id
55+
self._id += 1
56+
return str(current_id)

chatlas/_chat.py

Lines changed: 107 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
from pydantic import BaseModel
2828

29+
from ._callbacks import CallbackManager
2930
from ._content import (
3031
Content,
3132
ContentJson,
@@ -42,7 +43,7 @@
4243
)
4344
from ._logging import log_tool_error
4445
from ._provider import Provider
45-
from ._tools import Tool
46+
from ._tools import Tool, ToolRejectError
4647
from ._turn import Turn, user_turn
4748
from ._typing_extensions import TypedDict
4849
from ._utils import html_escape, wrap_async
@@ -96,6 +97,8 @@ def __init__(
9697
self.provider = provider
9798
self._turns: list[Turn] = list(turns or [])
9899
self._tools: dict[str, Tool] = {}
100+
self._on_tool_request_callbacks = CallbackManager()
101+
self._on_tool_result_callbacks = CallbackManager()
99102
self._current_display: Optional[MarkdownDisplay] = None
100103
self._echo_options: EchoDisplayOptions = {
101104
"rich_markdown": {},
@@ -988,6 +991,53 @@ def add(a: int, b: int) -> int:
988991
tool = Tool(func, model=model)
989992
self._tools[tool.name] = tool
990993

994+
def on_tool_request(self, callback: Callable[[ContentToolRequest], None]):
995+
"""
996+
Register a callback for a tool request event.
997+
998+
A tool request event occurs when the assistant requests a tool to be
999+
called on its behalf. Before invoking the tool, `on_tool_request`
1000+
handlers are called with the relevant `ContentToolRequest` object. This
1001+
is useful if you want to handle tool requests in a custom way, such as
1002+
requiring logging them or requiring user approval before invoking the
1003+
tool
1004+
1005+
Parameters
1006+
----------
1007+
callback
1008+
A function to be called when a tool request event occurs.
1009+
This function must have a single argument, which will be the
1010+
tool request (i.e., a `ContentToolRequest` object).
1011+
1012+
Returns
1013+
-------
1014+
A callable that can be used to remove the callback later.
1015+
"""
1016+
return self._on_tool_request_callbacks.add(callback)
1017+
1018+
def on_tool_result(self, callback: Callable[[ContentToolResult], None]):
1019+
"""
1020+
Register a callback for a tool result event.
1021+
1022+
A tool result event occurs when a tool has been invoked and the
1023+
result is ready to be provided to the assistant. After the tool
1024+
has been invoked, `on_tool_result` handlers are called with the
1025+
relevant `ContentToolResult` object. This is useful if you want to
1026+
handle tool results in a custom way such as logging them.
1027+
1028+
Parameters
1029+
----------
1030+
callback
1031+
A function to be called when a tool result event occurs.
1032+
This function must have a single argument, which will be the
1033+
tool result (i.e., a `ContentToolResult` object).
1034+
1035+
Returns
1036+
-------
1037+
A callable that can be used to remove the callback later.
1038+
"""
1039+
return self._on_tool_result_callbacks.add(callback)
1040+
9911041
@property
9921042
def current_display(self) -> Optional[MarkdownDisplay]:
9931043
"""
@@ -1418,28 +1468,43 @@ def _invoke_tool(self, x: ContentToolRequest) -> ContentToolResult:
14181468
e = RuntimeError(f"Unknown tool: {x.name}")
14191469
return ContentToolResult(value=None, error=e, request=x)
14201470

1421-
args = x.arguments
1422-
1471+
# First, invoke the request callbacks. If a ToolRejectError is raised,
1472+
# treat it like a tool failure (i.e., gracefully handle it).
1473+
result: ContentToolResult | None = None
14231474
try:
1424-
if isinstance(args, dict):
1425-
result = func(**args)
1426-
else:
1427-
result = func(args)
1475+
self._on_tool_request_callbacks.invoke(x)
1476+
except ToolRejectError as e:
1477+
result = ContentToolResult(value=None, error=e, request=x)
1478+
1479+
# Invoke the tool (if it hasn't been rejected).
1480+
if result is None:
1481+
try:
1482+
if isinstance(x.arguments, dict):
1483+
res = func(**x.arguments)
1484+
else:
1485+
res = func(x.arguments)
14281486

1429-
if not isinstance(result, ContentToolResult):
1430-
result = ContentToolResult(value=result)
1487+
if isinstance(res, ContentToolResult):
1488+
result = res
1489+
else:
1490+
result = ContentToolResult(value=res)
1491+
1492+
result.request = x
1493+
except Exception as e:
1494+
result = ContentToolResult(value=None, error=e, request=x)
14311495

1432-
result.request = x
1433-
return result
1434-
except Exception as e:
1496+
# If we've captured an error, notify and log it.
1497+
if result.error:
14351498
warnings.warn(
14361499
f"Calling tool '{x.name}' led to an error.",
14371500
ToolFailureWarning,
14381501
stacklevel=2,
14391502
)
14401503
traceback.print_exc()
1441-
log_tool_error(x.name, str(args), e)
1442-
return ContentToolResult(value=None, error=e, request=x)
1504+
log_tool_error(x.name, str(x.arguments), result.error)
1505+
1506+
self._on_tool_result_callbacks.invoke(result)
1507+
return result
14431508

14441509
async def _invoke_tool_async(self, x: ContentToolRequest) -> ContentToolResult:
14451510
tool_def = self._tools.get(x.name, None)
@@ -1454,28 +1519,43 @@ async def _invoke_tool_async(self, x: ContentToolRequest) -> ContentToolResult:
14541519
e = RuntimeError(f"Unknown tool: {x.name}")
14551520
return ContentToolResult(value=None, error=e, request=x)
14561521

1457-
args = x.arguments
1458-
1522+
# First, invoke the request callbacks. If a ToolRejectError is raised,
1523+
# treat it like a tool failure (i.e., gracefully handle it).
1524+
result: ContentToolResult | None = None
14591525
try:
1460-
if isinstance(args, dict):
1461-
result = await func(**args)
1462-
else:
1463-
result = await func(args)
1526+
await self._on_tool_request_callbacks.invoke_async(x)
1527+
except ToolRejectError as e:
1528+
result = ContentToolResult(value=None, error=e, request=x)
1529+
1530+
# Invoke the tool (if it hasn't been rejected).
1531+
if result is None:
1532+
try:
1533+
if isinstance(x.arguments, dict):
1534+
res = await func(**x.arguments)
1535+
else:
1536+
res = await func(x.arguments)
14641537

1465-
if not isinstance(result, ContentToolResult):
1466-
result = ContentToolResult(value=result)
1538+
if isinstance(res, ContentToolResult):
1539+
result = res
1540+
else:
1541+
result = ContentToolResult(value=res)
14671542

1468-
result.request = x
1469-
return result
1470-
except Exception as e:
1543+
result.request = x
1544+
except Exception as e:
1545+
result = ContentToolResult(value=None, error=e, request=x)
1546+
1547+
# If we've captured an error, notify and log it.
1548+
if result.error:
14711549
warnings.warn(
14721550
f"Calling tool '{x.name}' led to an error.",
14731551
ToolFailureWarning,
14741552
stacklevel=2,
14751553
)
14761554
traceback.print_exc()
1477-
log_tool_error(x.name, str(args), e)
1478-
return ContentToolResult(value=None, error=e, request=x)
1555+
log_tool_error(x.name, str(x.arguments), result.error)
1556+
1557+
await self._on_tool_result_callbacks.invoke_async(result)
1558+
return result
14791559

14801560
def _markdown_display(self, echo: EchoOptions) -> ChatMarkdownDisplay:
14811561
"""

chatlas/_tools.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88

99
from . import _utils
1010

11-
__all__ = ("Tool",)
11+
__all__ = (
12+
"Tool",
13+
"ToolRejectError",
14+
)
1215

1316
if TYPE_CHECKING:
1417
from openai.types.chat import ChatCompletionToolParam
@@ -47,6 +50,61 @@ def __init__(
4750
self.name = self.schema["function"]["name"]
4851

4952

53+
class ToolRejectError(Exception):
54+
"""
55+
Error to represent a tool call being rejected.
56+
57+
This error is meant to be raised when an end user has chosen to deny a tool
58+
call. It can be raised in a tool function or in a `.on_tool_request()`
59+
callback registered via a :class:`~chatlas.Chat`. When used in the callback,
60+
the tool call is rejected before the tool function is invoked.
61+
62+
Parameters
63+
----------
64+
reason
65+
A string describing the reason for rejecting the tool call. This will be
66+
included in the error message passed to the LLM. In addition to the
67+
reason, the error message will also include "Tool call rejected." to
68+
indicate that the tool call was not processed.
69+
70+
Raises
71+
-------
72+
ToolRejectError
73+
An error with a message informing the LLM that the tool call was
74+
rejected (and the reason why).
75+
76+
Examples
77+
--------
78+
>>> import os
79+
>>> import chatlas as ctl
80+
>>>
81+
>>> chat = ctl.ChatOpenAI()
82+
>>>
83+
>>> def list_files():
84+
... "List files in the user's current directory"
85+
... while True:
86+
... allow = input(
87+
... "Would you like to allow access to your current directory? (yes/no): "
88+
... )
89+
... if allow.lower() == "yes":
90+
... return os.listdir(".")
91+
... elif allow.lower() == "no":
92+
... raise ctl.ToolRejectError(
93+
... "The user has chosen to disallow the tool call."
94+
... )
95+
... else:
96+
... print("Please answer with 'yes' or 'no'.")
97+
>>>
98+
>>> chat.register_tool(list_files)
99+
>>> chat.chat("What files are available in my current directory?")
100+
"""
101+
102+
def __init__(self, reason: str = "The user has chosen to disallow the tool call."):
103+
message = f"Tool call rejected. {reason}"
104+
super().__init__(message)
105+
self.message = message
106+
107+
50108
def func_to_schema(
51109
func: Callable[..., Any] | Callable[..., Awaitable[Any]],
52110
model: Optional[type[BaseModel]] = None,

docs/_quarto.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ quartodoc:
113113
desc: Add context to python function before registering it as a tool.
114114
contents:
115115
- Tool
116+
- ToolRejectError
116117
- title: Turns
117118
desc: A provider-agnostic representation of content generated during an assistant/user turn.
118119
contents:

0 commit comments

Comments
 (0)