Skip to content

Commit 837f2f7

Browse files
Feat/wdz/tool arg handler (#257)
feat: add ToolArgumentsHandler to tool node
1 parent b29da23 commit 837f2f7

File tree

2 files changed

+45
-12
lines changed

2 files changed

+45
-12
lines changed

compose/tool_node.go

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,10 @@ func withExecutedTools(executedTools map[string]string) ToolsNodeOption {
6464
// Invoke(ctx context.Context, input *schema.Message, opts ...ToolsNodeOption) ([]*schema.Message, error)
6565
// Stream(ctx context.Context, input *schema.Message, opts ...ToolsNodeOption) (*schema.StreamReader[[]*schema.Message], error)
6666
type ToolsNode struct {
67-
tuple *toolsTuple
68-
unknownToolHandler func(ctx context.Context, name, input string) (string, error)
69-
executeSequentially bool
67+
tuple *toolsTuple
68+
unknownToolHandler func(ctx context.Context, name, input string) (string, error)
69+
executeSequentially bool
70+
toolArgumentsHandler func(ctx context.Context, name, input string) (string, error)
7071
}
7172

7273
// ToolsNodeConfig is the config for ToolsNode.
@@ -91,6 +92,17 @@ type ToolsNodeConfig struct {
9192
// When set to true, tool calls will be executed one after another in the order they appear in the input message.
9293
// When set to false (default), tool calls will be executed in parallel.
9394
ExecuteSequentially bool
95+
96+
// ToolArgumentsHandler allows handling of tool arguments before execution.
97+
// When provided, this function will be called for each tool call to process the arguments.
98+
// Parameters:
99+
// - ctx: The context for the tool call
100+
// - name: The name of the tool being called
101+
// - arguments: The original arguments string for the tool
102+
// Returns:
103+
// - string: The processed arguments string to be used for tool execution
104+
// - error: Any error that occurred during preprocessing
105+
ToolArgumentsHandler func(ctx context.Context, name, arguments string) (string, error)
94106
}
95107

96108
// NewToolNode creates a new ToolsNode.
@@ -107,9 +119,10 @@ func NewToolNode(ctx context.Context, conf *ToolsNodeConfig) (*ToolsNode, error)
107119
}
108120

109121
return &ToolsNode{
110-
tuple: tuple,
111-
unknownToolHandler: conf.UnknownToolsHandler,
112-
executeSequentially: conf.ExecuteSequentially,
122+
tuple: tuple,
123+
unknownToolHandler: conf.UnknownToolsHandler,
124+
executeSequentially: conf.ExecuteSequentially,
125+
toolArgumentsHandler: conf.ToolArgumentsHandler,
113126
}, nil
114127
}
115128

@@ -191,7 +204,7 @@ type toolCallTask struct {
191204
err error
192205
}
193206

194-
func (tn *ToolsNode) genToolCallTasks(tuple *toolsTuple, input *schema.Message, executedTools map[string]string) ([]toolCallTask, error) {
207+
func (tn *ToolsNode) genToolCallTasks(ctx context.Context, tuple *toolsTuple, input *schema.Message, executedTools map[string]string, isStream bool) ([]toolCallTask, error) {
195208
if input.Role != schema.Assistant {
196209
return nil, fmt.Errorf("expected message role is Assistant, got %s", input.Role)
197210
}
@@ -210,7 +223,11 @@ func (tn *ToolsNode) genToolCallTasks(tuple *toolsTuple, input *schema.Message,
210223
toolCallTasks[i].arg = toolCall.Function.Arguments
211224
toolCallTasks[i].callID = toolCall.ID
212225
toolCallTasks[i].executed = true
213-
toolCallTasks[i].output = result
226+
if isStream {
227+
toolCallTasks[i].sOutput = schema.StreamReaderFromArray([]string{result})
228+
} else {
229+
toolCallTasks[i].output = result
230+
}
214231
continue
215232
}
216233
index, ok := tuple.indexes[toolCall.Function.Name]
@@ -223,8 +240,16 @@ func (tn *ToolsNode) genToolCallTasks(tuple *toolsTuple, input *schema.Message,
223240
toolCallTasks[i].r = tuple.rps[index]
224241
toolCallTasks[i].meta = tuple.meta[index]
225242
toolCallTasks[i].name = toolCall.Function.Name
226-
toolCallTasks[i].arg = toolCall.Function.Arguments
227243
toolCallTasks[i].callID = toolCall.ID
244+
if tn.toolArgumentsHandler != nil {
245+
arg, err := tn.toolArgumentsHandler(ctx, toolCall.Function.Name, toolCall.Function.Arguments)
246+
if err != nil {
247+
return nil, fmt.Errorf("failed to executed tool[name:%s arguments:%s] arguments handler: %w", toolCall.Function.Name, toolCall.Function.Arguments, err)
248+
}
249+
toolCallTasks[i].arg = arg
250+
} else {
251+
toolCallTasks[i].arg = toolCall.Function.Arguments
252+
}
228253
}
229254
}
230255

@@ -280,6 +305,9 @@ func runToolCallTaskByStream(ctx context.Context, task *toolCallTask, opts ...to
280305

281306
func sequentialRunToolCall(ctx context.Context, run func(ctx2 context.Context, callTask *toolCallTask, opts ...tool.Option), tasks []toolCallTask, opts ...tool.Option) {
282307
for i := 0; i < len(tasks); i++ {
308+
if tasks[i].executed {
309+
continue
310+
}
283311
run(ctx, &tasks[i], opts...)
284312
}
285313
}
@@ -294,6 +322,9 @@ func parallelRunToolCall(ctx context.Context,
294322

295323
var wg sync.WaitGroup
296324
for i := 1; i < len(tasks); i++ {
325+
if tasks[i].executed {
326+
continue
327+
}
297328
wg.Add(1)
298329
go func(ctx_ context.Context, t *toolCallTask, opts ...tool.Option) {
299330
defer wg.Done()
@@ -326,7 +357,7 @@ func (tn *ToolsNode) Invoke(ctx context.Context, input *schema.Message,
326357
}
327358
}
328359

329-
tasks, err := tn.genToolCallTasks(tuple, input, opt.executedTools)
360+
tasks, err := tn.genToolCallTasks(ctx, tuple, input, opt.executedTools, false)
330361
if err != nil {
331362
return nil, err
332363
}
@@ -385,7 +416,7 @@ func (tn *ToolsNode) Stream(ctx context.Context, input *schema.Message,
385416
}
386417
}
387418

388-
tasks, err := tn.genToolCallTasks(tuple, input, opt.executedTools)
419+
tasks, err := tn.genToolCallTasks(ctx, tuple, input, opt.executedTools, true)
389420
if err != nil {
390421
return nil, err
391422
}

compose/tool_node_test.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -734,7 +734,9 @@ func TestToolRerun(t *testing.T) {
734734
},
735735
}, info.RerunNodesExtra["tool node"])
736736

737-
result, err := r.Invoke(ctx, nil, WithCheckPointID("1"))
737+
sr, err := r.Stream(ctx, nil, WithCheckPointID("1"))
738+
assert.NoError(t, err)
739+
result, err := concatStreamReader(sr)
738740
assert.NoError(t, err)
739741
assert.Equal(t, "tool1 input: inputtool2 input: inputtool3 input: inputtool4 input: input", result)
740742
}

0 commit comments

Comments
 (0)