Skip to content

Commit b29da23

Browse files
fix: set tools executed if them have not reported error (#253)
1 parent a5d404a commit b29da23

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

compose/tool_node.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,8 +258,10 @@ func runToolCallTaskByInvoke(ctx context.Context, task *toolCallTask, opts ...to
258258
})
259259

260260
ctx = setToolCallInfo(ctx, &toolCallInfo{toolCallID: task.callID})
261-
task.executed = true
262261
task.output, task.err = task.r.Invoke(ctx, task.arg, opts...)
262+
if task.err == nil {
263+
task.executed = true
264+
}
263265
}
264266

265267
func runToolCallTaskByStream(ctx context.Context, task *toolCallTask, opts ...tool.Option) {
@@ -270,8 +272,10 @@ func runToolCallTaskByStream(ctx context.Context, task *toolCallTask, opts ...to
270272
})
271273

272274
ctx = setToolCallInfo(ctx, &toolCallInfo{toolCallID: task.callID})
273-
task.executed = true
274275
task.sOutput, task.err = task.r.Stream(ctx, task.arg, opts...)
276+
if task.err == nil {
277+
task.executed = true
278+
}
275279
}
276280

277281
func sequentialRunToolCall(ctx context.Context, run func(ctx2 context.Context, callTask *toolCallTask, opts ...tool.Option), tasks []toolCallTask, opts ...tool.Option) {

compose/tool_node_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -720,7 +720,7 @@ func TestToolRerun(t *testing.T) {
720720
r, err := g.Compile(ctx, WithCheckPointStore(&inMemoryStore{m: map[string][]byte{}}))
721721
assert.NoError(t, err)
722722

723-
_, err = r.Invoke(ctx, &schema.Message{Role: schema.Assistant, ToolCalls: tc}, WithCheckPointID("1"))
723+
_, err = r.Stream(ctx, &schema.Message{Role: schema.Assistant, ToolCalls: tc}, WithCheckPointID("1"))
724724
info, ok := ExtractInterruptInfo(err)
725725
assert.True(t, ok)
726726
assert.Equal(t, []string{"tool node"}, info.RerunNodes)

0 commit comments

Comments
 (0)