Skip to content

Commit cb3c30f

Browse files
authored
preserve tool call results in AiResponse.merge (#4866)
1 parent 6495440 commit cb3c30f

File tree

3 files changed

+91
-47
lines changed

3 files changed

+91
-47
lines changed

.changeset/tricky-months-sit.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"@effect/ai": patch
3+
---
4+
5+
preserve tool call results in AiResponse.merge

packages/ai/ai/src/AiLanguageModel.ts

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -358,14 +358,14 @@ export const make: <Config>(
358358
if (Option.isNone(spanTransformer)) {
359359
return stream
360360
}
361-
let lastResponse: AiResponse.AiResponse | undefined
361+
let finalResponse = AiResponse.empty
362362
return stream.pipe(
363363
Stream.map((response) => {
364-
lastResponse = response
364+
finalResponse = AiResponse.merge(finalResponse, response)
365365
return response
366366
}),
367367
Stream.ensuring(Effect.sync(() => {
368-
spanTransformer.value({ ...options, response: lastResponse! })
368+
spanTransformer.value({ ...options, response: finalResponse })
369369
}))
370370
)
371371
})),
@@ -470,49 +470,49 @@ const getDescription = (ast: AST.AST): string => {
470470
return AST.DescriptionAnnotationId in annotations ? annotations[AST.DescriptionAnnotationId] as string : ""
471471
}
472472

