From cda8caabada2b1a262b4cb7a8f08a2b5b89d3d0a Mon Sep 17 00:00:00 2001 From: Kevin Mas Ruiz Date: Wed, 18 Jun 2025 11:22:45 +0200 Subject: [PATCH 1/2] chore: add example accuracy test --- package-lock.json | 3 +- package.json | 8 +- .../accuracy/1-step/simple-find-query.test.ts | 20 ++ tests/accuracy/models/gemini.ts | 82 ++++++++ tests/accuracy/models/index.ts | 11 + tests/accuracy/models/model.ts | 12 ++ tests/accuracy/test-sdk.ts | 190 ++++++++++++++++++ 7 files changed, 323 insertions(+), 3 deletions(-) create mode 100644 tests/accuracy/1-step/simple-find-query.test.ts create mode 100644 tests/accuracy/models/gemini.ts create mode 100644 tests/accuracy/models/index.ts create mode 100644 tests/accuracy/models/model.ts create mode 100644 tests/accuracy/test-sdk.ts diff --git a/package-lock.json b/package-lock.json index 0a20e60f..c4a43d60 100644 --- a/package-lock.json +++ b/package-lock.json @@ -54,7 +54,8 @@ "tsx": "^4.19.3", "typescript": "^5.8.2", "typescript-eslint": "^8.29.1", - "yaml": "^2.7.1" + "yaml": "^2.7.1", + "zod-to-json-schema": "^3.24.5" }, "engines": { "node": ">=20.10.0" diff --git a/package.json b/package.json index e9603e7a..99983d46 100644 --- a/package.json +++ b/package.json @@ -29,7 +29,10 @@ "check:types": "tsc --noEmit --project tsconfig.json", "reformat": "prettier --write .", "generate": "./scripts/generate.sh", - "test": "node --experimental-vm-modules node_modules/jest/bin/jest.js --coverage" + "test": "npm run test:unit && npm run test:integration", + "test:accuracy": "node --experimental-vm-modules node_modules/jest/bin/jest.js --coverage --testPathPattern=tests/accuracy", + "test:unit": "node --experimental-vm-modules node_modules/jest/bin/jest.js --coverage --testPathPattern=tests/unit", + "test:integration": "node --experimental-vm-modules node_modules/jest/bin/jest.js --coverage --testPathPattern=tests/integration" }, "license": "Apache-2.0", "devDependencies": { @@ -57,7 +60,8 @@ "tsx": "^4.19.3", "typescript": "^5.8.2", "typescript-eslint": "^8.29.1", - "yaml": "^2.7.1" + "yaml": "^2.7.1", + "zod-to-json-schema": "^3.24.5" }, "dependencies": { "@modelcontextprotocol/sdk": "^1.11.2", diff --git a/tests/accuracy/1-step/simple-find-query.test.ts b/tests/accuracy/1-step/simple-find-query.test.ts new file mode 100644 index 00000000..46b46d28 --- /dev/null +++ b/tests/accuracy/1-step/simple-find-query.test.ts @@ -0,0 +1,20 @@ +import { describeAccuracyTest } from "../test-sdk.js"; + +describeAccuracyTest("1 step find queries", ({ prompt }) => { + prompt("find all users in database 'my' and collection 'users'", (tool) => { + tool("find").verifyCalled({ database: "my", collection: "users", limit: 10 }); + }); + + prompt("find all red cards in database 'production' and collection 'cars'", (tool) => { + tool("find").verifyCalled({ filter: { color: "red" }, database: "production", collection: "cars", limit: 10 }); + }); + + prompt("get 100 books in database 'prod' and collection 'books' where the author is J.R.R Tolkien", (tool) => { + tool("find").verifyCalled({ + filter: { author: "J.R.R Tolkien" }, + database: "prod", + collection: "books", + limit: 100, + }); + }); +}); diff --git a/tests/accuracy/models/gemini.ts b/tests/accuracy/models/gemini.ts new file mode 100644 index 00000000..6abf48fc --- /dev/null +++ b/tests/accuracy/models/gemini.ts @@ -0,0 +1,82 @@ +import { ModelFacade, ToolCall, ToolDefinition } from "./model.js"; + +type GeminiModel = "gemini-2.0-flash" | "gemini-1.5-flash"; + +export class GeminiModelFacade implements ModelFacade { + readonly name: GeminiModel; + + constructor(modelName: GeminiModel) { + this.name = modelName; + } + + available(): boolean { + return process.env.MONGODB_MCP_TEST_GEMINI_API_KEY !== undefined; + } + + async generateContent(prompt: string, tools: ToolDefinition[]): Promise<{ toolCall: ToolCall[]; text?: string }> { + const toolDefinitions = tools.map((tool) => ({ + name: tool.name, + description: tool.description, + parameters: tool.parameters || {}, + })); + + const chatHistory = [{ role: "user", parts: [{ text: prompt }] }]; + const payload = { + contents: chatHistory, + tools: { + function_declarations: [toolDefinitions], + }, + }; + + const apiKey = process.env.MONGODB_MCP_TEST_GEMINI_API_KEY; + const apiUrl = `https://generativelanguage.googleapis.com/v1beta/models/${this.name}:generateContent?key=${apiKey}`; + + try { + const response = await fetch(apiUrl, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify(payload), + }); + + if (!response.ok) { + const errorData = await response.text(); + console.error(`[Gemini API Error] HTTP error! status: ${response.status}, data: ${errorData}`); + return { toolCall: [], text: `Gemini API error: ${response.status}` }; + } + + const result = (await response.json()) as { + candidates: Array<{ + content: { + parts: Array<{ + text?: string; + functionCall?: { + name: string; + args: Record; + }; + }>; + }; + }>; + }; + + if (result.candidates && result.candidates.length > 0) { + const firstPart = result.candidates[0]?.content.parts[0]; + if (firstPart?.functionCall) { + return { + toolCall: [ + { + name: firstPart.functionCall.name, + args: firstPart.functionCall.args, + }, + ], + }; + } else if (firstPart?.text) { + return { toolCall: [], text: firstPart.text }; + } + } + return { toolCall: [], text: "Gemini response was empty or unexpected." }; + } catch (error: unknown) { + console.error("[Gemini API Fetch Error]", error); + return { toolCall: [], text: `Error contacting Gemini LLM.` }; + } + } +} diff --git a/tests/accuracy/models/index.ts b/tests/accuracy/models/index.ts new file mode 100644 index 00000000..73a43cb7 --- /dev/null +++ b/tests/accuracy/models/index.ts @@ -0,0 +1,11 @@ +import { ModelFacade } from "./model.js"; +import { GeminiModelFacade } from "./gemini.js"; + +const ALL_MODELS: ModelFacade[] = [ + new GeminiModelFacade("gemini-2.0-flash"), + new GeminiModelFacade("gemini-1.5-flash"), +]; + +export function availableModels(): ModelFacade[] { + return ALL_MODELS.filter((model) => model.available()); +} diff --git a/tests/accuracy/models/model.ts b/tests/accuracy/models/model.ts new file mode 100644 index 00000000..069fef34 --- /dev/null +++ b/tests/accuracy/models/model.ts @@ -0,0 +1,12 @@ +export type ToolCall = { name: string; args: Record }; +export type ToolDefinition = { + name: string; + description: string; + parameters: Record; +}; + +export interface ModelFacade { + name: string; + available(): boolean; + generateContent(prompt: string, tools: ToolDefinition[]): Promise<{ toolCall: ToolCall[]; text?: string }>; +} diff --git a/tests/accuracy/test-sdk.ts b/tests/accuracy/test-sdk.ts new file mode 100644 index 00000000..69cdbbac --- /dev/null +++ b/tests/accuracy/test-sdk.ts @@ -0,0 +1,190 @@ +import { Client } from "@modelcontextprotocol/sdk/client/index.js"; +import { McpServer, RegisteredTool, ToolCallback } from "@modelcontextprotocol/sdk/server/mcp.js"; +import { InMemoryTransport } from "@modelcontextprotocol/sdk/inMemory.js"; +import { Server } from "../../src/server.js"; +import { Session } from "../../src/session.js"; +import { Telemetry } from "../../src/telemetry/telemetry.js"; +import { config, UserConfig } from "../../src/config.js"; +import { afterEach } from "node:test"; +import { availableModels } from "./models/index.js"; +import { ToolDefinition } from "./models/model.js"; +import { zodToJsonSchema } from "zod-to-json-schema"; + +class ToolMock { + readonly name: string; + arguments: unknown; + returns: unknown; + wasCalledWith: unknown; + + constructor(name: string) { + this.name = name; + this.arguments = {}; + this.returns = {}; + } + + verifyCalled(args: unknown): this { + this.arguments = args; + return this; + } + + thenReturn(value: unknown): this { + this.returns = value; + return this; + } + + _wasCalledWith(args: unknown): this { + this.wasCalledWith = args; + return this; + } + + _verify(): void { + if (this.wasCalledWith) { + expect(this.wasCalledWith).toEqual(this.arguments); + } else { + expect(this.arguments).not.toBe(null); + } + } +} + +interface McpServerUnsafe { + mcpServer: McpServer; +} + +type AccuracyToolSetupFunction = (toolName: string) => ToolMock; +type AccuracyTestCaseFn = (tools: AccuracyToolSetupFunction) => void; +type AccuracyItFn = (prompt: string, testCase: AccuracyTestCaseFn) => void; +type AccuracyTestSuite = { prompt: AccuracyItFn }; + +export function describeAccuracyTest(useCase: string, testCaseFn: (testSuite: AccuracyTestSuite) => void) { + const models = availableModels(); + if (models.length === 0) { + throw new Error("No models available for accuracy tests."); + } + + models.forEach((model) => { + describe(`${model.name}: ${useCase}`, () => { + let mcpServer: Server; + let mcpClient: Client; + let userConfig: UserConfig; + let session: Session; + let telemetry: Telemetry; + + beforeEach(async () => { + mcpClient = new Client( + { + name: "test-client", + version: "1.2.3", + }, + { + capabilities: {}, + } + ); + + userConfig = { ...config }; + session = new Session(userConfig); + telemetry = Telemetry.create(session, userConfig); + + mcpServer = new Server({ + session, + userConfig, + telemetry, + mcpServer: new McpServer({ + name: "test-server", + version: "5.2.3", + }), + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([mcpServer.connect(serverTransport), mcpClient.connect(clientTransport)]); + }); + + afterEach(async () => { + await Promise.all([mcpServer.close(), mcpClient.close()]); + }); + + const promptFn: AccuracyItFn = (prompt: string, testCase: AccuracyTestCaseFn) => { + it(prompt, async () => { + const mcpServerUnsafe = (mcpServer as unknown as McpServerUnsafe).mcpServer; + const tools = mcpServerUnsafe["_registeredTools"] as { [toolName: string]: RegisteredTool }; + const toolDefinitions = Object.entries(tools).map(([toolName, tool]) => { + if (!tool.inputSchema) { + throw new Error(`Tool ${toolName} does not have an input schema defined.`); + } + + const toolForApi: ToolDefinition = { + name: toolName, + description: tool.description ?? "", + parameters: zodToJsonSchema(tool.inputSchema, { + target: "jsonSchema7", + allowedAdditionalProperties: undefined, + rejectedAdditionalProperties: undefined, + postProcess: (schema) => { + if (schema && typeof schema === "object") { + return { + ...schema, + $schema: undefined, + const: undefined, + additionalProperties: undefined, + }; + } + return schema; + }, + }), + }; + delete toolForApi.parameters.$schema; + return toolForApi; + }); + + const mocks: Array = []; + const toolFn: AccuracyToolSetupFunction = (toolName: string) => { + const mock = new ToolMock(toolName); + + const mcpServerUnsafe = (mcpServer as unknown as McpServerUnsafe).mcpServer; + const tools = mcpServerUnsafe["_registeredTools"] as { [toolName: string]: RegisteredTool }; + + if (tools[toolName] !== undefined) { + tools[toolName].callback = ((args: unknown) => { + mock._wasCalledWith(args); + return mock.returns; + }) as unknown as ToolCallback; + } + + mocks.push(mock); + return mock; + }; + + testCase(toolFn); + + const consumePromptUntilNoMoreCall = async (prompt: string[]) => { + const promptStr = prompt.join("\n"); + const response = await model.generateContent(promptStr, toolDefinitions); + + if (response.toolCall.length > 0) { + const toolCallResults = await Promise.all( + response.toolCall.map((tc) => + mcpClient.callTool({ + name: tc.name, + arguments: tc.args, + }) + ) + ); + const newPrompt = toolCallResults.flatMap((result) => + (result.content as Array<{ text: string }>).map((c) => c.text) + ); + + if (newPrompt.join("\n").trim().length > 0) { + return consumePromptUntilNoMoreCall(newPrompt); + } + } + }; + + await consumePromptUntilNoMoreCall([prompt]); + mocks.forEach((mock) => mock._verify()); + }); + }; + + testCaseFn({ prompt: promptFn }); + }); + }); +} From bbe2b838bfa70a924de1ee2ff51d232563160907 Mon Sep 17 00:00:00 2001 From: Kevin Mas Ruiz Date: Fri, 20 Jun 2025 14:20:13 +0200 Subject: [PATCH 2/2] chore: add plans for multistep actions --- .../1-step/simple-delete-query.test.ts | 11 +++ .../accuracy/1-step/simple-find-query.test.ts | 2 +- .../1-step/simple-update-query.test.ts | 12 +++ ...create-collection-with-sample-data.test.ts | 32 +++++++ tests/accuracy/models/gemini.ts | 69 ++++++++++++++- tests/accuracy/models/model.ts | 4 +- tests/accuracy/test-sdk.ts | 88 ++++++++++++++++--- 7 files changed, 203 insertions(+), 15 deletions(-) create mode 100644 tests/accuracy/1-step/simple-delete-query.test.ts create mode 100644 tests/accuracy/1-step/simple-update-query.test.ts create mode 100644 tests/accuracy/2-step/create-collection-with-sample-data.test.ts diff --git a/tests/accuracy/1-step/simple-delete-query.test.ts b/tests/accuracy/1-step/simple-delete-query.test.ts new file mode 100644 index 00000000..f0d85f03 --- /dev/null +++ b/tests/accuracy/1-step/simple-delete-query.test.ts @@ -0,0 +1,11 @@ +import { describeAccuracyTest } from "../test-sdk.js"; + +describeAccuracyTest("1 step delete queries", ({ prompt }) => { + prompt("delete all disabled users (disabled = true) in database 'my' and collection 'users'", (tool) => { + tool("delete-many").verifyCalled({ + database: "my", + collection: "users", + filter: { disabled: true }, + }); + }); +}); diff --git a/tests/accuracy/1-step/simple-find-query.test.ts b/tests/accuracy/1-step/simple-find-query.test.ts index 46b46d28..6113d3d2 100644 --- a/tests/accuracy/1-step/simple-find-query.test.ts +++ b/tests/accuracy/1-step/simple-find-query.test.ts @@ -5,7 +5,7 @@ describeAccuracyTest("1 step find queries", ({ prompt }) => { tool("find").verifyCalled({ database: "my", collection: "users", limit: 10 }); }); - prompt("find all red cards in database 'production' and collection 'cars'", (tool) => { + prompt("find all red cars in database 'production' and collection 'cars'", (tool) => { tool("find").verifyCalled({ filter: { color: "red" }, database: "production", collection: "cars", limit: 10 }); }); diff --git a/tests/accuracy/1-step/simple-update-query.test.ts b/tests/accuracy/1-step/simple-update-query.test.ts new file mode 100644 index 00000000..6580ca23 --- /dev/null +++ b/tests/accuracy/1-step/simple-update-query.test.ts @@ -0,0 +1,12 @@ +import { describeAccuracyTest } from "../test-sdk.js"; + +describeAccuracyTest("1 step update queries", ({ prompt }) => { + prompt("set all users with an empty email to disabled in database 'my' and collection 'users'", (tool) => { + tool("update-many").verifyCalled({ + database: "my", + collection: "users", + filter: { email: "" }, + update: { $set: { disabled: true } }, + }); + }); +}); diff --git a/tests/accuracy/2-step/create-collection-with-sample-data.test.ts b/tests/accuracy/2-step/create-collection-with-sample-data.test.ts new file mode 100644 index 00000000..2b6c39e5 --- /dev/null +++ b/tests/accuracy/2-step/create-collection-with-sample-data.test.ts @@ -0,0 +1,32 @@ +import { describeAccuracyTest } from "../test-sdk.js"; + +describeAccuracyTest("2 step create collection", ({ prompt }) => { + prompt( + ` + create a new collection named 'users' in database 'my' and afterwards create a sample document with the following data: + - username: "john_doe" + - email: test@mongodb.com + - password: "password123" + - disabled: false + `, + (tool) => { + tool("create-collection").verifyCalled({ + database: "my", + collection: "users", + }); + + tool("insert-many").verifyCalled({ + database: "my", + collection: "users", + documents: [ + { + username: "john_doe", + email: "test@mongodb.com", + password: "password123", + disabled: false, + }, + ], + }); + } + ); +}); diff --git a/tests/accuracy/models/gemini.ts b/tests/accuracy/models/gemini.ts index 6abf48fc..04818966 100644 --- a/tests/accuracy/models/gemini.ts +++ b/tests/accuracy/models/gemini.ts @@ -13,14 +13,79 @@ export class GeminiModelFacade implements ModelFacade { return process.env.MONGODB_MCP_TEST_GEMINI_API_KEY !== undefined; } - async generateContent(prompt: string, tools: ToolDefinition[]): Promise<{ toolCall: ToolCall[]; text?: string }> { + async generatePlan(prompt: string, tools: ToolDefinition[]): Promise { + const planPrompt = `You are an expert MongoDB developer. Create a plan for the following task: \n ${prompt} \n Return the plan as a list of steps, as a JSON array. For example: [ "Step 1: ...", "Step 2: ...", "Step 3: ..." ]. Only return the JSON array, nothing else. Do not include any wrapper markdown or anything, just the plain JSON array.`; + const chatHistory = [{ role: "user", parts: [{ text: planPrompt }] }]; + + const apiKey = process.env.MONGODB_MCP_TEST_GEMINI_API_KEY; + const apiUrl = `https://generativelanguage.googleapis.com/v1beta/models/${this.name}:generateContent?key=${apiKey}`; + + const toolDefinitions = tools.map((tool) => ({ + name: tool.name, + description: tool.description, + parameters: tool.parameters || {}, + })); + + const payload = { + contents: chatHistory, + tools: { + function_declarations: [toolDefinitions], + }, + }; + + try { + const response = await fetch(apiUrl, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify(payload), + }); + + if (!response.ok) { + const errorData = await response.text(); + console.error(`[Gemini API Error] HTTP error! status: ${response.status}, data: ${errorData}`); + return []; + } + + const result = (await response.json()) as { + candidates: Array<{ + content: { + parts: Array<{ + text?: string; + functionCall?: { + name: string; + args: Record; + }; + }>; + }; + }>; + }; + + const responseString = result.candidates + .flatMap((candidate) => candidate.content.parts.map((part) => part.text || "")) + .join("") + .replace("```json", "") + .replace("```", ""); + + try { + return JSON.parse(responseString) as string[]; + } catch (parseError) { + console.error("[Gemini API JSON.parse Error]", responseString, parseError); + } + return []; + } catch (error: unknown) { + console.error("[Gemini API Fetch Error]", error); + return []; + } + } + + async generateContent(parts: string[], tools: ToolDefinition[]): Promise<{ toolCall: ToolCall[]; text?: string }> { const toolDefinitions = tools.map((tool) => ({ name: tool.name, description: tool.description, parameters: tool.parameters || {}, })); - const chatHistory = [{ role: "user", parts: [{ text: prompt }] }]; + const chatHistory = [{ role: "user", parts: parts.map((part) => ({ text: part })) }]; const payload = { contents: chatHistory, tools: { diff --git a/tests/accuracy/models/model.ts b/tests/accuracy/models/model.ts index 069fef34..2c3316ad 100644 --- a/tests/accuracy/models/model.ts +++ b/tests/accuracy/models/model.ts @@ -8,5 +8,7 @@ export type ToolDefinition = { export interface ModelFacade { name: string; available(): boolean; - generateContent(prompt: string, tools: ToolDefinition[]): Promise<{ toolCall: ToolCall[]; text?: string }>; + + generatePlan(prompt: string, tools: ToolDefinition[]): Promise; + generateContent(parts: string[], tools: ToolDefinition[]): Promise<{ toolCall: ToolCall[]; text?: string }>; } diff --git a/tests/accuracy/test-sdk.ts b/tests/accuracy/test-sdk.ts index 69cdbbac..e39dce38 100644 --- a/tests/accuracy/test-sdk.ts +++ b/tests/accuracy/test-sdk.ts @@ -10,6 +10,7 @@ import { availableModels } from "./models/index.js"; import { ToolDefinition } from "./models/model.js"; import { zodToJsonSchema } from "zod-to-json-schema"; +type ToolMockReturn = { content: Array<{ type: string; text: string }> }; class ToolMock { readonly name: string; arguments: unknown; @@ -27,7 +28,7 @@ class ToolMock { return this; } - thenReturn(value: unknown): this { + thenReturn(value: ToolMockReturn): this { this.returns = value; return this; } @@ -55,6 +56,36 @@ type AccuracyTestCaseFn = (tools: AccuracyToolSetupFunction) => void; type AccuracyItFn = (prompt: string, testCase: AccuracyTestCaseFn) => void; type AccuracyTestSuite = { prompt: AccuracyItFn }; +type NonMockedCallError = { tool: string; args: unknown }; + +function logVerbose(...args: unknown[]): void { + if (process.env.MONGODB_MCP_TEST_VERBOSE === "true") { + console.log(...args); + } +} + +function printModelPlanIfVerbose(model: string, plan: string[]): void { + logVerbose(model, "📝: ", plan.join("\n")); +} + +function testPromptIsVerbose(model: string, prompt: string): void { + logVerbose(model, "📜: ", prompt); +} + +function modelSaidVerbose(model: string, response: string): void { + if (response.length > 0) { + logVerbose(model, "🗣️: ", response); + } +} + +function modelToolCalledVerbose(model: string, toolCall: string, args: unknown): void { + logVerbose(model, "🛠️: ", toolCall, JSON.stringify(args)); +} + +function toolCallsReturnedVerbose(model: string, answer: string): void { + logVerbose(model, "📋: ", answer); +} + export function describeAccuracyTest(useCase: string, testCaseFn: (testSuite: AccuracyTestSuite) => void) { const models = availableModels(); if (models.length === 0) { @@ -105,8 +136,13 @@ export function describeAccuracyTest(useCase: string, testCaseFn: (testSuite: Ac const promptFn: AccuracyItFn = (prompt: string, testCase: AccuracyTestCaseFn) => { it(prompt, async () => { + testPromptIsVerbose(model.name, prompt); + const mcpServerUnsafe = (mcpServer as unknown as McpServerUnsafe).mcpServer; const tools = mcpServerUnsafe["_registeredTools"] as { [toolName: string]: RegisteredTool }; + const mockedTools = new Set(); + const nonMockedCallErrors = new Array(); + const toolDefinitions = Object.entries(tools).map(([toolName, tool]) => { if (!tool.inputSchema) { throw new Error(`Tool ${toolName} does not have an input schema defined.`); @@ -136,9 +172,14 @@ export function describeAccuracyTest(useCase: string, testCaseFn: (testSuite: Ac return toolForApi; }); - const mocks: Array = []; + const plan = await model.generatePlan(prompt, toolDefinitions); + printModelPlanIfVerbose(model.name, plan); + + + const mocks: Array = []; const toolFn: AccuracyToolSetupFunction = (toolName: string) => { const mock = new ToolMock(toolName); + mockedTools.add(toolName); const mcpServerUnsafe = (mcpServer as unknown as McpServerUnsafe).mcpServer; const tools = mcpServerUnsafe["_registeredTools"] as { [toolName: string]: RegisteredTool }; @@ -146,7 +187,7 @@ export function describeAccuracyTest(useCase: string, testCaseFn: (testSuite: Ac if (tools[toolName] !== undefined) { tools[toolName].callback = ((args: unknown) => { mock._wasCalledWith(args); - return mock.returns; + return Promise.resolve(mock.returns); }) as unknown as ToolCallback; } @@ -157,30 +198,55 @@ export function describeAccuracyTest(useCase: string, testCaseFn: (testSuite: Ac testCase(toolFn); const consumePromptUntilNoMoreCall = async (prompt: string[]) => { - const promptStr = prompt.join("\n"); - const response = await model.generateContent(promptStr, toolDefinitions); + const response = await model.generateContent(prompt, toolDefinitions); + modelSaidVerbose(model.name, response.text || ""); if (response.toolCall.length > 0) { const toolCallResults = await Promise.all( - response.toolCall.map((tc) => - mcpClient.callTool({ + response.toolCall.map((tc) => { + modelToolCalledVerbose(model.name, tc.name, tc.args); + + if (!mockedTools.has(tc.name)) { + nonMockedCallErrors.push({ tool: tc.name, args: tc.args }); + } + + return mcpClient.callTool({ name: tc.name, arguments: tc.args, - }) - ) + }); + }) ); - const newPrompt = toolCallResults.flatMap((result) => + + const responseParts = toolCallResults.flatMap((result) => (result.content as Array<{ text: string }>).map((c) => c.text) ); - if (newPrompt.join("\n").trim().length > 0) { + const newPrompt = prompt.concat(responseParts); + toolCallsReturnedVerbose(model.name, newPrompt.join("\n")); + + if (responseParts.length > 0) { return consumePromptUntilNoMoreCall(newPrompt); } } }; + for (const step of plan) { + await consumePromptUntilNoMoreCall([ step ]); + } + await consumePromptUntilNoMoreCall([prompt]); + mocks.forEach((mock) => mock._verify()); + if (nonMockedCallErrors.length > 0) { + for (const call of nonMockedCallErrors) { + console.error( + `Non-mocked tool call detected: ${call.tool} with args:`, + JSON.stringify(call.args, null, 2) + ); + } + + throw new Error("Non-mocked tool calls detected. Check the console for details."); + } }); };