1
- from typing import TYPE_CHECKING , Any , List , Optional , Set , Type , TypeVar , cast
1
+ from typing import TYPE_CHECKING , Any , Set , Type , TypeVar , cast
2
2
3
- from pydantic import BaseModel
3
+ from pydantic import BaseModel , create_model
4
4
from pydantic .schema import schema
5
5
6
6
from pydantic_openapi_schema import v3_1_0
9
9
from typing import Dict
10
10
11
11
REF_PREFIX = "#/components/schemas/"
12
+ SCHEMA_NAME_ATTRIBUTE = "__schema_name__"
12
13
13
14
T = TypeVar ("T" , bound = v3_1_0 .OpenAPI )
14
15
@@ -22,45 +23,35 @@ class OpenAPI310PydanticSchema(v3_1_0.Schema):
22
23
23
24
def construct_open_api_with_schema_class (
24
25
open_api_schema : T ,
25
- schema_classes : Optional [List [Type [BaseModel ]]] = None ,
26
- scan_for_pydantic_schema_reference : bool = True ,
27
- by_alias : bool = True ,
28
26
) -> T :
29
27
"""Construct a new OpenAPI object, with the use of pydantic classes to
30
28
produce JSON schemas.
31
29
32
30
Args:
33
- open_api_schema: the base `OpenAPI` object
34
- schema_classes: pydantic classes that their schema will be used as "#/components/schemas" values
35
- scan_for_pydantic_schema_reference: flag to indicate if scanning for `PydanticSchemaReference`
36
- class is needed for "#/components/schemas" value updates
37
- by_alias: construct schema by alias (default is True)
31
+ open_api_schema: An instance of the OpenAPI model.
38
32
39
33
Returns:
40
34
new OpenAPI object with "#/components/schemas" values updated. If there is no update in
41
35
"#/components/schemas" values, the original `open_api` will be returned.
42
36
"""
43
37
copied_schema = open_api_schema .copy (deep = True )
44
-
45
- if scan_for_pydantic_schema_reference :
46
- extracted_schema_classes = extract_pydantic_types_to_openapi_components (
47
- obj = copied_schema , ref_class = v3_1_0 .Reference
48
- )
49
- schema_classes = list (
50
- {* schema_classes , * extracted_schema_classes } if schema_classes else extracted_schema_classes
51
- )
38
+ schema_classes = list (extract_pydantic_types_to_openapi_components (obj = copied_schema , ref_class = v3_1_0 .Reference ))
52
39
53
40
if not schema_classes :
54
41
return open_api_schema
55
42
56
43
if not copied_schema .components :
57
44
copied_schema .components = v3_1_0 .Components (schemas = {})
58
- elif not copied_schema .components .schemas :
45
+ if copied_schema .components .schemas is None : # pragma: no cover
59
46
copied_schema .components .schemas = cast ("Dict[str, Any]" , {})
60
47
48
+ schema_classes = [
49
+ cls if not hasattr (cls , "__schema_name__" ) else create_model (getattr (cls , SCHEMA_NAME_ATTRIBUTE ), __base__ = cls )
50
+ for cls in schema_classes
51
+ ]
61
52
schema_classes .sort (key = lambda x : x .__name__ )
62
- schema_definitions = schema (schema_classes , by_alias = by_alias , ref_prefix = REF_PREFIX )["definitions" ]
63
- copied_schema .components .schemas .update ( # type: ignore
53
+ schema_definitions = schema (schema_classes , ref_prefix = REF_PREFIX )["definitions" ]
54
+ copied_schema .components .schemas .update (
64
55
{key : v3_1_0 .Schema .parse_obj (schema_dict ) for key , schema_dict in schema_definitions .items ()}
65
56
)
66
57
return copied_schema
@@ -84,22 +75,34 @@ def extract_pydantic_types_to_openapi_components(obj: Any, ref_class: Type[v3_1_
84
75
for field in fields :
85
76
child_obj = getattr (obj , field )
86
77
if isinstance (child_obj , OpenAPI310PydanticSchema ):
87
- setattr (obj , field , ref_class (ref = REF_PREFIX + child_obj .schema_class . __name__ ))
78
+ setattr (obj , field , ref_class (ref = create_ref_prefix ( child_obj .schema_class ) ))
88
79
pydantic_schemas .add (child_obj .schema_class )
89
80
else :
90
81
pydantic_schemas .update (extract_pydantic_types_to_openapi_components (child_obj , ref_class = ref_class ))
91
82
elif isinstance (obj , list ):
92
83
for index , elem in enumerate (obj ):
93
84
if isinstance (elem , OpenAPI310PydanticSchema ):
94
- obj [index ] = ref_class (ref = REF_PREFIX + elem .schema_class . __name__ )
85
+ obj [index ] = ref_class (ref = create_ref_prefix ( elem .schema_class ) )
95
86
pydantic_schemas .add (elem .schema_class )
96
87
else :
97
88
pydantic_schemas .update (extract_pydantic_types_to_openapi_components (elem , ref_class = ref_class ))
98
89
elif isinstance (obj , dict ):
99
90
for key , value in obj .items ():
100
91
if isinstance (value , OpenAPI310PydanticSchema ):
101
- obj [key ] = ref_class (ref = REF_PREFIX + value .schema_class . __name__ )
92
+ obj [key ] = ref_class (ref = create_ref_prefix ( value .schema_class ) )
102
93
pydantic_schemas .add (value .schema_class )
103
94
else :
104
95
pydantic_schemas .update (extract_pydantic_types_to_openapi_components (value , ref_class = ref_class ))
105
96
return pydantic_schemas
97
+
98
+
99
+ def create_ref_prefix (model : Type [BaseModel ]) -> str :
100
+ """
101
+
102
+ Args:
103
+ model: Pydantic model instance.
104
+
105
+ Returns:
106
+ A prefixed name.
107
+ """
108
+ return REF_PREFIX + getattr (model , SCHEMA_NAME_ATTRIBUTE , model .__name__ )
0 commit comments