diff --git a/fastapi_mcp/openapi/utils.py b/fastapi_mcp/openapi/utils.py index 1821d57..24eb80d 100644 --- a/fastapi_mcp/openapi/utils.py +++ b/fastapi_mcp/openapi/utils.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, Set, Optional def get_single_param_type_from_schema(param_schema: Dict[str, Any]) -> str: @@ -16,54 +16,144 @@ def get_single_param_type_from_schema(param_schema: Dict[str, Any]) -> str: return param_schema.get("type", "string") -def resolve_schema_references(schema_part: Dict[str, Any], reference_schema: Dict[str, Any]) -> Dict[str, Any]: +def resolve_schema_references( + schema_part: Dict[str, Any], + reference_schema: Dict[str, Any], + visited_refs: Optional[Set[str]] = None, + skip_components: bool = True, +) -> Dict[str, Any]: """ - Resolve schema references in OpenAPI schemas. + Resolve schema references with cycle detection and performance optimization. Args: schema_part: The part of the schema being processed that may contain references reference_schema: The complete schema used to resolve references from + visited_refs: Set of currently being resolved references (for cycle detection) + skip_components: Whether to skip processing the components section Returns: The schema with references resolved """ - # Make a copy to avoid modifying the input schema - schema_part = schema_part.copy() - - # Handle $ref directly in the schema - if "$ref" in schema_part: - ref_path = schema_part["$ref"] - # Standard OpenAPI references are in the format "#/components/schemas/ModelName" - if ref_path.startswith("#/components/schemas/"): - model_name = ref_path.split("/")[-1] - if "components" in reference_schema and "schemas" in reference_schema["components"]: - if model_name in reference_schema["components"]["schemas"]: - # Replace with the resolved schema - ref_schema = reference_schema["components"]["schemas"][model_name].copy() - # Remove the $ref key and merge with the original schema - schema_part.pop("$ref") - schema_part.update(ref_schema) + if visited_refs is None: + visited_refs = set() + + if not isinstance(schema_part, dict): + return schema_part + + part = schema_part.copy() + + if "$ref" in part: + ref_path = part["$ref"] + if ref_path in visited_refs: + return {"$ref": ref_path} + visited_refs.add(ref_path) + try: + if ref_path.startswith("#/components/schemas/"): + model_name = ref_path.split("/")[-1] + comps = reference_schema.get("components", {}).get("schemas", {}) + if model_name in comps: + ref_schema = comps[model_name] + resolved_ref = resolve_schema_references( + ref_schema, reference_schema, visited_refs, skip_components + ) + part.pop("$ref", None) + if isinstance(resolved_ref, dict) and "$ref" not in resolved_ref: + part.update(resolved_ref) + finally: + # Cleanup + visited_refs.discard(ref_path) + + return part # Recursively resolve references in all dictionary values - for key, value in schema_part.items(): + for key, value in list(part.items()): + if skip_components and key == "components": + continue + if isinstance(value, dict): - schema_part[key] = resolve_schema_references(value, reference_schema) + part[key] = resolve_schema_references(value, reference_schema, visited_refs, skip_components) elif isinstance(value, list): - # Only process list items that are dictionaries since only they can contain refs - schema_part[key] = [ - resolve_schema_references(item, reference_schema) if isinstance(item, dict) else item for item in value + part[key] = [ + resolve_schema_references(item, reference_schema, visited_refs, skip_components) + if isinstance(item, dict) + else item + for item in value ] + return part + + +def resolve_schema_for_display( + schema: Dict[str, Any], + components: Dict[str, Any], + cache: Optional[Dict[str, Dict[str, Any]]] = None, + stack: Optional[Set[str]] = None, +) -> Dict[str, Any]: + """ + Resolve a specific schema for display with caching and cycle detection. + + This function is optimized for just-in-time resolution of specific schemas + without processing the entire components tree. + + Args: + schema: The schema to resolve + components: The components section containing schema definitions + cache: Cache for memoizing resolved schemas + stack: Stack for cycle detection + + Returns: + The resolved schema + """ + if not isinstance(schema, dict): + return schema - return schema_part + cache = cache or {} + stack = stack or set() + + # Handle direct $ref first + if "$ref" in schema: + ref = schema["$ref"] + + # Break the cycle + if ref in stack: + return {"$ref": ref} + + if ref in cache: + return cache[ref] + + if not ref.startswith("#/components/schemas/"): + return schema + + name = ref.split("/")[-1] + target = components.get("schemas", {}).get(name, {}) + stack.add(ref) + resolved = resolve_schema_for_display(target, components, cache, stack) + stack.remove(ref) + cache[ref] = resolved if isinstance(resolved, dict) else target + return cache[ref] + + # Recursively resolve the schema but don't descend into the components subtree + out = schema.copy() + for k, v in list(out.items()): + if k == "components": + continue + if isinstance(v, dict): + out[k] = resolve_schema_for_display(v, components, cache, stack) + elif isinstance(v, list): + out[k] = [resolve_schema_for_display(i, components, cache, stack) if isinstance(i, dict) else i for i in v] + return out def clean_schema_for_display(schema: Dict[str, Any]) -> Dict[str, Any]: """ Clean up a schema for display by removing internal fields. + + Args: schema: The schema to clean + + Returns: The cleaned schema """ @@ -104,9 +194,13 @@ def generate_example_from_schema(schema: Dict[str, Any]) -> Any: """ Generate a simple example response from a JSON schema. + + Args: schema: The JSON schema to generate an example from + + Returns: An example object based on the schema """ diff --git a/tests/test_mcp_execute_api_tool.py b/tests/test_mcp_execute_api_tool.py index cc05d34..a492f65 100644 --- a/tests/test_mcp_execute_api_tool.py +++ b/tests/test_mcp_execute_api_tool.py @@ -10,183 +10,150 @@ async def test_execute_api_tool_success(simple_fastapi_app: FastAPI): """Test successful execution of an API tool.""" mcp = FastApiMCP(simple_fastapi_app) - + # Mock the HTTP client response mock_response = MagicMock() mock_response.json.return_value = {"id": 1, "name": "Test Item"} mock_response.status_code = 200 mock_response.text = '{"id": 1, "name": "Test Item"}' - + # Mock the HTTP client mock_client = AsyncMock() mock_client.get.return_value = mock_response - + # Test parameters tool_name = "get_item" arguments = {"item_id": 1} - + # Execute the tool - with patch.object(mcp, '_http_client', mock_client): + with patch.object(mcp, "_http_client", mock_client): result = await mcp._execute_api_tool( - client=mock_client, - tool_name=tool_name, - arguments=arguments, - operation_map=mcp.operation_map + client=mock_client, tool_name=tool_name, arguments=arguments, operation_map=mcp.operation_map ) - + # Verify the result assert len(result) == 1 assert isinstance(result[0], TextContent) assert result[0].text == '{\n "id": 1,\n "name": "Test Item"\n}' - + # Verify the HTTP client was called correctly - mock_client.get.assert_called_once_with( - "/items/1", - params={}, - headers={} - ) + mock_client.get.assert_called_once_with("/items/1", params={}, headers={}) @pytest.mark.asyncio async def test_execute_api_tool_with_query_params(simple_fastapi_app: FastAPI): """Test execution of an API tool with query parameters.""" mcp = FastApiMCP(simple_fastapi_app) - + # Mock the HTTP client response mock_response = MagicMock() mock_response.json.return_value = [{"id": 1, "name": "Item 1"}, {"id": 2, "name": "Item 2"}] mock_response.status_code = 200 mock_response.text = '[{"id": 1, "name": "Item 1"}, {"id": 2, "name": "Item 2"}]' - + # Mock the HTTP client mock_client = AsyncMock() mock_client.get.return_value = mock_response - + # Test parameters tool_name = "list_items" arguments = {"skip": 0, "limit": 2} - + # Execute the tool - with patch.object(mcp, '_http_client', mock_client): + with patch.object(mcp, "_http_client", mock_client): result = await mcp._execute_api_tool( - client=mock_client, - tool_name=tool_name, - arguments=arguments, - operation_map=mcp.operation_map + client=mock_client, tool_name=tool_name, arguments=arguments, operation_map=mcp.operation_map ) - + # Verify the result assert len(result) == 1 assert isinstance(result[0], TextContent) - + # Verify the HTTP client was called with query parameters - mock_client.get.assert_called_once_with( - "/items/", - params={"skip": 0, "limit": 2}, - headers={} - ) + mock_client.get.assert_called_once_with("/items/", params={"skip": 0, "limit": 2}, headers={}) @pytest.mark.asyncio async def test_execute_api_tool_with_body(simple_fastapi_app: FastAPI): """Test execution of an API tool with request body.""" mcp = FastApiMCP(simple_fastapi_app) - + # Mock the HTTP client response mock_response = MagicMock() mock_response.json.return_value = {"id": 1, "name": "New Item"} mock_response.status_code = 200 mock_response.text = '{"id": 1, "name": "New Item"}' - + # Mock the HTTP client mock_client = AsyncMock() mock_client.post.return_value = mock_response - + # Test parameters tool_name = "create_item" arguments = { - "item": { - "id": 1, - "name": "New Item", - "price": 10.0, - "tags": ["tag1"], - "description": "New item description" - } + "item": {"id": 1, "name": "New Item", "price": 10.0, "tags": ["tag1"], "description": "New item description"} } - + # Execute the tool - with patch.object(mcp, '_http_client', mock_client): + with patch.object(mcp, "_http_client", mock_client): result = await mcp._execute_api_tool( - client=mock_client, - tool_name=tool_name, - arguments=arguments, - operation_map=mcp.operation_map + client=mock_client, tool_name=tool_name, arguments=arguments, operation_map=mcp.operation_map ) - + # Verify the result assert len(result) == 1 assert isinstance(result[0], TextContent) - + # Verify the HTTP client was called with the request body - mock_client.post.assert_called_once_with( - "/items/", - params={}, - headers={}, - json=arguments - ) + mock_client.post.assert_called_once_with("/items/", params={}, headers={}, json=arguments) @pytest.mark.asyncio async def test_execute_api_tool_with_non_ascii_chars(simple_fastapi_app: FastAPI): """Test execution of an API tool with non-ASCII characters.""" mcp = FastApiMCP(simple_fastapi_app) - + # Test data with both ASCII and non-ASCII characters test_data = { "id": 1, "name": "你好 World", # Chinese characters + ASCII "price": 10.0, "tags": ["tag1", "标签2"], # Chinese characters in tags - "description": "这是一个测试描述" # All Chinese characters + "description": "这是一个测试描述", # All Chinese characters } - + # Mock the HTTP client response mock_response = MagicMock() mock_response.json.return_value = test_data mock_response.status_code = 200 - mock_response.text = '{"id": 1, "name": "你好 World", "price": 10.0, "tags": ["tag1", "标签2"], "description": "这是一个测试描述"}' - + mock_response.text = ( + '{"id": 1, "name": "你好 World", "price": 10.0, "tags": ["tag1", "标签2"], "description": "这是一个测试描述"}' + ) + # Mock the HTTP client mock_client = AsyncMock() mock_client.get.return_value = mock_response - + # Test parameters tool_name = "get_item" arguments = {"item_id": 1} - + # Execute the tool - with patch.object(mcp, '_http_client', mock_client): + with patch.object(mcp, "_http_client", mock_client): result = await mcp._execute_api_tool( - client=mock_client, - tool_name=tool_name, - arguments=arguments, - operation_map=mcp.operation_map + client=mock_client, tool_name=tool_name, arguments=arguments, operation_map=mcp.operation_map ) - + # Verify the result assert len(result) == 1 assert isinstance(result[0], TextContent) - + # Verify that the response contains both ASCII and non-ASCII characters response_text = result[0].text assert "你好" in response_text # Chinese characters preserved assert "World" in response_text # ASCII characters preserved assert "标签2" in response_text # Chinese characters in tags preserved assert "这是一个测试描述" in response_text # All Chinese description preserved - + # Verify the HTTP client was called correctly - mock_client.get.assert_called_once_with( - "/items/1", - params={}, - headers={} - ) + mock_client.get.assert_called_once_with("/items/1", params={}, headers={})