473-
const resolveParts = <Tools extends AiTool.Any>(options: {
473+
const resolveParts = Effect.fnUntraced(function*<Tools extends AiTool.Any>(options: {
474474
readonly response: AiResponse.AiResponse
475475
readonly toolkit: AiToolkit.ToHandler<Tools>
476476
readonly concurrency: Concurrency | undefined
477477
readonly method: string
478-
}) =>
479-
Effect.gen(function*() {
480-
const toolNames: Array<string> = []
481-
const toolParts = options.response.parts.filter(
482-
(part): part is AiResponse.ToolCallPart => {
483-
if (part._tag === "ToolCallPart") {
484-
toolNames.push(part.name)
485-
return true
486-
}
487-
return false
478+
}) {
479+
const toolNames: Array<string> = []
480+
const toolParts = options.response.parts.filter(
481+
(part): part is AiResponse.ToolCallPart => {
482+
if (part._tag === "ToolCallPart") {
483+
toolNames.push(part.name)
484+
return true
488485
}
489-
)
490-
if (toolParts.length === 0) {
491-
return new AiResponse.WithToolCallResults({
492-
parts: options.response.parts,
493-
results: constEmptyMap,
494-
encodedResults: constEmptyMap
495-
}, constDisableValidation)
486+
return false
496487
}
497-
yield* Effect.annotateCurrentSpan("toolCalls", toolNames)
498-
const results = new Map<AiResponse.ToolCallId, AiTool.Success<Tools>>()
499-
const encodedResults = new Map<AiResponse.ToolCallId, unknown>()
500-
yield* Effect.forEach(toolParts, (part) => {
501-
const id = part.id as AiResponse.ToolCallId
502-
const name = part.name as AiTool.Name<Tools>
503-
const params = part.params as AiTool.Parameters<Tools>
504-
const toolCall = options.toolkit.handle(name, params)
505-
return Effect.map(toolCall, ({ encodedResult, result }) => {
506-
results.set(id, result)
507-
encodedResults.set(id, encodedResult)
508-
})
509-
}, { concurrency: options.concurrency, discard: true })
488+
)
489+
if (toolParts.length === 0) {
510490
return new AiResponse.WithToolCallResults({
511491
parts: options.response.parts,
512-
results,
513-
encodedResults
492+
results: constEmptyMap,
493+
encodedResults: constEmptyMap
514494
}, constDisableValidation)
515-
})
495+
}
496+
yield* Effect.annotateCurrentSpan("toolCalls", toolNames)
497+
const results = new Map<AiResponse.ToolCallId, AiTool.Success<Tools>>()
498+
const encodedResults = new Map<AiResponse.ToolCallId, unknown>()
499+
const resolve = Effect.forEach(toolParts, (part) => {
500+
const id = part.id as AiResponse.ToolCallId
501+
const name = part.name as AiTool.Name<Tools>
502+
const params = part.params as AiTool.Parameters<Tools>
503+
const toolCall = options.toolkit.handle(name, params)
504+
return Effect.map(toolCall, ({ encodedResult, result }) => {
505+
results.set(id, result)
506+
encodedResults.set(id, encodedResult)
507+
})
508+
}, { concurrency: options.concurrency, discard: true })
509+
yield* resolve
510+
return new AiResponse.WithToolCallResults({
511+
parts: options.response.parts,
512+
results,
513+
encodedResults
514+
}, constDisableValidation)
515+
})
516516

517517
/**
518518
* Generate text using a large language model for the specified `prompt`.

packages/ai/ai/src/AiResponse.ts

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -757,15 +757,56 @@ export const merge: {
757757
(other: AiResponse) => (self: AiResponse) => AiResponse,
758758
(self: AiResponse, other: AiResponse) => AiResponse
759759
>(2, (self, other) => {
760-
if (other.parts.length === 0) {
761-
return new AiResponse({
762-
parts: self.parts
760+
const newParts = mergeParts(self, other)
761+
if (hasToolCallResults(self) && hasToolCallResults(other)) {
762+
return new WithToolCallResults({
763+
results: new Map([...self.results, ...other.results]) as any,
764+
encodedResults: new Map([...self.encodedResults, ...other.encodedResults]),
765+
parts: newParts
766+
}, constDisableValidation)
767+
} else if (hasToolCallResults(self)) {
768+
return new WithToolCallResults({
769+
results: self.results as any,
770+
encodedResults: self.encodedResults,
771+
parts: newParts
772+
}, constDisableValidation)
773+
} else if (hasToolCallResults(other)) {
774+
return new WithToolCallResults({
775+
results: other.results as any,
776+
encodedResults: other.encodedResults,
777+
parts: newParts
778+
}, constDisableValidation)
779+
} else if (isStructured(self) && isStructured(other)) {
780+
return new WithStructuredOutput({
781+
id: self.id,
782+
name: self.name,
783+
value: other.value,
784+
parts: newParts
785+
}, constDisableValidation)
786+
} else if (isStructured(self)) {
787+
return new WithStructuredOutput({
788+
id: self.id,
789+
name: self.name,
790+
value: self.value,
791+
parts: newParts
792+
}, constDisableValidation)
793+
} else if (isStructured(other)) {
794+
return new WithStructuredOutput({
795+
id: other.id,
796+
name: other.name,
797+
value: other.value,
798+
parts: newParts
763799
}, constDisableValidation)
764800
}
801+
return new AiResponse({ parts: newParts }, constDisableValidation)
802+
})
803+
804+
const mergeParts = (self: AiResponse, other: AiResponse): ReadonlyArray<Part> => {
805+
if (other.parts.length === 0) {
806+
return self.parts
807+
}
765808
if (self.parts.length === 0) {
766-
return new AiResponse({
767-
parts: other.parts
768-
}, constDisableValidation)
809+
return other.parts
769810
}
770811
const lastPart = self.parts[self.parts.length - 1]
771812
const newParts: Array<Part> = []
@@ -778,10 +819,8 @@ export const merge: {
778819
if (text.length > 0) {
779820
newParts.push(new TextPart({ text }, constDisableValidation))
780821
}
781-
return newParts.length === 0 ? self : new AiResponse({
782-
parts: [...self.parts.slice(0, self.parts.length - 1), ...newParts]
783-
}, constDisableValidation)
784-
})
822+
return newParts.length === 0 ? self.parts : [...self.parts.slice(0, self.parts.length - 1), ...newParts]
823+
}
785824

786825
/**
787826
* Adds the specified tool calls to the provided `AiResponse`.

0 commit comments

Comments
 (0)