diff --git a/fastapi_mcp/openapi/utils.py b/fastapi_mcp/openapi/utils.py index 1821d57..b80d35c 100644 --- a/fastapi_mcp/openapi/utils.py +++ b/fastapi_mcp/openapi/utils.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, Optional, Set def get_single_param_type_from_schema(param_schema: Dict[str, Any]) -> str: @@ -16,17 +16,24 @@ def get_single_param_type_from_schema(param_schema: Dict[str, Any]) -> str: return param_schema.get("type", "string") -def resolve_schema_references(schema_part: Dict[str, Any], reference_schema: Dict[str, Any]) -> Dict[str, Any]: +def resolve_schema_references( + schema_part: Dict[str, Any], + reference_schema: Dict[str, Any], + seen: Optional[Set[str]] = None, +) -> Dict[str, Any]: """ Resolve schema references in OpenAPI schemas. Args: schema_part: The part of the schema being processed that may contain references reference_schema: The complete schema used to resolve references from + seen: A set of already seen references to avoid infinite recursion Returns: The schema with references resolved """ + seen = seen or set() + # Make a copy to avoid modifying the input schema schema_part = schema_part.copy() @@ -35,6 +42,9 @@ def resolve_schema_references(schema_part: Dict[str, Any], reference_schema: Dic ref_path = schema_part["$ref"] # Standard OpenAPI references are in the format "#/components/schemas/ModelName" if ref_path.startswith("#/components/schemas/"): + if ref_path in seen: + return {"$ref": ref_path} + seen.add(ref_path) model_name = ref_path.split("/")[-1] if "components" in reference_schema and "schemas" in reference_schema["components"]: if model_name in reference_schema["components"]["schemas"]: @@ -47,11 +57,12 @@ def resolve_schema_references(schema_part: Dict[str, Any], reference_schema: Dic # Recursively resolve references in all dictionary values for key, value in schema_part.items(): if isinstance(value, dict): - schema_part[key] = resolve_schema_references(value, reference_schema) + schema_part[key] = resolve_schema_references(value, reference_schema, seen) elif isinstance(value, list): # Only process list items that are dictionaries since only they can contain refs schema_part[key] = [ - resolve_schema_references(item, reference_schema) if isinstance(item, dict) else item for item in value + resolve_schema_references(item, reference_schema, seen) if isinstance(item, dict) else item + for item in value ] return schema_part diff --git a/tests/fixtures/types.py b/tests/fixtures/types.py index e5314fa..e95a5f8 100644 --- a/tests/fixtures/types.py +++ b/tests/fixtures/types.py @@ -1,3 +1,4 @@ +from __future__ import annotations from typing import Optional, List, Dict, Any from datetime import datetime, date from enum import Enum @@ -95,6 +96,7 @@ class Product(BaseModel): updated_at: Optional[datetime] = None is_available: bool = True metadata: Dict[str, Any] = {} + related_products: Optional[List[Product]] = None class OrderItem(BaseModel):