Skip to content

Commit 3680197

Browse files
Teachability for any agent (#1091)
* Partial implementation * Partial implementation * Fixes * update tests * cleanup * update tests * comments * logging * wording * underscore * Extend notebook for teachable GPTAssistantAgent * Notebook for teachable GPTAssistantAgents * Update notebook * Update notebook * Update notebook * Update notebook * revert file * Update blog post and other documentation. * pre-commit * Address reviewer feedback. * Add new nb link to examples page. --------- Co-authored-by: Chi Wang <wang.chi@microsoft.com>
1 parent 172df55 commit 3680197

File tree

13 files changed

+1719
-472
lines changed

13 files changed

+1719
-472
lines changed

.github/workflows/contrib-tests.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -155,11 +155,11 @@ jobs:
155155
run: |
156156
python -m pip install --upgrade pip wheel
157157
pip install pytest
158-
- name: Install packages and dependencies for TeachableAgent
158+
- name: Install packages and dependencies for Teachability
159159
run: |
160160
pip install -e .[teachable]
161161
pip uninstall -y openai
162-
- name: Test TeachableAgent
162+
- name: Test Teachability
163163
if: matrix.python-version != '3.9' # diversify the python versions
164164
run: |
165165
pytest test/agentchat/contrib/test_teachable_agent.py

.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ key_aoai.txt
165165
base_aoai.txt
166166
wolfram.txt
167167

168-
# DB on disk for TeachableAgent
168+
# DB on disk for Teachability
169169
tmp/
170170
test/my_tmp/*
171171

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from autogen.agentchat.assistant_agent import ConversableAgent
2+
3+
4+
class AgentCapability:
5+
"""Base class for composable capabilities that can be added to an agent."""
6+
7+
def __init__(self):
8+
pass
9+
10+
def add_to_agent(self, agent: ConversableAgent):
11+
"""
12+
Adds a particular capability to the given agent. Must be implemented by the capability subclass.
13+
An implementation will typically call agent.register_hook() one or more times. See teachability.py as an example.
14+
"""
15+
raise NotImplementedError

autogen/agentchat/contrib/teachable_agent.py renamed to autogen/agentchat/contrib/capabilities/teachability.py

+116-150
Large diffs are not rendered by default.

autogen/agentchat/contrib/text_analyzer_agent.py

+1-9
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,6 @@ def __init__(
2929
Please refer to [OpenAIWrapper.create](/docs/reference/oai/client#create)
3030
for available options.
3131
To disable llm-based auto reply, set to False.
32-
teach_config (dict or None): Additional parameters used by TeachableAgent.
33-
To use default config, set to None. Otherwise, set to a dictionary with any of the following keys:
34-
- verbosity (Optional, int): # 0 (default) for basic info, 1 to add memory operations, 2 for analyzer messages, 3 for memo lists.
35-
- reset_db (Optional, bool): True to clear the DB before starting. Default False.
36-
- path_to_db_dir (Optional, str): path to the directory where the DB is stored. Default "./tmp/teachable_agent_db"
37-
- prepopulate (Optional, int): True (default) to prepopulate the DB with a set of input-output pairs.
38-
- recall_threshold (Optional, float): The maximum distance for retrieved memos, where 0.0 is exact match. Default 1.5. Larger values allow more (but less relevant) memos to be recalled.
39-
- max_num_retrievals (Optional, int): The maximum number of memos to retrieve from the DB. Default 10.
4032
**kwargs (dict): other kwargs in [ConversableAgent](../conversable_agent#__init__).
4133
"""
4234
super().__init__(
@@ -56,7 +48,7 @@ def _analyze_in_reply(
5648
) -> Tuple[bool, Union[str, Dict, None]]:
5749
"""Analyzes the given text as instructed, and returns the analysis as a message.
5850
Assumes exactly two messages containing the text to analyze and the analysis instructions.
59-
See TeachableAgent.analyze for an example of how to use this method."""
51+
See Teachability.analyze for an example of how to use this method."""
6052
if self.llm_config is False:
6153
raise ValueError("TextAnalyzerAgent requires self.llm_config to be set in its base class.")
6254
if messages is None:

autogen/agentchat/conversable_agent.py

+66-1
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,10 @@ def __init__(
152152
self.register_reply([Agent, None], ConversableAgent.check_termination_and_human_reply)
153153
self.register_reply([Agent, None], ConversableAgent.a_check_termination_and_human_reply)
154154

155+
# Registered hooks are kept in lists, indexed by hookable method, to be called in their order of registration.
156+
# New hookable methods should be added to this list as required to support new agent capabilities.
157+
self.hook_lists = {self.process_last_message: []} # This is currently the only hookable method.
158+
155159
def register_reply(
156160
self,
157161
trigger: Union[Type[Agent], str, Agent, Callable[[Agent], bool], List],
@@ -757,7 +761,7 @@ def generate_code_execution_reply(
757761
else:
758762
messages_to_scan += 1
759763

760-
# iterate through the last n messages reversely
764+
# iterate through the last n messages in reverse
761765
# if code blocks are found, execute the code blocks and return the output
762766
# if no code blocks are found, continue
763767
for i in range(min(len(messages), messages_to_scan)):
@@ -1173,6 +1177,10 @@ def generate_reply(
11731177
if messages is None:
11741178
messages = self._oai_messages[sender]
11751179

1180+
# Call the hookable method that gives registered hooks a chance to process the last message.
1181+
# Message modifications do not affect the incoming messages or self._oai_messages.
1182+
messages = self.process_last_message(messages)
1183+
11761184
for reply_func_tuple in self._reply_func_list:
11771185
reply_func = reply_func_tuple["reply_func"]
11781186
if exclude and reply_func in exclude:
@@ -1225,6 +1233,10 @@ async def a_generate_reply(
12251233
if messages is None:
12261234
messages = self._oai_messages[sender]
12271235

1236+
# Call the hookable method that gives registered hooks a chance to process the last message.
1237+
# Message modifications do not affect the incoming messages or self._oai_messages.
1238+
messages = self.process_last_message(messages)
1239+
12281240
for reply_func_tuple in self._reply_func_list:
12291241
reply_func = reply_func_tuple["reply_func"]
12301242
if exclude and reply_func in exclude:
@@ -1757,3 +1769,56 @@ def _decorator(func: F) -> F:
17571769
return func
17581770

17591771
return _decorator
1772+
1773+
def register_hook(self, hookable_method: Callable, hook: Callable):
1774+
"""
1775+
Registers a hook to be called by a hookable method, in order to add a capability to the agent.
1776+
Registered hooks are kept in lists (one per hookable method), and are called in their order of registration.
1777+
1778+
Args:
1779+
hookable_method: A hookable method implemented by ConversableAgent.
1780+
hook: A method implemented by a subclass of AgentCapability.
1781+
"""
1782+
assert hookable_method in self.hook_lists, f"{hookable_method} is not a hookable method."
1783+
hook_list = self.hook_lists[hookable_method]
1784+
assert hook not in hook_list, f"{hook} is already registered as a hook."
1785+
hook_list.append(hook)
1786+
1787+
def process_last_message(self, messages):
1788+
"""
1789+
Calls any registered capability hooks to use and potentially modify the text of the last message,
1790+
as long as the last message is not a function call or exit command.
1791+
"""
1792+
1793+
# If any required condition is not met, return the original message list.
1794+
hook_list = self.hook_lists[self.process_last_message]
1795+
if len(hook_list) == 0:
1796+
return messages # No hooks registered.
1797+
if messages is None:
1798+
return None # No message to process.
1799+
if len(messages) == 0:
1800+
return messages # No message to process.
1801+
last_message = messages[-1]
1802+
if "function_call" in last_message:
1803+
return messages # Last message is a function call.
1804+
if "context" in last_message:
1805+
return messages # Last message contains a context key.
1806+
if "content" not in last_message:
1807+
return messages # Last message has no content.
1808+
user_text = last_message["content"]
1809+
if not isinstance(user_text, str):
1810+
return messages # Last message content is not a string. TODO: Multimodal agents will use a dict here.
1811+
if user_text == "exit":
1812+
return messages # Last message is an exit command.
1813+
1814+
# Call each hook (in order of registration) to process the user's message.
1815+
processed_user_text = user_text
1816+
for hook in hook_list:
1817+
processed_user_text = hook(processed_user_text)
1818+
if processed_user_text == user_text:
1819+
return messages # No hooks actually modified the user's message.
1820+
1821+
# Replace the last user message with the expanded one.
1822+
messages = messages.copy()
1823+
messages[-1]["content"] = processed_user_text
1824+
return messages

0 commit comments

Comments
 (0)