Skip to content

Commit d3811d5

Browse files
committed
feat(mcp-client): add support for server-provided returnDirect, with client-level default
- Add `returnDirect` configuration to client properties - Set default toolAnnotations in client by merging server and client values - Add returnDirect handling to MCPToolBack
1 parent a5685a1 commit d3811d5

File tree

9 files changed

+108
-16
lines changed

9 files changed

+108
-16
lines changed

auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpClientAutoConfiguration.java

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
3737
import org.springframework.boot.context.properties.EnableConfigurationProperties;
3838
import org.springframework.context.annotation.Bean;
39+
import org.springframework.context.annotation.Import;
3940
import org.springframework.util.CollectionUtils;
4041

4142
/**
@@ -108,6 +109,7 @@
108109
@EnableConfigurationProperties(McpClientCommonProperties.class)
109110
@ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true",
110111
matchIfMissing = true)
112+
@Import(McpCompositeClientProperties.class)
111113
public class McpClientAutoConfiguration {
112114

113115
/**
@@ -146,7 +148,7 @@ private String connectedClientName(String clientName, String serverConnectionNam
146148
@ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "SYNC",
147149
matchIfMissing = true)
148150
public List<McpSyncClient> mcpSyncClients(McpSyncClientConfigurer mcpSyncClientConfigurer,
149-
McpClientCommonProperties commonProperties,
151+
McpClientCommonProperties commonProperties, McpCompositeClientProperties mcpCompositeClientProperties,
150152
ObjectProvider<List<NamedClientMcpTransport>> transportsProvider) {
151153

152154
List<McpSyncClient> mcpSyncClients = new ArrayList<>();
@@ -165,7 +167,11 @@ public List<McpSyncClient> mcpSyncClients(McpSyncClientConfigurer mcpSyncClientC
165167
.requestTimeout(commonProperties.getRequestTimeout());
166168

167169
spec = mcpSyncClientConfigurer.configure(namedTransport.name(), spec);
168-
170+
spec.toolAnnotationsHandler(name -> {
171+
// set returnDirect in client level
172+
boolean returnDirect = mcpCompositeClientProperties.getReturnDirect(namedTransport.name());
173+
return new McpSchema.ToolAnnotations(null, null, null, null, null, returnDirect);
174+
});
169175
var client = spec.build();
170176

171177
if (commonProperties.isInitialized()) {
@@ -213,7 +219,7 @@ McpSyncClientConfigurer mcpSyncClientConfigurer(ObjectProvider<McpSyncClientCust
213219
@Bean
214220
@ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "ASYNC")
215221
public List<McpAsyncClient> mcpAsyncClients(McpAsyncClientConfigurer mcpAsyncClientConfigurer,
216-
McpClientCommonProperties commonProperties,
222+
McpClientCommonProperties commonProperties, McpCompositeClientProperties mcpCompositeClientProperties,
217223
ObjectProvider<List<NamedClientMcpTransport>> transportsProvider) {
218224

219225
List<McpAsyncClient> mcpAsyncClients = new ArrayList<>();
@@ -232,7 +238,11 @@ public List<McpAsyncClient> mcpAsyncClients(McpAsyncClientConfigurer mcpAsyncCli
232238
.requestTimeout(commonProperties.getRequestTimeout());
233239

234240
spec = mcpAsyncClientConfigurer.configure(namedTransport.name(), spec);
235-
241+
spec.toolAnnotationsHandler(name -> {
242+
// set returnDirect in client level
243+
boolean returnDirect = mcpCompositeClientProperties.getReturnDirect(namedTransport.name());
244+
return new McpSchema.ToolAnnotations(null, null, null, null, null, returnDirect);
245+
});
236246
var client = spec.build();
237247

238248
if (commonProperties.isInitialized()) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package org.springframework.ai.mcp.client.common.autoconfigure;
2+
3+
import org.springframework.ai.mcp.client.common.autoconfigure.properties.McpSseClientProperties;
4+
import org.springframework.ai.mcp.client.common.autoconfigure.properties.McpStdioClientProperties;
5+
import org.springframework.ai.mcp.client.common.autoconfigure.properties.McpStreamableHttpClientProperties;
6+
import org.springframework.beans.factory.ObjectProvider;
7+
import org.springframework.context.annotation.Configuration;
8+
9+
@Configuration
10+
public class McpCompositeClientProperties {
11+
12+
private final ObjectProvider<McpSseClientProperties> sseClientPropertiesObjectProvider;
13+
14+
private final ObjectProvider<McpStdioClientProperties> stdioClientPropertiesObjectProvider;
15+
16+
private final ObjectProvider<McpStreamableHttpClientProperties> streamableHttpClientPropertiesObjectProvider;
17+
18+
public McpCompositeClientProperties(ObjectProvider<McpSseClientProperties> sseClientPropertiesObjectProvider,
19+
ObjectProvider<McpStdioClientProperties> stdioClientPropertiesObjectProvider,
20+
ObjectProvider<McpStreamableHttpClientProperties> streamableHttpClientPropertiesObjectProvider) {
21+
this.sseClientPropertiesObjectProvider = sseClientPropertiesObjectProvider;
22+
this.stdioClientPropertiesObjectProvider = stdioClientPropertiesObjectProvider;
23+
this.streamableHttpClientPropertiesObjectProvider = streamableHttpClientPropertiesObjectProvider;
24+
}
25+
26+
public boolean getReturnDirect(String connectionName) {
27+
McpSseClientProperties sseClientProperties = sseClientPropertiesObjectProvider.getIfAvailable();
28+
if (sseClientProperties != null && sseClientProperties.getConnections().containsKey(connectionName)) {
29+
return sseClientProperties.getConnections().get(connectionName).returnDirect();
30+
}
31+
McpStdioClientProperties stdioClientProperties = stdioClientPropertiesObjectProvider.getIfAvailable();
32+
if (stdioClientProperties != null && stdioClientProperties.getConnections().containsKey(connectionName)) {
33+
return stdioClientProperties.getConnections().get(connectionName).returnDirect();
34+
}
35+
McpStreamableHttpClientProperties streamableHttpClientProperties = streamableHttpClientPropertiesObjectProvider
36+
.getIfAvailable();
37+
if (streamableHttpClientProperties != null
38+
&& streamableHttpClientProperties.getConnections().containsKey(connectionName)) {
39+
return streamableHttpClientProperties.getConnections().get(connectionName).returnDirect();
40+
}
41+
return false;
42+
}
43+
44+
}

auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/properties/McpSseClientProperties.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ public Map<String, SseParameters> getConnections() {
6969
* @param url the URL endpoint for SSE communication with the MCP server
7070
* @param sseEndpoint the SSE endpoint for the MCP server
7171
*/
72-
public record SseParameters(String url, String sseEndpoint) {
72+
public record SseParameters(String url, String sseEndpoint, boolean returnDirect) {
73+
7374
}
7475

7576
}

auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/properties/McpStdioClientProperties.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,9 @@ public record Parameters(
124124
/**
125125
* Map of environment variables for the server process.
126126
*/
127-
@JsonProperty("env") Map<String, String> env) {
127+
@JsonProperty("env") Map<String, String> env,
128+
129+
@JsonProperty("returnDirect") boolean returnDirect) {
128130

129131
public ServerParameters toServerParameters() {
130132
return ServerParameters.builder(this.command()).args(this.args()).env(this.env()).build();

auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/properties/McpStreamableHttpClientProperties.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ public Map<String, ConnectionParameters> getConnections() {
6868
* @param url the URL endpoint for Streamable Http communication with the MCP server
6969
* @param endpoint the endpoint for the MCP server
7070
*/
71-
public record ConnectionParameters(String url, String endpoint) {
71+
public record ConnectionParameters(String url, String endpoint, boolean returnDirect) {
72+
7273
}
7374

7475
}

auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/properties/McpSseClientPropertiesTests.java

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ void connectionWithNullUrl() {
105105
void sseParametersRecord() {
106106
String url = "http://test-server:8080/events";
107107
String sseUrl = "/sse";
108-
McpSseClientProperties.SseParameters params = new McpSseClientProperties.SseParameters(url, sseUrl);
108+
McpSseClientProperties.SseParameters params = new McpSseClientProperties.SseParameters(url, sseUrl, false);
109109

110110
assertThat(params.url()).isEqualTo(url);
111111
assertThat(params.sseEndpoint()).isEqualTo(sseUrl);
@@ -114,7 +114,7 @@ void sseParametersRecord() {
114114
@Test
115115
void sseParametersRecordWithNullSseEndpoint() {
116116
String url = "http://test-server:8080/events";
117-
McpSseClientProperties.SseParameters params = new McpSseClientProperties.SseParameters(url, null);
117+
McpSseClientProperties.SseParameters params = new McpSseClientProperties.SseParameters(url, null, false);
118118

119119
assertThat(params.url()).isEqualTo(url);
120120
assertThat(params.sseEndpoint()).isNull();
@@ -150,21 +150,21 @@ void connectionMapManipulation() {
150150

151151
// Add a connection
152152
connections.put("server1",
153-
new McpSseClientProperties.SseParameters("http://localhost:8080/events", "/sse"));
153+
new McpSseClientProperties.SseParameters("http://localhost:8080/events", "/sse", false));
154154
assertThat(properties.getConnections()).hasSize(1);
155155
assertThat(properties.getConnections().get("server1").url()).isEqualTo("http://localhost:8080/events");
156156
assertThat(properties.getConnections().get("server1").sseEndpoint()).isEqualTo("/sse");
157157

158158
// Add another connection
159159
connections.put("server2",
160-
new McpSseClientProperties.SseParameters("http://otherserver:8081/events", null));
160+
new McpSseClientProperties.SseParameters("http://otherserver:8081/events", null, false));
161161
assertThat(properties.getConnections()).hasSize(2);
162162
assertThat(properties.getConnections().get("server2").url()).isEqualTo("http://otherserver:8081/events");
163163
assertThat(properties.getConnections().get("server2").sseEndpoint()).isNull();
164164

165165
// Replace a connection
166166
connections.put("server1",
167-
new McpSseClientProperties.SseParameters("http://newserver:8082/events", "/events"));
167+
new McpSseClientProperties.SseParameters("http://newserver:8082/events", "/events", false));
168168
assertThat(properties.getConnections()).hasSize(2);
169169
assertThat(properties.getConnections().get("server1").url()).isEqualTo("http://newserver:8082/events");
170170
assertThat(properties.getConnections().get("server1").sseEndpoint()).isEqualTo("/events");
@@ -209,13 +209,15 @@ void specialCharactersInConnectionName() {
209209
void connectionWithSseEndpoint() {
210210
this.contextRunner
211211
.withPropertyValues("spring.ai.mcp.client.sse.connections.server1.url=http://localhost:8080",
212-
"spring.ai.mcp.client.sse.connections.server1.sse-endpoint=/events")
212+
"spring.ai.mcp.client.sse.connections.server1.sse-endpoint=/events",
213+
"spring.ai.mcp.client.sse.connections.server1.return-direct=true")
213214
.run(context -> {
214215
McpSseClientProperties properties = context.getBean(McpSseClientProperties.class);
215216
assertThat(properties.getConnections()).hasSize(1);
216217
assertThat(properties.getConnections()).containsKey("server1");
217218
assertThat(properties.getConnections().get("server1").url()).isEqualTo("http://localhost:8080");
218219
assertThat(properties.getConnections().get("server1").sseEndpoint()).isEqualTo("/events");
220+
assertThat(properties.getConnections().get("server1").returnDirect()).isEqualTo(true);
219221
});
220222
}
221223

@@ -224,16 +226,20 @@ void multipleConnectionsWithSseEndpoint() {
224226
this.contextRunner
225227
.withPropertyValues("spring.ai.mcp.client.sse.connections.server1.url=http://localhost:8080",
226228
"spring.ai.mcp.client.sse.connections.server1.sse-endpoint=/events",
229+
"spring.ai.mcp.client.sse.connections.server1.return-direct=true",
227230
"spring.ai.mcp.client.sse.connections.server2.url=http://otherserver:8081",
228-
"spring.ai.mcp.client.sse.connections.server2.sse-endpoint=/sse")
231+
"spring.ai.mcp.client.sse.connections.server2.sse-endpoint=/sse",
232+
"spring.ai.mcp.client.sse.connections.server2.return-direct=false")
229233
.run(context -> {
230234
McpSseClientProperties properties = context.getBean(McpSseClientProperties.class);
231235
assertThat(properties.getConnections()).hasSize(2);
232236
assertThat(properties.getConnections()).containsKeys("server1", "server2");
233237
assertThat(properties.getConnections().get("server1").url()).isEqualTo("http://localhost:8080");
234238
assertThat(properties.getConnections().get("server1").sseEndpoint()).isEqualTo("/events");
239+
assertThat(properties.getConnections().get("server1").returnDirect()).isEqualTo(true);
235240
assertThat(properties.getConnections().get("server2").url()).isEqualTo("http://otherserver:8081");
236241
assertThat(properties.getConnections().get("server2").sseEndpoint()).isEqualTo("/sse");
242+
assertThat(properties.getConnections().get("server2").returnDirect()).isEqualTo(false);
237243
});
238244
}
239245

mcp/common/src/main/java/org/springframework/ai/mcp/McpToolUtils.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
import io.modelcontextprotocol.server.McpServerFeatures.AsyncToolSpecification;
3131
import io.modelcontextprotocol.server.McpStatelessServerFeatures;
3232
import io.modelcontextprotocol.server.McpSyncServerExchange;
33-
import io.modelcontextprotocol.server.McpTransportContext;
3433
import io.modelcontextprotocol.spec.McpSchema;
3534
import io.modelcontextprotocol.spec.McpSchema.CallToolRequest;
3635
import io.modelcontextprotocol.spec.McpSchema.Role;
@@ -203,6 +202,7 @@ private static SharedSyncToolSpecification toSharedSyncToolSpecification(ToolCal
203202
.name(toolCallback.getToolDefinition().name())
204203
.description(toolCallback.getToolDefinition().description())
205204
.inputSchema(toolCallback.getToolDefinition().inputSchema())
205+
.annotations(toToolAnnotations(toolCallback))
206206
.build();
207207

208208
return new SharedSyncToolSpecification(tool, (exchangeOrContext, request) -> {
@@ -222,6 +222,11 @@ private static SharedSyncToolSpecification toSharedSyncToolSpecification(ToolCal
222222
});
223223
}
224224

225+
private static McpSchema.ToolAnnotations toToolAnnotations(ToolCallback toolCallback) {
226+
Boolean returnDirect = toolCallback.getToolMetadata().returnDirect();
227+
return new McpSchema.ToolAnnotations(null, null, null, null, null, returnDirect);
228+
}
229+
225230
/**
226231
* Retrieves the MCP exchange object from the provided tool context if it exists.
227232
* @param toolContext the tool context from which to retrieve the MCP exchange

mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,13 @@
1717
package org.springframework.ai.mcp;
1818

1919
import io.modelcontextprotocol.client.McpSyncClient;
20+
import io.modelcontextprotocol.spec.McpSchema;
2021
import io.modelcontextprotocol.spec.McpSchema.CallToolRequest;
2122
import io.modelcontextprotocol.spec.McpSchema.CallToolResult;
2223
import io.modelcontextprotocol.spec.McpSchema.Tool;
2324
import java.util.Map;
25+
import java.util.Optional;
26+
2427
import org.slf4j.Logger;
2528
import org.slf4j.LoggerFactory;
2629

@@ -30,6 +33,8 @@
3033
import org.springframework.ai.tool.definition.DefaultToolDefinition;
3134
import org.springframework.ai.tool.definition.ToolDefinition;
3235
import org.springframework.ai.tool.execution.ToolExecutionException;
36+
import org.springframework.ai.tool.metadata.DefaultToolMetadata;
37+
import org.springframework.ai.tool.metadata.ToolMetadata;
3338

3439
/**
3540
* Implementation of {@link ToolCallback} that adapts MCP tools to Spring AI's tool
@@ -80,6 +85,24 @@ public SyncMcpToolCallback(McpSyncClient mcpClient, Tool tool) {
8085

8186
}
8287

88+
/**
89+
* Returns the tool metadata for the MCP tool.
90+
* <p>
91+
* The tool metadata includes:
92+
* <ul>
93+
* <li>The tool's return direct flag from the MCP definition</li>
94+
* </ul>
95+
* @return the tool metadata
96+
*/
97+
@Override
98+
public ToolMetadata getToolMetadata() {
99+
Boolean returnDirect = Optional.ofNullable(tool.annotations())
100+
.map(McpSchema.ToolAnnotations::returnDirect)
101+
.orElse(false);
102+
103+
return DefaultToolMetadata.builder().returnDirect(returnDirect).build();
104+
}
105+
83106
/**
84107
* Returns a Spring AI tool definition adapted from the MCP tool.
85108
* <p>

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiCompatibleChatModelIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ static Stream<ChatModel> openAiCompatibleApis() {
6868
.openAiApi(OpenAiApi.builder()
6969
.baseUrl("https://api.groq.com/openai")
7070
.apiKey(System.getenv("GROQ_API_KEY"))
71-
.build())
71+
.build())
7272
.defaultOptions(forModelName("llama3-8b-8192"))
7373
.build());
7474
}

0 commit comments

Comments
 (0)