1
- from typing import TYPE_CHECKING , Any , List , Optional , Set , Type , TypeVar , Union , cast
1
+ from typing import TYPE_CHECKING , Any , List , Optional , Set , Type , TypeVar , cast
2
2
3
3
from pydantic import BaseModel
4
4
from pydantic .schema import schema
5
5
6
- from pydantic_openapi_schema import v3_0_3 , v3_1_0
6
+ from pydantic_openapi_schema import v3_1_0
7
7
8
8
if TYPE_CHECKING :
9
9
from typing import Dict
10
10
11
11
REF_PREFIX = "#/components/schemas/"
12
12
13
- T = TypeVar ("T" , bound = Union [v3_0_3 .OpenAPI , v3_1_0 .OpenAPI ])
14
-
15
-
16
- class OpenAPI303PydanticSchema (v3_0_3 .Schema ):
17
- """Special `Schema` class to indicate a reference from pydantic class."""
18
-
19
- schema_class : Type [BaseModel ]
20
- """the class that is used for generate the schema"""
13
+ T = TypeVar ("T" , bound = v3_1_0 .OpenAPI )
21
14
22
15
23
16
class OpenAPI310PydanticSchema (v3_1_0 .Schema ):
@@ -38,7 +31,7 @@ def construct_open_api_with_schema_class(
38
31
39
32
Args:
40
33
open_api_schema: the base `OpenAPI` object
41
- schema_classes: pydanitic classes that their schema will be used "#/components/schemas" values
34
+ schema_classes: pydantic classes that their schema will be used as "#/components/schemas" values
42
35
scan_for_pydantic_schema_reference: flag to indicate if scanning for `PydanticSchemaReference`
43
36
class is needed for "#/components/schemas" value updates
44
37
by_alias: construct schema by alias (default is True)
@@ -47,13 +40,11 @@ class is needed for "#/components/schemas" value updates
47
40
new OpenAPI object with "#/components/schemas" values updated. If there is no update in
48
41
"#/components/schemas" values, the original `open_api` will be returned.
49
42
"""
50
- specs = v3_1_0 if isinstance (open_api_schema , v3_1_0 .OpenAPI ) else v3_0_3
51
-
52
43
copied_schema = open_api_schema .copy (deep = True )
53
44
54
45
if scan_for_pydantic_schema_reference :
55
46
extracted_schema_classes = extract_pydantic_types_to_openapi_components (
56
- obj = copied_schema , ref_class = specs .Reference
47
+ obj = copied_schema , ref_class = v3_1_0 .Reference
57
48
)
58
49
schema_classes = list (
59
50
{* schema_classes , * extracted_schema_classes } if schema_classes else extracted_schema_classes
@@ -63,21 +54,19 @@ class is needed for "#/components/schemas" value updates
63
54
return open_api_schema
64
55
65
56
if not copied_schema .components :
66
- copied_schema .components = specs .Components (schemas = {})
57
+ copied_schema .components = v3_1_0 .Components (schemas = {})
67
58
elif not copied_schema .components .schemas :
68
59
copied_schema .components .schemas = cast ("Dict[str, Any]" , {})
69
60
70
61
schema_classes .sort (key = lambda x : x .__name__ )
71
62
schema_definitions = schema (schema_classes , by_alias = by_alias , ref_prefix = REF_PREFIX )["definitions" ]
72
63
copied_schema .components .schemas .update ( # type: ignore
73
- {key : specs .Schema .parse_obj (schema_dict ) for key , schema_dict in schema_definitions .items ()}
64
+ {key : v3_1_0 .Schema .parse_obj (schema_dict ) for key , schema_dict in schema_definitions .items ()}
74
65
)
75
- return cast ( "T" , copied_schema )
66
+ return copied_schema
76
67
77
68
78
- def extract_pydantic_types_to_openapi_components (
79
- obj : Any , ref_class : Union [Type [v3_0_3 .Reference ], Type [v3_1_0 .Reference ]]
80
- ) -> Set [Type [BaseModel ]]:
69
+ def extract_pydantic_types_to_openapi_components (obj : Any , ref_class : Type [v3_1_0 .Reference ]) -> Set [Type [BaseModel ]]:
81
70
"""Recursively traverses the OpenAPI document, replacing any found Pydantic
82
71
Models with $references to the schema's components section and returning
83
72
the pydantic models themselves.
@@ -94,21 +83,21 @@ def extract_pydantic_types_to_openapi_components(
94
83
fields = obj .__fields_set__
95
84
for field in fields :
96
85
child_obj = getattr (obj , field )
97
- if isinstance (child_obj , ( OpenAPI303PydanticSchema , OpenAPI310PydanticSchema ) ):
86
+ if isinstance (child_obj , OpenAPI310PydanticSchema ):
98
87
setattr (obj , field , ref_class (ref = REF_PREFIX + child_obj .schema_class .__name__ ))
99
88
pydantic_schemas .add (child_obj .schema_class )
100
89
else :
101
90
pydantic_schemas .update (extract_pydantic_types_to_openapi_components (child_obj , ref_class = ref_class ))
102
91
elif isinstance (obj , list ):
103
92
for index , elem in enumerate (obj ):
104
- if isinstance (elem , ( OpenAPI303PydanticSchema , OpenAPI310PydanticSchema ) ):
93
+ if isinstance (elem , OpenAPI310PydanticSchema ):
105
94
obj [index ] = ref_class (ref = REF_PREFIX + elem .schema_class .__name__ )
106
95
pydantic_schemas .add (elem .schema_class )
107
96
else :
108
97
pydantic_schemas .update (extract_pydantic_types_to_openapi_components (elem , ref_class = ref_class ))
109
98
elif isinstance (obj , dict ):
110
99
for key , value in obj .items ():
111
- if isinstance (value , ( OpenAPI303PydanticSchema , OpenAPI310PydanticSchema ) ):
100
+ if isinstance (value , OpenAPI310PydanticSchema ):
112
101
obj [key ] = ref_class (ref = REF_PREFIX + value .schema_class .__name__ )
113
102
pydantic_schemas .add (value .schema_class )
114
103
else :
0 commit comments