Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from dataclasses import dataclass
from typing import Any, AsyncGenerator, Callable, Collection, Tuple, cast, Union
import json
import logging
import re
from http import HTTPStatus

Expand Down Expand Up @@ -103,6 +102,24 @@ def _instrument(self, **kwargs):
),
"mcp.server.streamable_http",
)
# Try multiple response creation points
# Try direct wrapping instead of post-import hook
try:
wrap_function_wrapper(
"mcp.types",
"JSONRPCResponse.__init__",
self._jsonrpc_response_init_wrapper(tracer),
)
except Exception:
# Fallback to post-import hook
register_post_import_hook(
lambda _: wrap_function_wrapper(
"mcp.types",
"JSONRPCResponse.__init__",
self._jsonrpc_response_init_wrapper(tracer),
),
"mcp.types",
)
wrap_function_wrapper(
"mcp.shared.session",
"BaseSession.send_request",
Expand All @@ -112,6 +129,7 @@ def _instrument(self, **kwargs):
def _uninstrument(self, **kwargs):
unwrap("mcp.client.stdio", "stdio_client")
unwrap("mcp.server.stdio", "stdio_server")
unwrap("mcp.types", "JSONRPCResponse.__init__")
self._fastmcp_instrumentor.uninstrument()

def _transport_wrapper(self, tracer):
Expand All @@ -137,11 +155,9 @@ async def traced_method(
yield InstrumentedStreamReader(
read_stream, tracer
), InstrumentedStreamWriter(write_stream, tracer), get_session_id_callback
except Exception as e:
logging.warning(f"mcp instrumentation _transport_wrapper exception: {e}")
except Exception:
yield result
except Exception as e:
logging.warning(f"mcp instrumentation transport_wrapper exception: {e}")
except Exception:
yield result

return traced_method
Expand All @@ -167,6 +183,39 @@ def traced_method(

return traced_method

def _jsonrpc_response_init_wrapper(self, tracer):
@dont_throw
def traced_method(wrapped, instance, args, kwargs):
result_value = kwargs.get("result", None)
if result_value is None and len(args) > 1:
result_value = args[1]

if result_value is not None and isinstance(result_value, dict) and "content" in result_value:
with tracer.start_as_current_span("MCP_Tool_Response") as span:
# Serialize the result data
result_serialized = serialize(result_value)
span.set_attribute(SpanAttributes.MCP_RESPONSE_VALUE, f"{result_serialized}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In _jsonrpc_response_init_wrapper, wrapping the serialized result in an f-string is redundant, and using args[0] as a fallback for id_value assumes a specific init signature. Please document or validate the expected parameter order.

Suggested change
span.set_attribute(SpanAttributes.MCP_RESPONSE_VALUE, f"{result_serialized}")
span.set_attribute(SpanAttributes.MCP_RESPONSE_VALUE, result_serialized)


# Set span status
if result_value.get("isError", False):
span.set_status(Status(StatusCode.ERROR, "Tool execution error"))
else:
span.set_status(Status(StatusCode.OK))

# Add request ID if available
id_value = kwargs.get("id", None)
if id_value is None and len(args) > 0:
id_value = args[0]

if id_value is not None:
span.set_attribute(SpanAttributes.MCP_REQUEST_ID, f"{id_value}")

# Call the original method
result = wrapped(*args, **kwargs)
return result

return traced_method

def patch_mcp_client(self, tracer: Tracer):
@dont_throw
async def traced_method(wrapped, instance, args, kwargs):
Expand Down Expand Up @@ -527,39 +576,18 @@ async def send(self, item: Any) -> Any:
else:
return await self.__wrapped__.send(item)

with self._tracer.start_as_current_span("ResponseStreamWriter") as span:
if hasattr(request, "result"):
span.set_attribute(
SpanAttributes.MCP_RESPONSE_VALUE, f"{serialize(request.result)}"
)
if "isError" in request.result:
if request.result["isError"] is True:
span.set_status(
Status(
StatusCode.ERROR,
f"{request.result['content'][0]['text']}",
)
)
error_type = get_error_type(
request.result["content"][0]["text"]
)
if error_type is not None:
span.set_attribute(ERROR_TYPE, error_type)
if hasattr(request, "id"):
span.set_attribute(SpanAttributes.MCP_REQUEST_ID, f"{request.id}")

if not isinstance(request, JSONRPCRequest):
return await self.__wrapped__.send(item)
meta = None
if not request.params:
request.params = {}
meta = request.params.setdefault("_meta", {})

propagate.get_global_textmap().inject(meta)
if not isinstance(request, JSONRPCRequest):
return await self.__wrapped__.send(item)
meta = None
if not request.params:
request.params = {}
meta = request.params.setdefault("_meta", {})

propagate.get_global_textmap().inject(meta)
return await self.__wrapped__.send(item)


@dataclass(slots=True, frozen=True)
@dataclass
class ItemWithContext:
item: Any
ctx: context.Context
Expand All @@ -579,9 +607,10 @@ async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> Any:

@dont_throw
async def send(self, item: Any) -> Any:
# Removed RequestStreamWriter span creation - we don't need low-level protocol spans
ctx = context.get_current()
return await self.__wrapped__.send(ItemWithContext(item, ctx))
# Create ResponseStreamWriter span for server-side responses
with self._tracer.start_as_current_span("ResponseStreamWriter") as _:
ctx = context.get_current()
return await self.__wrapped__.send(ItemWithContext(item, ctx))


class ContextAttachingStreamReader(ObjectProxy): # type: ignore
Expand Down