Skip to content

Commit a5d404a

Browse files
Feat/wdz/react interrupt (#252)
* feat: react agent support tool node rerun * feat: export IsInterruptRerun
1 parent 6b94b7b commit a5d404a

File tree

4 files changed

+18
-5
lines changed

4 files changed

+18
-5
lines changed

compose/graph_run.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,7 @@ func (r *runner) resolveInterruptCompletedTasks(tempInfo *interruptTempInfo, com
395395
tempInfo.subGraphInterrupts[completedTasks[i].nodeKey] = info
396396
continue
397397
}
398-
extra, ok := isInterruptRerunError(completedTasks[i].err)
398+
extra, ok := IsInterruptRerunError(completedTasks[i].err)
399399
if ok {
400400
tempInfo.interruptRerunNodes = append(tempInfo.interruptRerunNodes, completedTasks[i].nodeKey)
401401
if extra != nil {

compose/interrupt.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ func (i *interruptAndRerun) Error() string {
4747
return fmt.Sprintf("interrupt and rerun: %v", i.Extra)
4848
}
4949

50-
func isInterruptRerunError(err error) (any, bool) {
50+
func IsInterruptRerunError(err error) (any, bool) {
5151
if errors.Is(err, InterruptAndRerun) {
5252
return nil, true
5353
}
@@ -117,7 +117,7 @@ func isInterruptError(err error) bool {
117117
if info := isSubGraphInterrupt(err); info != nil {
118118
return true
119119
}
120-
if _, ok := isInterruptRerunError(err); ok {
120+
if _, ok := IsInterruptRerunError(err); ok {
121121
return true
122122
}
123123

compose/tool_node.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ func (tn *ToolsNode) Invoke(ctx context.Context, input *schema.Message,
343343
rerun := false
344344
for i := 0; i < n; i++ {
345345
if tasks[i].err != nil {
346-
extra, ok := isInterruptRerunError(tasks[i].err)
346+
extra, ok := IsInterruptRerunError(tasks[i].err)
347347
if !ok {
348348
return nil, fmt.Errorf("failed to invoke tool[name:%s id:%s]: %w", tasks[i].name, tasks[i].callID, tasks[i].err)
349349
}
@@ -404,7 +404,7 @@ func (tn *ToolsNode) Stream(ctx context.Context, input *schema.Message,
404404
// check rerun
405405
for i := 0; i < n; i++ {
406406
if tasks[i].err != nil {
407-
extra, ok := isInterruptRerunError(tasks[i].err)
407+
extra, ok := IsInterruptRerunError(tasks[i].err)
408408
if !ok {
409409
return nil, fmt.Errorf("failed to stream tool call %s: %w", tasks[i].callID, tasks[i].err)
410410
}

flow/agent/react/react.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package react
1919
import (
2020
"context"
2121
"io"
22+
"sync"
2223

2324
"github.com/cloudwego/eino/components/model"
2425
"github.com/cloudwego/eino/compose"
@@ -153,6 +154,8 @@ type Agent struct {
153154
graphAddNodeOpts []compose.GraphAddNodeOpt
154155
}
155156

157+
var registerStateOnce sync.Once
158+
156159
// NewAgent creates a ReAct agent that feeds tool response into next round of Chat Model generation.
157160
//
158161
// IMPORTANT!! For models that don't output tool calls in the first streaming chunk (e.g. Claude)
@@ -167,6 +170,13 @@ func NewAgent(ctx context.Context, config *AgentConfig) (_ *Agent, err error) {
167170
messageModifier = config.MessageModifier
168171
)
169172

173+
registerStateOnce.Do(func() {
174+
err = compose.RegisterSerializableType[state]("_eino_react_state")
175+
})
176+
if err != nil {
177+
return
178+
}
179+
170180
if toolCallChecker == nil {
171181
toolCallChecker = firstChunkStreamToolCallChecker
172182
}
@@ -208,6 +218,9 @@ func NewAgent(ctx context.Context, config *AgentConfig) (_ *Agent, err error) {
208218
}
209219

210220
toolsNodePreHandle := func(ctx context.Context, input *schema.Message, state *state) (*schema.Message, error) {
221+
if input == nil {
222+
return state.Messages[len(state.Messages)-1], nil // used for rerun interrupt resume
223+
}
211224
state.Messages = append(state.Messages, input)
212225
state.ReturnDirectlyToolCallID = getReturnDirectlyToolCallID(input, config.ToolReturnDirectly)
213226
return input, nil

0 commit comments

Comments
 (0)