diff --git a/requirements/requirements-optionals.txt b/requirements/requirements-optionals.txt index bac597c953..8f7573b8f6 100644 --- a/requirements/requirements-optionals.txt +++ b/requirements/requirements-optionals.txt @@ -1,6 +1,4 @@ # Optional packages which may be used with REST framework. -coreapi==2.3.1 -coreschema==0.0.4 django-filter django-guardian>=2.4.0,<2.5 inflection==0.5.1 diff --git a/rest_framework/authtoken/views.py b/rest_framework/authtoken/views.py index 50f9acbd90..cd6c9d679c 100644 --- a/rest_framework/authtoken/views.py +++ b/rest_framework/authtoken/views.py @@ -1,10 +1,7 @@ from rest_framework import parsers, renderers from rest_framework.authtoken.models import Token from rest_framework.authtoken.serializers import AuthTokenSerializer -from rest_framework.compat import coreapi, coreschema from rest_framework.response import Response -from rest_framework.schemas import ManualSchema -from rest_framework.schemas import coreapi as coreapi_schema from rest_framework.views import APIView @@ -15,31 +12,6 @@ class ObtainAuthToken(APIView): renderer_classes = (renderers.JSONRenderer,) serializer_class = AuthTokenSerializer - if coreapi_schema.is_enabled(): - schema = ManualSchema( - fields=[ - coreapi.Field( - name="username", - required=True, - location='form', - schema=coreschema.String( - title="Username", - description="Valid username for authentication", - ), - ), - coreapi.Field( - name="password", - required=True, - location='form', - schema=coreschema.String( - title="Password", - description="Valid password for authentication", - ), - ), - ], - encoding="application/json", - ) - def get_serializer_context(self): return { 'request': self.request, diff --git a/rest_framework/compat.py b/rest_framework/compat.py index ff21bacff4..7cf68168ca 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -23,32 +23,20 @@ def unicode_http_header(value): postgres_fields = None -# coreapi is required for CoreAPI schema generation -try: - import coreapi -except ImportError: - coreapi = None - -# uritemplate is required for OpenAPI and CoreAPI schema generation +# uritemplate is required for OpenAPI schema generation try: import uritemplate except ImportError: uritemplate = None -# coreschema is optional -try: - import coreschema -except ImportError: - coreschema = None - - # pyyaml is optional try: import yaml except ImportError: yaml = None + # inflection is optional try: import inflection diff --git a/rest_framework/documentation.py b/rest_framework/documentation.py deleted file mode 100644 index 53e5ab551a..0000000000 --- a/rest_framework/documentation.py +++ /dev/null @@ -1,88 +0,0 @@ -from django.urls import include, path - -from rest_framework.renderers import ( - CoreJSONRenderer, DocumentationRenderer, SchemaJSRenderer -) -from rest_framework.schemas import SchemaGenerator, get_schema_view -from rest_framework.settings import api_settings - - -def get_docs_view( - title=None, description=None, schema_url=None, urlconf=None, - public=True, patterns=None, generator_class=SchemaGenerator, - authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES, - permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES, - renderer_classes=None): - - if renderer_classes is None: - renderer_classes = [DocumentationRenderer, CoreJSONRenderer] - - return get_schema_view( - title=title, - url=schema_url, - urlconf=urlconf, - description=description, - renderer_classes=renderer_classes, - public=public, - patterns=patterns, - generator_class=generator_class, - authentication_classes=authentication_classes, - permission_classes=permission_classes, - ) - - -def get_schemajs_view( - title=None, description=None, schema_url=None, urlconf=None, - public=True, patterns=None, generator_class=SchemaGenerator, - authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES, - permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES): - renderer_classes = [SchemaJSRenderer] - - return get_schema_view( - title=title, - url=schema_url, - urlconf=urlconf, - description=description, - renderer_classes=renderer_classes, - public=public, - patterns=patterns, - generator_class=generator_class, - authentication_classes=authentication_classes, - permission_classes=permission_classes, - ) - - -def include_docs_urls( - title=None, description=None, schema_url=None, urlconf=None, - public=True, patterns=None, generator_class=SchemaGenerator, - authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES, - permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES, - renderer_classes=None): - docs_view = get_docs_view( - title=title, - description=description, - schema_url=schema_url, - urlconf=urlconf, - public=public, - patterns=patterns, - generator_class=generator_class, - authentication_classes=authentication_classes, - renderer_classes=renderer_classes, - permission_classes=permission_classes, - ) - schema_js_view = get_schemajs_view( - title=title, - description=description, - schema_url=schema_url, - urlconf=urlconf, - public=public, - patterns=patterns, - generator_class=generator_class, - authentication_classes=authentication_classes, - permission_classes=permission_classes, - ) - urls = [ - path('', docs_view, name='docs-index'), - path('schema.js', schema_js_view, name='schema-js') - ] - return include((urls, 'api-docs'), namespace='api-docs') diff --git a/rest_framework/filters.py b/rest_framework/filters.py index 3f4730da84..3586de22a1 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -3,7 +3,6 @@ returned by list views. """ import operator -import warnings from functools import reduce from django.core.exceptions import FieldDoesNotExist, ImproperlyConfigured @@ -14,8 +13,6 @@ from django.utils.text import smart_split, unescape_string_literal from django.utils.translation import gettext_lazy as _ -from rest_framework import RemovedInDRF317Warning -from rest_framework.compat import coreapi, coreschema from rest_framework.fields import CharField from rest_framework.settings import api_settings @@ -48,13 +45,6 @@ def filter_queryset(self, request, queryset, view): """ raise NotImplementedError(".filter_queryset() must be overridden.") - def get_schema_fields(self, view): - assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' - if coreapi is not None: - warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning) - assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' - return [] - def get_schema_operation_parameters(self, view): return [] @@ -186,23 +176,6 @@ def to_html(self, request, queryset, view): template = loader.get_template(self.template) return template.render(context) - def get_schema_fields(self, view): - assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' - if coreapi is not None: - warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning) - assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' - return [ - coreapi.Field( - name=self.search_param, - required=False, - location='query', - schema=coreschema.String( - title=force_str(self.search_title), - description=force_str(self.search_description) - ) - ) - ] - def get_schema_operation_parameters(self, view): return [ { @@ -348,23 +321,6 @@ def to_html(self, request, queryset, view): context = self.get_template_context(request, queryset, view) return template.render(context) - def get_schema_fields(self, view): - assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' - if coreapi is not None: - warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning) - assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' - return [ - coreapi.Field( - name=self.ordering_param, - required=False, - location='query', - schema=coreschema.String( - title=force_str(self.ordering_title), - description=force_str(self.ordering_description) - ) - ) - ] - def get_schema_operation_parameters(self, view): return [ { diff --git a/rest_framework/management/__init__.py b/rest_framework/management/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/rest_framework/management/commands/__init__.py b/rest_framework/management/commands/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/rest_framework/management/commands/generateschema.py b/rest_framework/management/commands/generateschema.py deleted file mode 100644 index 8c73e4b9c8..0000000000 --- a/rest_framework/management/commands/generateschema.py +++ /dev/null @@ -1,71 +0,0 @@ -from django.core.management.base import BaseCommand -from django.utils.module_loading import import_string - -from rest_framework import renderers -from rest_framework.schemas import coreapi -from rest_framework.schemas.openapi import SchemaGenerator - -OPENAPI_MODE = 'openapi' -COREAPI_MODE = 'coreapi' - - -class Command(BaseCommand): - help = "Generates configured API schema for project." - - def get_mode(self): - return COREAPI_MODE if coreapi.is_enabled() else OPENAPI_MODE - - def add_arguments(self, parser): - parser.add_argument('--title', dest="title", default='', type=str) - parser.add_argument('--url', dest="url", default=None, type=str) - parser.add_argument('--description', dest="description", default=None, type=str) - if self.get_mode() == COREAPI_MODE: - parser.add_argument('--format', dest="format", choices=['openapi', 'openapi-json', 'corejson'], default='openapi', type=str) - else: - parser.add_argument('--format', dest="format", choices=['openapi', 'openapi-json'], default='openapi', type=str) - parser.add_argument('--urlconf', dest="urlconf", default=None, type=str) - parser.add_argument('--generator_class', dest="generator_class", default=None, type=str) - parser.add_argument('--file', dest="file", default=None, type=str) - parser.add_argument('--api_version', dest="api_version", default='', type=str) - - def handle(self, *args, **options): - if options['generator_class']: - generator_class = import_string(options['generator_class']) - else: - generator_class = self.get_generator_class() - generator = generator_class( - url=options['url'], - title=options['title'], - description=options['description'], - urlconf=options['urlconf'], - version=options['api_version'], - ) - schema = generator.get_schema(request=None, public=True) - renderer = self.get_renderer(options['format']) - output = renderer.render(schema, renderer_context={}) - - if options['file']: - with open(options['file'], 'wb') as f: - f.write(output) - else: - self.stdout.write(output.decode()) - - def get_renderer(self, format): - if self.get_mode() == COREAPI_MODE: - renderer_cls = { - 'corejson': renderers.CoreJSONRenderer, - 'openapi': renderers.CoreAPIOpenAPIRenderer, - 'openapi-json': renderers.CoreAPIJSONOpenAPIRenderer, - }[format] - return renderer_cls() - - renderer_cls = { - 'openapi': renderers.OpenAPIRenderer, - 'openapi-json': renderers.JSONOpenAPIRenderer, - }[format] - return renderer_cls() - - def get_generator_class(self): - if self.get_mode() == COREAPI_MODE: - return coreapi.SchemaGenerator - return SchemaGenerator diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py index a543ceeb50..164d7c2f85 100644 --- a/rest_framework/pagination.py +++ b/rest_framework/pagination.py @@ -4,7 +4,6 @@ """ import contextlib -import warnings from base64 import b64decode, b64encode from collections import namedtuple from urllib import parse @@ -15,8 +14,6 @@ from django.utils.encoding import force_str from django.utils.translation import gettext_lazy as _ -from rest_framework import RemovedInDRF317Warning -from rest_framework.compat import coreapi, coreschema from rest_framework.exceptions import NotFound from rest_framework.response import Response from rest_framework.settings import api_settings @@ -151,12 +148,6 @@ def to_html(self): # pragma: no cover def get_results(self, data): return data['results'] - def get_schema_fields(self, view): - assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' - if coreapi is not None: - warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning) - return [] - def get_schema_operation_parameters(self, view): return [] @@ -313,36 +304,6 @@ def to_html(self): context = self.get_html_context() return template.render(context) - def get_schema_fields(self, view): - assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' - if coreapi is not None: - warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning) - assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' - fields = [ - coreapi.Field( - name=self.page_query_param, - required=False, - location='query', - schema=coreschema.Integer( - title='Page', - description=force_str(self.page_query_description) - ) - ) - ] - if self.page_size_query_param is not None: - fields.append( - coreapi.Field( - name=self.page_size_query_param, - required=False, - location='query', - schema=coreschema.Integer( - title='Page size', - description=force_str(self.page_size_query_description) - ) - ) - ) - return fields - def get_schema_operation_parameters(self, view): parameters = [ { @@ -530,32 +491,6 @@ def get_count(self, queryset): except (AttributeError, TypeError): return len(queryset) - def get_schema_fields(self, view): - assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' - if coreapi is not None: - warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning) - assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' - return [ - coreapi.Field( - name=self.limit_query_param, - required=False, - location='query', - schema=coreschema.Integer( - title='Limit', - description=force_str(self.limit_query_description) - ) - ), - coreapi.Field( - name=self.offset_query_param, - required=False, - location='query', - schema=coreschema.Integer( - title='Offset', - description=force_str(self.offset_query_description) - ) - ) - ] - def get_schema_operation_parameters(self, view): parameters = [ { @@ -933,36 +868,6 @@ def to_html(self): context = self.get_html_context() return template.render(context) - def get_schema_fields(self, view): - assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' - if coreapi is not None: - warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning) - assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' - fields = [ - coreapi.Field( - name=self.cursor_query_param, - required=False, - location='query', - schema=coreschema.String( - title='Cursor', - description=force_str(self.cursor_query_description) - ) - ) - ] - if self.page_size_query_param is not None: - fields.append( - coreapi.Field( - name=self.page_size_query_param, - required=False, - location='query', - schema=coreschema.Integer( - title='Page size', - description=force_str(self.page_size_query_description) - ) - ) - ) - return fields - def get_schema_operation_parameters(self, view): parameters = [ { diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index b81f9ab46c..74677dbd80 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -7,10 +7,8 @@ REST framework also provides an HTML renderer that renders the browsable API. """ -import base64 import contextlib import datetime -from urllib import parse from django import forms from django.conf import settings @@ -18,14 +16,12 @@ from django.core.paginator import Page from django.template import engines, loader from django.urls import NoReverseMatch -from django.utils.html import mark_safe from django.utils.http import parse_header_parameters from django.utils.safestring import SafeString from rest_framework import VERSION, exceptions, serializers, status from rest_framework.compat import ( - INDENT_SEPARATORS, LONG_SEPARATORS, SHORT_SEPARATORS, coreapi, coreschema, - pygments_css, yaml + INDENT_SEPARATORS, LONG_SEPARATORS, SHORT_SEPARATORS, pygments_css, yaml ) from rest_framework.exceptions import ParseError from rest_framework.request import is_form_media_type, override_method @@ -418,7 +414,7 @@ def get_content(self, renderer, data, render_style = getattr(renderer, 'render_style', 'text') assert render_style in ['text', 'binary'], 'Expected .render_style ' \ - '"text" or "binary", but got "%s"' % render_style + '"text" or "binary", but got "%s"' % render_style if render_style == 'binary': return '[%d bytes of binary content]' % len(content) @@ -487,8 +483,8 @@ def get_rendered_html_form(self, data, view, method, request): has_serializer_class = getattr(view, 'serializer_class', None) if ( - (not has_serializer and not has_serializer_class) or - not any(is_form_media_type(parser.media_type) for parser in view.parser_classes) + (not has_serializer and not has_serializer_class) or + not any(is_form_media_type(parser.media_type) for parser in view.parser_classes) ): return @@ -837,7 +833,7 @@ def get_result_url(self, result, view): and viewset-like (has `.basename` / `.reverse_action()`). """ if not hasattr(view, 'reverse_action') or \ - not hasattr(view, 'lookup_field'): + not hasattr(view, 'lookup_field'): return lookup_field = view.lookup_field @@ -850,57 +846,6 @@ def get_result_url(self, result, view): return -class DocumentationRenderer(BaseRenderer): - media_type = 'text/html' - format = 'html' - charset = 'utf-8' - template = 'rest_framework/docs/index.html' - error_template = 'rest_framework/docs/error.html' - code_style = 'emacs' - languages = ['shell', 'javascript', 'python'] - - def get_context(self, data, request): - return { - 'document': data, - 'langs': self.languages, - 'lang_htmls': ["rest_framework/docs/langs/%s.html" % language for language in self.languages], - 'lang_intro_htmls': ["rest_framework/docs/langs/%s-intro.html" % language for language in self.languages], - 'code_style': pygments_css(self.code_style), - 'request': request - } - - def render(self, data, accepted_media_type=None, renderer_context=None): - if isinstance(data, coreapi.Document): - template = loader.get_template(self.template) - context = self.get_context(data, renderer_context['request']) - return template.render(context, request=renderer_context['request']) - else: - template = loader.get_template(self.error_template) - context = { - "data": data, - "request": renderer_context['request'], - "response": renderer_context['response'], - "debug": settings.DEBUG, - } - return template.render(context, request=renderer_context['request']) - - -class SchemaJSRenderer(BaseRenderer): - media_type = 'application/javascript' - format = 'javascript' - charset = 'utf-8' - template = 'rest_framework/schema.js' - - def render(self, data, accepted_media_type=None, renderer_context=None): - codec = coreapi.codecs.CoreJSONCodec() - schema = base64.b64encode(codec.encode(data)).decode('ascii') - - template = loader.get_template(self.template) - context = {'schema': mark_safe(schema)} - request = renderer_context['request'] - return template.render(context, request=request) - - class MultiPartRenderer(BaseRenderer): media_type = 'multipart/form-data; boundary=BoUnDaRyStRiNg' format = 'multipart' @@ -921,139 +866,6 @@ def render(self, data, accepted_media_type=None, renderer_context=None): return encode_multipart(self.BOUNDARY, data) -class CoreJSONRenderer(BaseRenderer): - media_type = 'application/coreapi+json' - charset = None - format = 'corejson' - - def __init__(self): - assert coreapi, 'Using CoreJSONRenderer, but `coreapi` is not installed.' - - def render(self, data, media_type=None, renderer_context=None): - indent = bool(renderer_context.get('indent', 0)) - codec = coreapi.codecs.CoreJSONCodec() - return codec.dump(data, indent=indent) - - -class _BaseOpenAPIRenderer: - def get_schema(self, instance): - CLASS_TO_TYPENAME = { - coreschema.Object: 'object', - coreschema.Array: 'array', - coreschema.Number: 'number', - coreschema.Integer: 'integer', - coreschema.String: 'string', - coreschema.Boolean: 'boolean', - } - - schema = {} - if instance.__class__ in CLASS_TO_TYPENAME: - schema['type'] = CLASS_TO_TYPENAME[instance.__class__] - schema['title'] = instance.title - schema['description'] = instance.description - if hasattr(instance, 'enum'): - schema['enum'] = instance.enum - return schema - - def get_parameters(self, link): - parameters = [] - for field in link.fields: - if field.location not in ['path', 'query']: - continue - parameter = { - 'name': field.name, - 'in': field.location, - } - if field.required: - parameter['required'] = True - if field.description: - parameter['description'] = field.description - if field.schema: - parameter['schema'] = self.get_schema(field.schema) - parameters.append(parameter) - return parameters - - def get_operation(self, link, name, tag): - operation_id = "%s_%s" % (tag, name) if tag else name - parameters = self.get_parameters(link) - - operation = { - 'operationId': operation_id, - } - if link.title: - operation['summary'] = link.title - if link.description: - operation['description'] = link.description - if parameters: - operation['parameters'] = parameters - if tag: - operation['tags'] = [tag] - return operation - - def get_paths(self, document): - paths = {} - - tag = None - for name, link in document.links.items(): - path = parse.urlparse(link.url).path - method = link.action.lower() - paths.setdefault(path, {}) - paths[path][method] = self.get_operation(link, name, tag=tag) - - for tag, section in document.data.items(): - for name, link in section.links.items(): - path = parse.urlparse(link.url).path - method = link.action.lower() - paths.setdefault(path, {}) - paths[path][method] = self.get_operation(link, name, tag=tag) - - return paths - - def get_structure(self, data): - return { - 'openapi': '3.0.0', - 'info': { - 'version': '', - 'title': data.title, - 'description': data.description - }, - 'servers': [{ - 'url': data.url - }], - 'paths': self.get_paths(data) - } - - -class CoreAPIOpenAPIRenderer(_BaseOpenAPIRenderer): - media_type = 'application/vnd.oai.openapi' - charset = None - format = 'openapi' - - def __init__(self): - assert coreapi, 'Using CoreAPIOpenAPIRenderer, but `coreapi` is not installed.' - assert yaml, 'Using CoreAPIOpenAPIRenderer, but `pyyaml` is not installed.' - - def render(self, data, media_type=None, renderer_context=None): - structure = self.get_structure(data) - return yaml.dump(structure, default_flow_style=False).encode() - - -class CoreAPIJSONOpenAPIRenderer(_BaseOpenAPIRenderer): - media_type = 'application/vnd.oai.openapi+json' - charset = None - format = 'openapi-json' - ensure_ascii = not api_settings.UNICODE_JSON - - def __init__(self): - assert coreapi, 'Using CoreAPIJSONOpenAPIRenderer, but `coreapi` is not installed.' - - def render(self, data, media_type=None, renderer_context=None): - structure = self.get_structure(data) - return json.dumps( - structure, indent=4, - ensure_ascii=self.ensure_ascii).encode('utf-8') - - class OpenAPIRenderer(BaseRenderer): media_type = 'application/vnd.oai.openapi' charset = None @@ -1067,6 +879,7 @@ def render(self, data, media_type=None, renderer_context=None): class Dumper(yaml.Dumper): def ignore_aliases(self, data): return True + Dumper.add_representer(SafeString, Dumper.represent_str) Dumper.add_representer(datetime.timedelta, encoders.CustomScalar.represent_timedelta) return yaml.dump(data, default_flow_style=False, sort_keys=False, Dumper=Dumper).encode('utf-8') diff --git a/rest_framework/schemas/__init__.py b/rest_framework/schemas/__init__.py deleted file mode 100644 index b63cb23536..0000000000 --- a/rest_framework/schemas/__init__.py +++ /dev/null @@ -1,58 +0,0 @@ -""" -rest_framework.schemas - -schemas: - __init__.py - generators.py # Top-down schema generation - inspectors.py # Per-endpoint view introspection - utils.py # Shared helper functions - views.py # Houses `SchemaView`, `APIView` subclass. - -We expose a minimal "public" API directly from `schemas`. This covers the -basic use-cases: - - from rest_framework.schemas import ( - AutoSchema, - ManualSchema, - get_schema_view, - SchemaGenerator, - ) - -Other access should target the submodules directly -""" -from rest_framework.settings import api_settings - -from . import coreapi, openapi -from .coreapi import AutoSchema, ManualSchema, SchemaGenerator # noqa -from .inspectors import DefaultSchema # noqa - - -def get_schema_view( - title=None, url=None, description=None, urlconf=None, renderer_classes=None, - public=False, patterns=None, generator_class=None, - authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES, - permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES, - version=None): - """ - Return a schema view. - """ - if generator_class is None: - if coreapi.is_enabled(): - generator_class = coreapi.SchemaGenerator - else: - generator_class = openapi.SchemaGenerator - - generator = generator_class( - title=title, url=url, description=description, - urlconf=urlconf, patterns=patterns, version=version - ) - - # Avoid import cycle on APIView - from .views import SchemaView - return SchemaView.as_view( - renderer_classes=renderer_classes, - schema_generator=generator, - public=public, - authentication_classes=authentication_classes, - permission_classes=permission_classes, - ) diff --git a/rest_framework/schemas/coreapi.py b/rest_framework/schemas/coreapi.py deleted file mode 100644 index 657178304a..0000000000 --- a/rest_framework/schemas/coreapi.py +++ /dev/null @@ -1,626 +0,0 @@ -import warnings -from collections import Counter -from urllib import parse - -from django.db import models -from django.utils.encoding import force_str - -from rest_framework import RemovedInDRF317Warning, exceptions, serializers -from rest_framework.compat import coreapi, coreschema, uritemplate -from rest_framework.settings import api_settings - -from .generators import BaseSchemaGenerator -from .inspectors import ViewInspector -from .utils import get_pk_description, is_list_view - - -def common_path(paths): - split_paths = [path.strip('/').split('/') for path in paths] - s1 = min(split_paths) - s2 = max(split_paths) - common = s1 - for i, c in enumerate(s1): - if c != s2[i]: - common = s1[:i] - break - return '/' + '/'.join(common) - - -def is_custom_action(action): - return action not in { - 'retrieve', 'list', 'create', 'update', 'partial_update', 'destroy' - } - - -def distribute_links(obj): - for key, value in obj.items(): - distribute_links(value) - - for preferred_key, link in obj.links: - key = obj.get_available_key(preferred_key) - obj[key] = link - - -INSERT_INTO_COLLISION_FMT = """ -Schema Naming Collision. - -coreapi.Link for URL path {value_url} cannot be inserted into schema. -Position conflicts with coreapi.Link for URL path {target_url}. - -Attempted to insert link with keys: {keys}. - -Adjust URLs to avoid naming collision or override `SchemaGenerator.get_keys()` -to customise schema structure. -""" - - -class LinkNode(dict): - def __init__(self): - self.links = [] - self.methods_counter = Counter() - super().__init__() - - def get_available_key(self, preferred_key): - if preferred_key not in self: - return preferred_key - - while True: - current_val = self.methods_counter[preferred_key] - self.methods_counter[preferred_key] += 1 - - key = f'{preferred_key}_{current_val}' - if key not in self: - return key - - -def insert_into(target, keys, value): - """ - Nested dictionary insertion. - - >>> example = {} - >>> insert_into(example, ['a', 'b', 'c'], 123) - >>> example - LinkNode({'a': LinkNode({'b': LinkNode({'c': LinkNode(links=[123])}}}))) - """ - for key in keys[:-1]: - if key not in target: - target[key] = LinkNode() - target = target[key] - - try: - target.links.append((keys[-1], value)) - except TypeError: - msg = INSERT_INTO_COLLISION_FMT.format( - value_url=value.url, - target_url=target.url, - keys=keys - ) - raise ValueError(msg) - - -class SchemaGenerator(BaseSchemaGenerator): - """ - Original CoreAPI version. - """ - # Map HTTP methods onto actions. - default_mapping = { - 'get': 'retrieve', - 'post': 'create', - 'put': 'update', - 'patch': 'partial_update', - 'delete': 'destroy', - } - - # Map the method names we use for viewset actions onto external schema names. - # These give us names that are more suitable for the external representation. - # Set by 'SCHEMA_COERCE_METHOD_NAMES'. - coerce_method_names = None - - def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None, version=None): - assert coreapi, '`coreapi` must be installed for schema support.' - if coreapi is not None: - warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning) - assert coreschema, '`coreschema` must be installed for schema support.' - - super().__init__(title, url, description, patterns, urlconf) - self.coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES - - def get_links(self, request=None): - """ - Return a dictionary containing all the links that should be - included in the API schema. - """ - links = LinkNode() - - paths, view_endpoints = self._get_paths_and_endpoints(request) - - # Only generate the path prefix for paths that will be included - if not paths: - return None - prefix = self.determine_path_prefix(paths) - - for path, method, view in view_endpoints: - if not self.has_view_permissions(path, method, view): - continue - link = view.schema.get_link(path, method, base_url=self.url) - subpath = path[len(prefix):] - keys = self.get_keys(subpath, method, view) - insert_into(links, keys, link) - - return links - - def get_schema(self, request=None, public=False): - """ - Generate a `coreapi.Document` representing the API schema. - """ - self._initialise_endpoints() - - links = self.get_links(None if public else request) - if not links: - return None - - url = self.url - if not url and request is not None: - url = request.build_absolute_uri() - - distribute_links(links) - return coreapi.Document( - title=self.title, description=self.description, - url=url, content=links - ) - - # Method for generating the link layout.... - def get_keys(self, subpath, method, view): - """ - Return a list of keys that should be used to layout a link within - the schema document. - - /users/ ("users", "list"), ("users", "create") - /users/{pk}/ ("users", "read"), ("users", "update"), ("users", "delete") - /users/enabled/ ("users", "enabled") # custom viewset list action - /users/{pk}/star/ ("users", "star") # custom viewset detail action - /users/{pk}/groups/ ("users", "groups", "list"), ("users", "groups", "create") - /users/{pk}/groups/{pk}/ ("users", "groups", "read"), ("users", "groups", "update"), ("users", "groups", "delete") - """ - if hasattr(view, 'action'): - # Viewsets have explicitly named actions. - action = view.action - else: - # Views have no associated action, so we determine one from the method. - if is_list_view(subpath, method, view): - action = 'list' - else: - action = self.default_mapping[method.lower()] - - named_path_components = [ - component for component - in subpath.strip('/').split('/') - if '{' not in component - ] - - if is_custom_action(action): - # Custom action, eg "/users/{pk}/activate/", "/users/active/" - mapped_methods = { - # Don't count head mapping, e.g. not part of the schema - method for method in view.action_map if method != 'head' - } - if len(mapped_methods) > 1: - action = self.default_mapping[method.lower()] - if action in self.coerce_method_names: - action = self.coerce_method_names[action] - return named_path_components + [action] - else: - return named_path_components[:-1] + [action] - - if action in self.coerce_method_names: - action = self.coerce_method_names[action] - - # Default action, eg "/users/", "/users/{pk}/" - return named_path_components + [action] - - def determine_path_prefix(self, paths): - """ - Given a list of all paths, return the common prefix which should be - discounted when generating a schema structure. - - This will be the longest common string that does not include that last - component of the URL, or the last component before a path parameter. - - For example: - - /api/v1/users/ - /api/v1/users/{pk}/ - - The path prefix is '/api/v1' - """ - prefixes = [] - for path in paths: - components = path.strip('/').split('/') - initial_components = [] - for component in components: - if '{' in component: - break - initial_components.append(component) - prefix = '/'.join(initial_components[:-1]) - if not prefix: - # We can just break early in the case that there's at least - # one URL that doesn't have a path prefix. - return '/' - prefixes.append('/' + prefix + '/') - return common_path(prefixes) - -# View Inspectors # - - -def field_to_schema(field): - title = force_str(field.label) if field.label else '' - description = force_str(field.help_text) if field.help_text else '' - - if isinstance(field, (serializers.ListSerializer, serializers.ListField)): - child_schema = field_to_schema(field.child) - return coreschema.Array( - items=child_schema, - title=title, - description=description - ) - elif isinstance(field, serializers.DictField): - return coreschema.Object( - title=title, - description=description - ) - elif isinstance(field, serializers.Serializer): - return coreschema.Object( - properties={ - key: field_to_schema(value) - for key, value - in field.fields.items() - }, - title=title, - description=description - ) - elif isinstance(field, serializers.ManyRelatedField): - related_field_schema = field_to_schema(field.child_relation) - - return coreschema.Array( - items=related_field_schema, - title=title, - description=description - ) - elif isinstance(field, serializers.PrimaryKeyRelatedField): - schema_cls = coreschema.String - model = getattr(field.queryset, 'model', None) - if model is not None: - model_field = model._meta.pk - if isinstance(model_field, models.AutoField): - schema_cls = coreschema.Integer - return schema_cls(title=title, description=description) - elif isinstance(field, serializers.RelatedField): - return coreschema.String(title=title, description=description) - elif isinstance(field, serializers.MultipleChoiceField): - return coreschema.Array( - items=coreschema.Enum(enum=list(field.choices)), - title=title, - description=description - ) - elif isinstance(field, serializers.ChoiceField): - return coreschema.Enum( - enum=list(field.choices), - title=title, - description=description - ) - elif isinstance(field, serializers.BooleanField): - return coreschema.Boolean(title=title, description=description) - elif isinstance(field, (serializers.DecimalField, serializers.FloatField)): - return coreschema.Number(title=title, description=description) - elif isinstance(field, serializers.IntegerField): - return coreschema.Integer(title=title, description=description) - elif isinstance(field, serializers.DateField): - return coreschema.String( - title=title, - description=description, - format='date' - ) - elif isinstance(field, serializers.DateTimeField): - return coreschema.String( - title=title, - description=description, - format='date-time' - ) - elif isinstance(field, serializers.JSONField): - return coreschema.Object(title=title, description=description) - - if field.style.get('base_template') == 'textarea.html': - return coreschema.String( - title=title, - description=description, - format='textarea' - ) - - return coreschema.String(title=title, description=description) - - -class AutoSchema(ViewInspector): - """ - Default inspector for APIView - - Responsible for per-view introspection and schema generation. - """ - def __init__(self, manual_fields=None): - """ - Parameters: - - * `manual_fields`: list of `coreapi.Field` instances that - will be added to auto-generated fields, overwriting on `Field.name` - """ - super().__init__() - if coreapi is not None: - warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning) - - if manual_fields is None: - manual_fields = [] - self._manual_fields = manual_fields - - def get_link(self, path, method, base_url): - """ - Generate `coreapi.Link` for self.view, path and method. - - This is the main _public_ access point. - - Parameters: - - * path: Route path for view from URLConf. - * method: The HTTP request method. - * base_url: The project "mount point" as given to SchemaGenerator - """ - fields = self.get_path_fields(path, method) - fields += self.get_serializer_fields(path, method) - fields += self.get_pagination_fields(path, method) - fields += self.get_filter_fields(path, method) - - manual_fields = self.get_manual_fields(path, method) - fields = self.update_fields(fields, manual_fields) - - if fields and any([field.location in ('form', 'body') for field in fields]): - encoding = self.get_encoding(path, method) - else: - encoding = None - - description = self.get_description(path, method) - - if base_url and path.startswith('/'): - path = path[1:] - - return coreapi.Link( - url=parse.urljoin(base_url, path), - action=method.lower(), - encoding=encoding, - fields=fields, - description=description - ) - - def get_path_fields(self, path, method): - """ - Return a list of `coreapi.Field` instances corresponding to any - templated path variables. - """ - view = self.view - model = getattr(getattr(view, 'queryset', None), 'model', None) - fields = [] - - for variable in uritemplate.variables(path): - title = '' - description = '' - schema_cls = coreschema.String - kwargs = {} - if model is not None: - # Attempt to infer a field description if possible. - try: - model_field = model._meta.get_field(variable) - except Exception: - model_field = None - - if model_field is not None and model_field.verbose_name: - title = force_str(model_field.verbose_name) - - if model_field is not None and model_field.help_text: - description = force_str(model_field.help_text) - elif model_field is not None and model_field.primary_key: - description = get_pk_description(model, model_field) - - if hasattr(view, 'lookup_value_regex') and view.lookup_field == variable: - kwargs['pattern'] = view.lookup_value_regex - elif isinstance(model_field, models.AutoField): - schema_cls = coreschema.Integer - - field = coreapi.Field( - name=variable, - location='path', - required=True, - schema=schema_cls(title=title, description=description, **kwargs) - ) - fields.append(field) - - return fields - - def get_serializer_fields(self, path, method): - """ - Return a list of `coreapi.Field` instances corresponding to any - request body input, as determined by the serializer class. - """ - view = self.view - - if method not in ('PUT', 'PATCH', 'POST'): - return [] - - if not hasattr(view, 'get_serializer'): - return [] - - try: - serializer = view.get_serializer() - except exceptions.APIException: - serializer = None - warnings.warn('{}.get_serializer() raised an exception during ' - 'schema generation. Serializer fields will not be ' - 'generated for {} {}.' - .format(view.__class__.__name__, method, path)) - - if isinstance(serializer, serializers.ListSerializer): - return [ - coreapi.Field( - name='data', - location='body', - required=True, - schema=coreschema.Array() - ) - ] - - if not isinstance(serializer, serializers.Serializer): - return [] - - fields = [] - for field in serializer.fields.values(): - if field.read_only or isinstance(field, serializers.HiddenField): - continue - - required = field.required and method != 'PATCH' - field = coreapi.Field( - name=field.field_name, - location='form', - required=required, - schema=field_to_schema(field) - ) - fields.append(field) - - return fields - - def get_pagination_fields(self, path, method): - view = self.view - - if not is_list_view(path, method, view): - return [] - - pagination = getattr(view, 'pagination_class', None) - if not pagination: - return [] - - paginator = view.pagination_class() - return paginator.get_schema_fields(view) - - def _allows_filters(self, path, method): - """ - Determine whether to include filter Fields in schema. - - Default implementation looks for ModelViewSet or GenericAPIView - actions/methods that cause filtering on the default implementation. - - Override to adjust behaviour for your view. - - Note: Introduced in v3.7: Initially "private" (i.e. with leading underscore) - to allow changes based on user experience. - """ - if getattr(self.view, 'filter_backends', None) is None: - return False - - if hasattr(self.view, 'action'): - return self.view.action in ["list", "retrieve", "update", "partial_update", "destroy"] - - return method.lower() in ["get", "put", "patch", "delete"] - - def get_filter_fields(self, path, method): - if not self._allows_filters(path, method): - return [] - - fields = [] - for filter_backend in self.view.filter_backends: - fields += filter_backend().get_schema_fields(self.view) - return fields - - def get_manual_fields(self, path, method): - return self._manual_fields - - @staticmethod - def update_fields(fields, update_with): - """ - Update list of coreapi.Field instances, overwriting on `Field.name`. - - Utility function to handle replacing coreapi.Field fields - from a list by name. Used to handle `manual_fields`. - - Parameters: - - * `fields`: list of `coreapi.Field` instances to update - * `update_with: list of `coreapi.Field` instances to add or replace. - """ - if not update_with: - return fields - - by_name = {f.name: f for f in fields} - for f in update_with: - by_name[f.name] = f - fields = list(by_name.values()) - return fields - - def get_encoding(self, path, method): - """ - Return the 'encoding' parameter to use for a given endpoint. - """ - view = self.view - - # Core API supports the following request encodings over HTTP... - supported_media_types = { - 'application/json', - 'application/x-www-form-urlencoded', - 'multipart/form-data', - } - parser_classes = getattr(view, 'parser_classes', []) - for parser_class in parser_classes: - media_type = getattr(parser_class, 'media_type', None) - if media_type in supported_media_types: - return media_type - # Raw binary uploads are supported with "application/octet-stream" - if media_type == '*/*': - return 'application/octet-stream' - - return None - - -class ManualSchema(ViewInspector): - """ - Allows providing a list of coreapi.Fields, - plus an optional description. - """ - def __init__(self, fields, description='', encoding=None): - """ - Parameters: - - * `fields`: list of `coreapi.Field` instances. - * `description`: String description for view. Optional. - """ - super().__init__() - if coreapi is not None: - warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning) - - assert all(isinstance(f, coreapi.Field) for f in fields), "`fields` must be a list of coreapi.Field instances" - self._fields = fields - self._description = description - self._encoding = encoding - - def get_link(self, path, method, base_url): - - if base_url and path.startswith('/'): - path = path[1:] - - return coreapi.Link( - url=parse.urljoin(base_url, path), - action=method.lower(), - encoding=self._encoding, - fields=self._fields, - description=self._description - ) - - -def is_enabled(): - """Is CoreAPI Mode enabled?""" - if coreapi is not None: - warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning) - return issubclass(api_settings.DEFAULT_SCHEMA_CLASS, AutoSchema) diff --git a/rest_framework/schemas/generators.py b/rest_framework/schemas/generators.py deleted file mode 100644 index f59e25c213..0000000000 --- a/rest_framework/schemas/generators.py +++ /dev/null @@ -1,239 +0,0 @@ -""" -generators.py # Top-down schema generation - -See schemas.__init__.py for package overview. -""" -import re -from importlib import import_module - -from django.conf import settings -from django.contrib.admindocs.views import simplify_regex -from django.core.exceptions import PermissionDenied -from django.http import Http404 -from django.urls import URLPattern, URLResolver - -from rest_framework import exceptions -from rest_framework.request import clone_request -from rest_framework.settings import api_settings -from rest_framework.utils.model_meta import _get_pk - - -def get_pk_name(model): - meta = model._meta.concrete_model._meta - return _get_pk(meta).name - - -def is_api_view(callback): - """ - Return `True` if the given view callback is a REST framework view/viewset. - """ - # Avoid import cycle on APIView - from rest_framework.views import APIView - cls = getattr(callback, 'cls', None) - return (cls is not None) and issubclass(cls, APIView) - - -def endpoint_ordering(endpoint): - path, method, callback = endpoint - method_priority = { - 'GET': 0, - 'POST': 1, - 'PUT': 2, - 'PATCH': 3, - 'DELETE': 4 - }.get(method, 5) - return (method_priority,) - - -_PATH_PARAMETER_COMPONENT_RE = re.compile( - r'<(?:(?P[^>:]+):)?(?P\w+)>' -) - - -class EndpointEnumerator: - """ - A class to determine the available API endpoints that a project exposes. - """ - def __init__(self, patterns=None, urlconf=None): - if patterns is None: - if urlconf is None: - # Use the default Django URL conf - urlconf = settings.ROOT_URLCONF - - # Load the given URLconf module - if isinstance(urlconf, str): - urls = import_module(urlconf) - else: - urls = urlconf - patterns = urls.urlpatterns - - self.patterns = patterns - - def get_api_endpoints(self, patterns=None, prefix=''): - """ - Return a list of all available API endpoints by inspecting the URL conf. - """ - if patterns is None: - patterns = self.patterns - - api_endpoints = [] - - for pattern in patterns: - path_regex = prefix + str(pattern.pattern) - if isinstance(pattern, URLPattern): - path = self.get_path_from_regex(path_regex) - callback = pattern.callback - if self.should_include_endpoint(path, callback): - for method in self.get_allowed_methods(callback): - endpoint = (path, method, callback) - api_endpoints.append(endpoint) - - elif isinstance(pattern, URLResolver): - nested_endpoints = self.get_api_endpoints( - patterns=pattern.url_patterns, - prefix=path_regex - ) - api_endpoints.extend(nested_endpoints) - - return sorted(api_endpoints, key=endpoint_ordering) - - def get_path_from_regex(self, path_regex): - """ - Given a URL conf regex, return a URI template string. - """ - # ???: Would it be feasible to adjust this such that we generate the - # path, plus the kwargs, plus the type from the converter, such that we - # could feed that straight into the parameter schema object? - - path = simplify_regex(path_regex) - - # Strip Django 2.0 converters as they are incompatible with uritemplate format - return re.sub(_PATH_PARAMETER_COMPONENT_RE, r'{\g}', path) - - def should_include_endpoint(self, path, callback): - """ - Return `True` if the given endpoint should be included. - """ - if not is_api_view(callback): - return False # Ignore anything except REST framework views. - - if callback.cls.schema is None: - return False - - if 'schema' in callback.initkwargs: - if callback.initkwargs['schema'] is None: - return False - - if path.endswith('.{format}') or path.endswith('.{format}/'): - return False # Ignore .json style URLs. - - return True - - def get_allowed_methods(self, callback): - """ - Return a list of the valid HTTP methods for this endpoint. - """ - if hasattr(callback, 'actions'): - actions = set(callback.actions) - http_method_names = set(callback.cls.http_method_names) - methods = [method.upper() for method in actions & http_method_names] - else: - methods = callback.cls().allowed_methods - - return [method for method in methods if method not in ('OPTIONS', 'HEAD')] - - -class BaseSchemaGenerator: - endpoint_inspector_cls = EndpointEnumerator - - # 'pk' isn't great as an externally exposed name for an identifier, - # so by default we prefer to use the actual model field name for schemas. - # Set by 'SCHEMA_COERCE_PATH_PK'. - coerce_path_pk = None - - def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None, version=None): - if url and not url.endswith('/'): - url += '/' - - self.coerce_path_pk = api_settings.SCHEMA_COERCE_PATH_PK - - self.patterns = patterns - self.urlconf = urlconf - self.title = title - self.description = description - self.version = version - self.url = url - self.endpoints = None - - def _initialise_endpoints(self): - if self.endpoints is None: - inspector = self.endpoint_inspector_cls(self.patterns, self.urlconf) - self.endpoints = inspector.get_api_endpoints() - - def _get_paths_and_endpoints(self, request): - """ - Generate (path, method, view) given (path, method, callback) for paths. - """ - paths = [] - view_endpoints = [] - for path, method, callback in self.endpoints: - view = self.create_view(callback, method, request) - path = self.coerce_path(path, method, view) - paths.append(path) - view_endpoints.append((path, method, view)) - - return paths, view_endpoints - - def create_view(self, callback, method, request=None): - """ - Given a callback, return an actual view instance. - """ - view = callback.cls(**getattr(callback, 'initkwargs', {})) - view.args = () - view.kwargs = {} - view.format_kwarg = None - view.request = None - view.action_map = getattr(callback, 'actions', None) - - actions = getattr(callback, 'actions', None) - if actions is not None: - if method == 'OPTIONS': - view.action = 'metadata' - else: - view.action = actions.get(method.lower()) - - if request is not None: - view.request = clone_request(request, method) - - return view - - def coerce_path(self, path, method, view): - """ - Coerce {pk} path arguments into the name of the model field, - where possible. This is cleaner for an external representation. - (Ie. "this is an identifier", not "this is a database primary key") - """ - if not self.coerce_path_pk or '{pk}' not in path: - return path - model = getattr(getattr(view, 'queryset', None), 'model', None) - if model: - field_name = get_pk_name(model) - else: - field_name = 'id' - return path.replace('{pk}', '{%s}' % field_name) - - def get_schema(self, request=None, public=False): - raise NotImplementedError(".get_schema() must be implemented in subclasses.") - - def has_view_permissions(self, path, method, view): - """ - Return `True` if the incoming request has the correct view permissions. - """ - if view.request is None: - return True - - try: - view.check_permissions(view.request) - except (exceptions.APIException, Http404, PermissionDenied): - return False - return True diff --git a/rest_framework/schemas/inspectors.py b/rest_framework/schemas/inspectors.py deleted file mode 100644 index e027b46a70..0000000000 --- a/rest_framework/schemas/inspectors.py +++ /dev/null @@ -1,126 +0,0 @@ -""" -inspectors.py # Per-endpoint view introspection - -See schemas.__init__.py for package overview. -""" -import re -from weakref import WeakKeyDictionary - -from django.utils.encoding import smart_str - -from rest_framework.settings import api_settings -from rest_framework.utils import formatting - - -class ViewInspector: - """ - Descriptor class on APIView. - - Provide subclass for per-view schema generation - """ - - # Used in _get_description_section() - header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:') - - def __init__(self): - self.instance_schemas = WeakKeyDictionary() - - def __get__(self, instance, owner): - """ - Enables `ViewInspector` as a Python _Descriptor_. - - This is how `view.schema` knows about `view`. - - `__get__` is called when the descriptor is accessed on the owner. - (That will be when view.schema is called in our case.) - - `owner` is always the owner class. (An APIView, or subclass for us.) - `instance` is the view instance or `None` if accessed from the class, - rather than an instance. - - See: https://docs.python.org/3/howto/descriptor.html for info on - descriptor usage. - """ - if instance in self.instance_schemas: - return self.instance_schemas[instance] - - self.view = instance - return self - - def __set__(self, instance, other): - self.instance_schemas[instance] = other - if other is not None: - other.view = instance - - @property - def view(self): - """View property.""" - assert self._view is not None, ( - "Schema generation REQUIRES a view instance. (Hint: you accessed " - "`schema` from the view class rather than an instance.)" - ) - return self._view - - @view.setter - def view(self, value): - self._view = value - - @view.deleter - def view(self): - self._view = None - - def get_description(self, path, method): - """ - Determine a path description. - - This will be based on the method docstring if one exists, - or else the class docstring. - """ - view = self.view - - method_name = getattr(view, 'action', method.lower()) - method_func = getattr(view, method_name, None) - method_docstring = method_func.__doc__ - if method_func and method_docstring: - # An explicit docstring on the method or action. - return self._get_description_section(view, method.lower(), formatting.dedent(smart_str(method_docstring))) - else: - return self._get_description_section(view, getattr(view, 'action', method.lower()), - view.get_view_description()) - - def _get_description_section(self, view, header, description): - lines = description.splitlines() - current_section = '' - sections = {'': ''} - - for line in lines: - if self.header_regex.match(line): - current_section, separator, lead = line.partition(':') - sections[current_section] = lead.strip() - else: - sections[current_section] += '\n' + line - - # TODO: SCHEMA_COERCE_METHOD_NAMES appears here and in `SchemaGenerator.get_keys` - coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES - if header in sections: - return sections[header].strip() - if header in coerce_method_names: - if coerce_method_names[header] in sections: - return sections[coerce_method_names[header]].strip() - return sections[''].strip() - - -class DefaultSchema(ViewInspector): - """Allows overriding AutoSchema using DEFAULT_SCHEMA_CLASS setting""" - def __get__(self, instance, owner): - result = super().__get__(instance, owner) - if not isinstance(result, DefaultSchema): - return result - - inspector_class = api_settings.DEFAULT_SCHEMA_CLASS - assert issubclass(inspector_class, ViewInspector), ( - "DEFAULT_SCHEMA_CLASS must be set to a ViewInspector (usually an AutoSchema) subclass" - ) - inspector = inspector_class() - inspector.view = instance - return inspector diff --git a/rest_framework/schemas/openapi.py b/rest_framework/schemas/openapi.py deleted file mode 100644 index eb7dc909d9..0000000000 --- a/rest_framework/schemas/openapi.py +++ /dev/null @@ -1,721 +0,0 @@ -import re -import warnings -from decimal import Decimal -from operator import attrgetter -from urllib.parse import urljoin - -from django.core.validators import ( - DecimalValidator, EmailValidator, MaxLengthValidator, MaxValueValidator, - MinLengthValidator, MinValueValidator, RegexValidator, URLValidator -) -from django.db import models -from django.utils.encoding import force_str - -from rest_framework import exceptions, renderers, serializers -from rest_framework.compat import inflection, uritemplate -from rest_framework.fields import _UnvalidatedField, empty -from rest_framework.settings import api_settings - -from .generators import BaseSchemaGenerator -from .inspectors import ViewInspector -from .utils import get_pk_description, is_list_view - - -class SchemaGenerator(BaseSchemaGenerator): - - def get_info(self): - # Title and version are required by openapi specification 3.x - info = { - 'title': self.title or '', - 'version': self.version or '' - } - - if self.description is not None: - info['description'] = self.description - - return info - - def check_duplicate_operation_id(self, paths): - ids = {} - for route in paths: - for method in paths[route]: - if 'operationId' not in paths[route][method]: - continue - operation_id = paths[route][method]['operationId'] - if operation_id in ids: - warnings.warn( - 'You have a duplicated operationId in your OpenAPI schema: {operation_id}\n' - '\tRoute: {route1}, Method: {method1}\n' - '\tRoute: {route2}, Method: {method2}\n' - '\tAn operationId has to be unique across your schema. Your schema may not work in other tools.' - .format( - route1=ids[operation_id]['route'], - method1=ids[operation_id]['method'], - route2=route, - method2=method, - operation_id=operation_id - ) - ) - ids[operation_id] = { - 'route': route, - 'method': method - } - - def get_schema(self, request=None, public=False): - """ - Generate a OpenAPI schema. - """ - self._initialise_endpoints() - components_schemas = {} - - # Iterate endpoints generating per method path operations. - paths = {} - _, view_endpoints = self._get_paths_and_endpoints(None if public else request) - for path, method, view in view_endpoints: - if not self.has_view_permissions(path, method, view): - continue - - operation = view.schema.get_operation(path, method) - components = view.schema.get_components(path, method) - for k in components.keys(): - if k not in components_schemas: - continue - if components_schemas[k] == components[k]: - continue - warnings.warn(f'Schema component "{k}" has been overridden with a different value.') - - components_schemas.update(components) - - # Normalise path for any provided mount url. - if path.startswith('/'): - path = path[1:] - path = urljoin(self.url or '/', path) - - paths.setdefault(path, {}) - paths[path][method.lower()] = operation - - self.check_duplicate_operation_id(paths) - - # Compile final schema. - schema = { - 'openapi': '3.0.2', - 'info': self.get_info(), - 'paths': paths, - } - - if len(components_schemas) > 0: - schema['components'] = { - 'schemas': components_schemas - } - - return schema - -# View Inspectors - - -class AutoSchema(ViewInspector): - - def __init__(self, tags=None, operation_id_base=None, component_name=None): - """ - :param operation_id_base: user-defined name in operationId. If empty, it will be deducted from the Model/Serializer/View name. - :param component_name: user-defined component's name. If empty, it will be deducted from the Serializer's class name. - """ - if tags and not all(isinstance(tag, str) for tag in tags): - raise ValueError('tags must be a list or tuple of string.') - self._tags = tags - self.operation_id_base = operation_id_base - self.component_name = component_name - super().__init__() - - request_media_types = [] - response_media_types = [] - - method_mapping = { - 'get': 'retrieve', - 'post': 'create', - 'put': 'update', - 'patch': 'partialUpdate', - 'delete': 'destroy', - } - - def get_operation(self, path, method): - operation = {} - - operation['operationId'] = self.get_operation_id(path, method) - operation['description'] = self.get_description(path, method) - - parameters = [] - parameters += self.get_path_parameters(path, method) - parameters += self.get_pagination_parameters(path, method) - parameters += self.get_filter_parameters(path, method) - operation['parameters'] = parameters - - request_body = self.get_request_body(path, method) - if request_body: - operation['requestBody'] = request_body - operation['responses'] = self.get_responses(path, method) - operation['tags'] = self.get_tags(path, method) - - return operation - - def get_component_name(self, serializer): - """ - Compute the component's name from the serializer. - Raise an exception if the serializer's class name is "Serializer" (case-insensitive). - """ - if self.component_name is not None: - return self.component_name - - # use the serializer's class name as the component name. - component_name = serializer.__class__.__name__ - # We remove the "serializer" string from the class name. - pattern = re.compile("serializer", re.IGNORECASE) - component_name = pattern.sub("", component_name) - - if component_name == "": - raise Exception( - '"{}" is an invalid class name for schema generation. ' - 'Serializer\'s class name should be unique and explicit. e.g. "ItemSerializer"' - .format(serializer.__class__.__name__) - ) - - return component_name - - def get_components(self, path, method): - """ - Return components with their properties from the serializer. - """ - - if method.lower() == 'delete': - return {} - - request_serializer = self.get_request_serializer(path, method) - response_serializer = self.get_response_serializer(path, method) - - components = {} - - if isinstance(request_serializer, serializers.Serializer): - component_name = self.get_component_name(request_serializer) - content = self.map_serializer(request_serializer) - components.setdefault(component_name, content) - - if isinstance(response_serializer, serializers.Serializer): - component_name = self.get_component_name(response_serializer) - content = self.map_serializer(response_serializer) - components.setdefault(component_name, content) - - return components - - def _to_camel_case(self, snake_str): - components = snake_str.split('_') - # We capitalize the first letter of each component except the first one - # with the 'title' method and join them together. - return components[0] + ''.join(x.title() for x in components[1:]) - - def get_operation_id_base(self, path, method, action): - """ - Compute the base part for operation ID from the model, serializer or view name. - """ - model = getattr(getattr(self.view, 'queryset', None), 'model', None) - - if self.operation_id_base is not None: - name = self.operation_id_base - - # Try to deduce the ID from the view's model - elif model is not None: - name = model.__name__ - - # Try with the serializer class name - elif self.get_serializer(path, method) is not None: - name = self.get_serializer(path, method).__class__.__name__ - if name.endswith('Serializer'): - name = name[:-10] - - # Fallback to the view name - else: - name = self.view.__class__.__name__ - if name.endswith('APIView'): - name = name[:-7] - elif name.endswith('View'): - name = name[:-4] - - # Due to camel-casing of classes and `action` being lowercase, apply title in order to find if action truly - # comes at the end of the name - if name.endswith(action.title()): # ListView, UpdateAPIView, ThingDelete ... - name = name[:-len(action)] - - if action == 'list': - assert inflection, '`inflection` must be installed for OpenAPI schema support.' - name = inflection.pluralize(name) - - return name - - def get_operation_id(self, path, method): - """ - Compute an operation ID from the view type and get_operation_id_base method. - """ - method_name = getattr(self.view, 'action', method.lower()) - if is_list_view(path, method, self.view): - action = 'list' - elif method_name not in self.method_mapping: - action = self._to_camel_case(method_name) - else: - action = self.method_mapping[method.lower()] - - name = self.get_operation_id_base(path, method, action) - - return action + name - - def get_path_parameters(self, path, method): - """ - Return a list of parameters from templated path variables. - """ - assert uritemplate, '`uritemplate` must be installed for OpenAPI schema support.' - - model = getattr(getattr(self.view, 'queryset', None), 'model', None) - parameters = [] - - for variable in uritemplate.variables(path): - description = '' - if model is not None: # TODO: test this. - # Attempt to infer a field description if possible. - try: - model_field = model._meta.get_field(variable) - except Exception: - model_field = None - - if model_field is not None and model_field.help_text: - description = force_str(model_field.help_text) - elif model_field is not None and model_field.primary_key: - description = get_pk_description(model, model_field) - - parameter = { - "name": variable, - "in": "path", - "required": True, - "description": description, - 'schema': { - 'type': 'string', # TODO: integer, pattern, ... - }, - } - parameters.append(parameter) - - return parameters - - def get_filter_parameters(self, path, method): - if not self.allows_filters(path, method): - return [] - parameters = [] - for filter_backend in self.view.filter_backends: - parameters += filter_backend().get_schema_operation_parameters(self.view) - return parameters - - def allows_filters(self, path, method): - """ - Determine whether to include filter Fields in schema. - - Default implementation looks for ModelViewSet or GenericAPIView - actions/methods that cause filtering on the default implementation. - """ - if getattr(self.view, 'filter_backends', None) is None: - return False - if hasattr(self.view, 'action'): - return self.view.action in ["list", "retrieve", "update", "partial_update", "destroy"] - return method.lower() in ["get", "put", "patch", "delete"] - - def get_pagination_parameters(self, path, method): - view = self.view - - if not is_list_view(path, method, view): - return [] - - paginator = self.get_paginator() - if not paginator: - return [] - - return paginator.get_schema_operation_parameters(view) - - def map_choicefield(self, field): - choices = list(dict.fromkeys(field.choices)) # preserve order and remove duplicates - if all(isinstance(choice, bool) for choice in choices): - type = 'boolean' - elif all(isinstance(choice, int) for choice in choices): - type = 'integer' - elif all(isinstance(choice, (int, float, Decimal)) for choice in choices): # `number` includes `integer` - # Ref: https://tools.ietf.org/html/draft-wright-json-schema-validation-00#section-5.21 - type = 'number' - elif all(isinstance(choice, str) for choice in choices): - type = 'string' - else: - type = None - - mapping = { - # The value of `enum` keyword MUST be an array and SHOULD be unique. - # Ref: https://tools.ietf.org/html/draft-wright-json-schema-validation-00#section-5.20 - 'enum': choices - } - - # If We figured out `type` then and only then we should set it. It must be a string. - # Ref: https://swagger.io/docs/specification/data-models/data-types/#mixed-type - # It is optional but it can not be null. - # Ref: https://tools.ietf.org/html/draft-wright-json-schema-validation-00#section-5.21 - if type: - mapping['type'] = type - return mapping - - def map_field(self, field): - - # Nested Serializers, `many` or not. - if isinstance(field, serializers.ListSerializer): - return { - 'type': 'array', - 'items': self.map_serializer(field.child) - } - if isinstance(field, serializers.Serializer): - data = self.map_serializer(field) - data['type'] = 'object' - return data - - # Related fields. - if isinstance(field, serializers.ManyRelatedField): - return { - 'type': 'array', - 'items': self.map_field(field.child_relation) - } - if isinstance(field, serializers.PrimaryKeyRelatedField): - if getattr(field, "pk_field", False): - return self.map_field(field=field.pk_field) - model = getattr(field.queryset, 'model', None) - if model is not None: - model_field = model._meta.pk - if isinstance(model_field, models.AutoField): - return {'type': 'integer'} - - # ChoiceFields (single and multiple). - # Q: - # - Is 'type' required? - # - can we determine the TYPE of a choicefield? - if isinstance(field, serializers.MultipleChoiceField): - return { - 'type': 'array', - 'items': self.map_choicefield(field) - } - - if isinstance(field, serializers.ChoiceField): - return self.map_choicefield(field) - - # ListField. - if isinstance(field, serializers.ListField): - mapping = { - 'type': 'array', - 'items': {}, - } - if not isinstance(field.child, _UnvalidatedField): - mapping['items'] = self.map_field(field.child) - return mapping - - # DateField and DateTimeField type is string - if isinstance(field, serializers.DateField): - return { - 'type': 'string', - 'format': 'date', - } - - if isinstance(field, serializers.DateTimeField): - return { - 'type': 'string', - 'format': 'date-time', - } - - # "Formats such as "email", "uuid", and so on, MAY be used even though undefined by this specification." - # see: https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#data-types - # see also: https://swagger.io/docs/specification/data-models/data-types/#string - if isinstance(field, serializers.EmailField): - return { - 'type': 'string', - 'format': 'email' - } - - if isinstance(field, serializers.URLField): - return { - 'type': 'string', - 'format': 'uri' - } - - if isinstance(field, serializers.UUIDField): - return { - 'type': 'string', - 'format': 'uuid' - } - - if isinstance(field, serializers.IPAddressField): - content = { - 'type': 'string', - } - if field.protocol != 'both': - content['format'] = field.protocol - return content - - if isinstance(field, serializers.DecimalField): - if getattr(field, 'coerce_to_string', api_settings.COERCE_DECIMAL_TO_STRING): - content = { - 'type': 'string', - 'format': 'decimal', - } - else: - content = { - 'type': 'number' - } - - if field.decimal_places: - content['multipleOf'] = float('.' + (field.decimal_places - 1) * '0' + '1') - if field.max_whole_digits: - content['maximum'] = int(field.max_whole_digits * '9') + 1 - content['minimum'] = -content['maximum'] - self._map_min_max(field, content) - return content - - if isinstance(field, serializers.FloatField): - content = { - 'type': 'number', - } - self._map_min_max(field, content) - return content - - if isinstance(field, serializers.IntegerField): - content = { - 'type': 'integer' - } - self._map_min_max(field, content) - # 2147483647 is max for int32_size, so we use int64 for format - if int(content.get('maximum', 0)) > 2147483647 or int(content.get('minimum', 0)) > 2147483647: - content['format'] = 'int64' - return content - - if isinstance(field, serializers.FileField): - return { - 'type': 'string', - 'format': 'binary' - } - - # Simplest cases, default to 'string' type: - FIELD_CLASS_SCHEMA_TYPE = { - serializers.BooleanField: 'boolean', - serializers.JSONField: 'object', - serializers.DictField: 'object', - serializers.HStoreField: 'object', - } - return {'type': FIELD_CLASS_SCHEMA_TYPE.get(field.__class__, 'string')} - - def _map_min_max(self, field, content): - if field.max_value: - content['maximum'] = field.max_value - if field.min_value: - content['minimum'] = field.min_value - - def map_serializer(self, serializer): - # Assuming we have a valid serializer instance. - required = [] - properties = {} - - for field in serializer.fields.values(): - if isinstance(field, serializers.HiddenField): - continue - - if field.required and not serializer.partial: - required.append(self.get_field_name(field)) - - schema = self.map_field(field) - if field.read_only: - schema['readOnly'] = True - if field.write_only: - schema['writeOnly'] = True - if field.allow_null: - schema['nullable'] = True - if field.default is not None and field.default != empty and not callable(field.default): - schema['default'] = field.default - if field.help_text: - schema['description'] = str(field.help_text) - self.map_field_validators(field, schema) - - properties[self.get_field_name(field)] = schema - - result = { - 'type': 'object', - 'properties': properties - } - if required: - result['required'] = required - - return result - - def map_field_validators(self, field, schema): - """ - map field validators - """ - for v in field.validators: - # "Formats such as "email", "uuid", and so on, MAY be used even though undefined by this specification." - # https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#data-types - if isinstance(v, EmailValidator): - schema['format'] = 'email' - if isinstance(v, URLValidator): - schema['format'] = 'uri' - if isinstance(v, RegexValidator): - # In Python, the token \Z does what \z does in other engines. - # https://stackoverflow.com/questions/53283160 - schema['pattern'] = v.regex.pattern.replace('\\Z', '\\z') - elif isinstance(v, MaxLengthValidator): - attr_name = 'maxLength' - if isinstance(field, serializers.ListField): - attr_name = 'maxItems' - schema[attr_name] = v.limit_value - elif isinstance(v, MinLengthValidator): - attr_name = 'minLength' - if isinstance(field, serializers.ListField): - attr_name = 'minItems' - schema[attr_name] = v.limit_value - elif isinstance(v, MaxValueValidator): - schema['maximum'] = v.limit_value - elif isinstance(v, MinValueValidator): - schema['minimum'] = v.limit_value - elif isinstance(v, DecimalValidator) and \ - not getattr(field, 'coerce_to_string', api_settings.COERCE_DECIMAL_TO_STRING): - if v.decimal_places: - schema['multipleOf'] = float('.' + (v.decimal_places - 1) * '0' + '1') - if v.max_digits: - digits = v.max_digits - if v.decimal_places is not None and v.decimal_places > 0: - digits -= v.decimal_places - schema['maximum'] = int(digits * '9') + 1 - schema['minimum'] = -schema['maximum'] - - def get_field_name(self, field): - """ - Override this method if you want to change schema field name. - For example, convert snake_case field name to camelCase. - """ - return field.field_name - - def get_paginator(self): - pagination_class = getattr(self.view, 'pagination_class', None) - if pagination_class: - return pagination_class() - return None - - def map_parsers(self, path, method): - return list(map(attrgetter('media_type'), self.view.parser_classes)) - - def map_renderers(self, path, method): - media_types = [] - for renderer in self.view.renderer_classes: - # BrowsableAPIRenderer not relevant to OpenAPI spec - if issubclass(renderer, renderers.BrowsableAPIRenderer): - continue - media_types.append(renderer.media_type) - return media_types - - def get_serializer(self, path, method): - view = self.view - - if not hasattr(view, 'get_serializer'): - return None - - try: - return view.get_serializer() - except exceptions.APIException: - warnings.warn('{}.get_serializer() raised an exception during ' - 'schema generation. Serializer fields will not be ' - 'generated for {} {}.' - .format(view.__class__.__name__, method, path)) - return None - - def get_request_serializer(self, path, method): - """ - Override this method if your view uses a different serializer for - handling request body. - """ - return self.get_serializer(path, method) - - def get_response_serializer(self, path, method): - """ - Override this method if your view uses a different serializer for - populating response data. - """ - return self.get_serializer(path, method) - - def get_reference(self, serializer): - return {'$ref': f'#/components/schemas/{self.get_component_name(serializer)}'} - - def get_request_body(self, path, method): - if method not in ('PUT', 'PATCH', 'POST'): - return {} - - self.request_media_types = self.map_parsers(path, method) - - serializer = self.get_request_serializer(path, method) - - if not isinstance(serializer, serializers.Serializer): - item_schema = {} - else: - item_schema = self.get_reference(serializer) - - return { - 'content': { - ct: {'schema': item_schema} - for ct in self.request_media_types - } - } - - def get_responses(self, path, method): - if method == 'DELETE': - return { - '204': { - 'description': '' - } - } - - self.response_media_types = self.map_renderers(path, method) - - serializer = self.get_response_serializer(path, method) - - if not isinstance(serializer, serializers.Serializer): - item_schema = {} - else: - item_schema = self.get_reference(serializer) - - if is_list_view(path, method, self.view): - response_schema = { - 'type': 'array', - 'items': item_schema, - } - paginator = self.get_paginator() - if paginator: - response_schema = paginator.get_paginated_response_schema(response_schema) - else: - response_schema = item_schema - status_code = '201' if method == 'POST' else '200' - return { - status_code: { - 'content': { - ct: {'schema': response_schema} - for ct in self.response_media_types - }, - # description is a mandatory property, - # https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#responseObject - # TODO: put something meaningful into it - 'description': "" - } - } - - def get_tags(self, path, method): - # If user have specified tags, use them. - if self._tags: - return self._tags - - # First element of a specific path could be valid tag. This is a fallback solution. - # PUT, PATCH, GET(Retrieve), DELETE: /user_profile/{id}/ tags = [user-profile] - # POST, GET(List): /user_profile/ tags = [user-profile] - if path.startswith('/'): - path = path[1:] - - return [path.split('/')[0].replace('_', '-')] diff --git a/rest_framework/schemas/utils.py b/rest_framework/schemas/utils.py deleted file mode 100644 index 60ed698294..0000000000 --- a/rest_framework/schemas/utils.py +++ /dev/null @@ -1,41 +0,0 @@ -""" -utils.py # Shared helper functions - -See schemas.__init__.py for package overview. -""" -from django.db import models -from django.utils.translation import gettext_lazy as _ - -from rest_framework.mixins import RetrieveModelMixin - - -def is_list_view(path, method, view): - """ - Return True if the given path/method appears to represent a list view. - """ - if hasattr(view, 'action'): - # Viewsets have an explicitly defined action, which we can inspect. - return view.action == 'list' - - if method.lower() != 'get': - return False - if isinstance(view, RetrieveModelMixin): - return False - path_components = path.strip('/').split('/') - if path_components and '{' in path_components[-1]: - return False - return True - - -def get_pk_description(model, model_field): - if isinstance(model_field, models.AutoField): - value_type = _('unique integer value') - elif isinstance(model_field, models.UUIDField): - value_type = _('UUID string') - else: - value_type = _('unique value') - - return _('A {value_type} identifying this {name}.').format( - value_type=value_type, - name=model._meta.verbose_name, - ) diff --git a/rest_framework/schemas/views.py b/rest_framework/schemas/views.py deleted file mode 100644 index 527a23236f..0000000000 --- a/rest_framework/schemas/views.py +++ /dev/null @@ -1,48 +0,0 @@ -""" -views.py # Houses `SchemaView`, `APIView` subclass. - -See schemas.__init__.py for package overview. -""" -from rest_framework import exceptions, renderers -from rest_framework.response import Response -from rest_framework.schemas import coreapi -from rest_framework.settings import api_settings -from rest_framework.views import APIView - - -class SchemaView(APIView): - _ignore_model_permissions = True - schema = None # exclude from schema - renderer_classes = None - schema_generator = None - public = False - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - if self.renderer_classes is None: - if coreapi.is_enabled(): - self.renderer_classes = [ - renderers.CoreAPIOpenAPIRenderer, - renderers.CoreJSONRenderer - ] - else: - self.renderer_classes = [ - renderers.OpenAPIRenderer, - renderers.JSONOpenAPIRenderer, - ] - if renderers.BrowsableAPIRenderer in api_settings.DEFAULT_RENDERER_CLASSES: - self.renderer_classes += [renderers.BrowsableAPIRenderer] - - def get(self, request, *args, **kwargs): - schema = self.schema_generator.get_schema(request, self.public) - if schema is None: - raise exceptions.PermissionDenied() - return Response(schema) - - def handle_exception(self, exc): - # Schema renderers do not render exceptions, so re-perform content - # negotiation with default renderers. - self.renderer_classes = api_settings.DEFAULT_RENDERER_CLASSES - neg = self.perform_content_negotiation(self.request, force=True) - self.request.accepted_renderer, self.request.accepted_media_type = neg - return super().handle_exception(exc) diff --git a/tests/schemas/__init__.py b/tests/schemas/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/schemas/test_coreapi.py b/tests/schemas/test_coreapi.py deleted file mode 100644 index a97b02fe1f..0000000000 --- a/tests/schemas/test_coreapi.py +++ /dev/null @@ -1,1486 +0,0 @@ -import unittest - -import pytest -from django.core.exceptions import PermissionDenied -from django.http import Http404 -from django.test import TestCase, override_settings -from django.urls import include, path - -from rest_framework import ( - RemovedInDRF317Warning, filters, generics, pagination, permissions, - serializers -) -from rest_framework.compat import coreapi, coreschema -from rest_framework.decorators import action, api_view, schema -from rest_framework.filters import ( - BaseFilterBackend, OrderingFilter, SearchFilter -) -from rest_framework.pagination import ( - BasePagination, CursorPagination, LimitOffsetPagination, - PageNumberPagination -) -from rest_framework.request import Request -from rest_framework.routers import DefaultRouter, SimpleRouter -from rest_framework.schemas import ( - AutoSchema, ManualSchema, SchemaGenerator, get_schema_view -) -from rest_framework.schemas.coreapi import field_to_schema, is_enabled -from rest_framework.schemas.generators import EndpointEnumerator -from rest_framework.schemas.utils import is_list_view -from rest_framework.test import APIClient, APIRequestFactory -from rest_framework.utils import formatting -from rest_framework.views import APIView -from rest_framework.viewsets import GenericViewSet, ModelViewSet - -from ..models import BasicModel, ForeignKeySource, ManyToManySource -from . import views - -factory = APIRequestFactory() - - -class MockUser: - def is_authenticated(self): - return True - - -class ExamplePagination(pagination.PageNumberPagination): - page_size = 100 - page_size_query_param = 'page_size' - - -class EmptySerializer(serializers.Serializer): - pass - - -class ExampleSerializer(serializers.Serializer): - a = serializers.CharField(required=True, help_text='A field description') - b = serializers.CharField(required=False) - read_only = serializers.CharField(read_only=True) - hidden = serializers.HiddenField(default='hello') - - -class AnotherSerializerWithDictField(serializers.Serializer): - a = serializers.DictField() - - -class AnotherSerializerWithListFields(serializers.Serializer): - a = serializers.ListField(child=serializers.IntegerField()) - b = serializers.ListSerializer(child=serializers.CharField()) - - -class AnotherSerializer(serializers.Serializer): - c = serializers.CharField(required=True) - d = serializers.CharField(required=False) - - -class ExampleViewSet(ModelViewSet): - pagination_class = ExamplePagination - permission_classes = [permissions.IsAuthenticatedOrReadOnly] - filter_backends = [filters.OrderingFilter] - serializer_class = ExampleSerializer - - @action(methods=['post'], detail=True, serializer_class=AnotherSerializer) - def custom_action(self, request, pk): - """ - A description of custom action. - """ - raise NotImplementedError - - @action(methods=['post'], detail=True, serializer_class=AnotherSerializerWithDictField) - def custom_action_with_dict_field(self, request, pk): - """ - A custom action using a dict field in the serializer. - """ - raise NotImplementedError - - @action(methods=['post'], detail=True, serializer_class=AnotherSerializerWithListFields) - def custom_action_with_list_fields(self, request, pk): - """ - A custom action using both list field and list serializer in the serializer. - """ - raise NotImplementedError - - @action(detail=False) - def custom_list_action(self, request): - raise NotImplementedError - - @action(methods=['post', 'get'], detail=False, serializer_class=EmptySerializer) - def custom_list_action_multiple_methods(self, request): - """Custom description.""" - raise NotImplementedError - - @custom_list_action_multiple_methods.mapping.delete - def custom_list_action_multiple_methods_delete(self, request): - """Deletion description.""" - raise NotImplementedError - - @action(detail=False, schema=None) - def excluded_action(self, request): - pass - - def get_serializer(self, *args, **kwargs): - assert self.request - assert self.action - return super().get_serializer(*args, **kwargs) - - @action(methods=['get', 'post'], detail=False) - def documented_custom_action(self, request): - """ - get: - A description of the get method on the custom action. - - post: - A description of the post method on the custom action. - """ - pass - - @documented_custom_action.mapping.put - def put_documented_custom_action(self, request, *args, **kwargs): - """ - A description of the put method on the custom action from mapping. - """ - pass - - -with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}): - if coreapi: - schema_view = get_schema_view(title='Example API') - else: - def schema_view(request): - pass - -router = DefaultRouter() -router.register('example', ExampleViewSet, basename='example') -urlpatterns = [ - path('', schema_view), - path('', include(router.urls)) -] - - -@unittest.skipUnless(coreapi, 'coreapi is not installed') -@override_settings(ROOT_URLCONF=__name__, REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) -class TestRouterGeneratedSchema(TestCase): - def test_anonymous_request(self): - client = APIClient() - response = client.get('/', HTTP_ACCEPT='application/coreapi+json') - assert response.status_code == 200 - expected = coreapi.Document( - url='http://testserver/', - title='Example API', - content={ - 'example': { - 'list': coreapi.Link( - url='/example/', - action='get', - fields=[ - coreapi.Field('page', required=False, location='query', schema=coreschema.Integer(title='Page', description='A page number within the paginated result set.')), - coreapi.Field('page_size', required=False, location='query', schema=coreschema.Integer(title='Page size', description='Number of results to return per page.')), - coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.')) - ] - ), - 'custom_list_action': coreapi.Link( - url='/example/custom_list_action/', - action='get' - ), - 'custom_list_action_multiple_methods': { - 'read': coreapi.Link( - url='/example/custom_list_action_multiple_methods/', - action='get', - description='Custom description.', - ) - }, - 'documented_custom_action': { - 'read': coreapi.Link( - url='/example/documented_custom_action/', - action='get', - description='A description of the get method on the custom action.', - ) - }, - 'read': coreapi.Link( - url='/example/{id}/', - action='get', - fields=[ - coreapi.Field('id', required=True, location='path', schema=coreschema.String()), - coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.')) - ] - ) - } - } - ) - assert response.data == expected - - def test_authenticated_request(self): - client = APIClient() - client.force_authenticate(MockUser()) - response = client.get('/', HTTP_ACCEPT='application/coreapi+json') - assert response.status_code == 200 - expected = coreapi.Document( - url='http://testserver/', - title='Example API', - content={ - 'example': { - 'list': coreapi.Link( - url='/example/', - action='get', - fields=[ - coreapi.Field('page', required=False, location='query', schema=coreschema.Integer(title='Page', description='A page number within the paginated result set.')), - coreapi.Field('page_size', required=False, location='query', schema=coreschema.Integer(title='Page size', description='Number of results to return per page.')), - coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.')) - ] - ), - 'create': coreapi.Link( - url='/example/', - action='post', - encoding='application/json', - fields=[ - coreapi.Field('a', required=True, location='form', schema=coreschema.String(title='A', description='A field description')), - coreapi.Field('b', required=False, location='form', schema=coreschema.String(title='B')) - ] - ), - 'read': coreapi.Link( - url='/example/{id}/', - action='get', - fields=[ - coreapi.Field('id', required=True, location='path', schema=coreschema.String()), - coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.')) - ] - ), - 'custom_action': coreapi.Link( - url='/example/{id}/custom_action/', - action='post', - encoding='application/json', - description='A description of custom action.', - fields=[ - coreapi.Field('id', required=True, location='path', schema=coreschema.String()), - coreapi.Field('c', required=True, location='form', schema=coreschema.String(title='C')), - coreapi.Field('d', required=False, location='form', schema=coreschema.String(title='D')), - ] - ), - 'custom_action_with_dict_field': coreapi.Link( - url='/example/{id}/custom_action_with_dict_field/', - action='post', - encoding='application/json', - description='A custom action using a dict field in the serializer.', - fields=[ - coreapi.Field('id', required=True, location='path', schema=coreschema.String()), - coreapi.Field('a', required=True, location='form', schema=coreschema.Object(title='A')), - ] - ), - 'custom_action_with_list_fields': coreapi.Link( - url='/example/{id}/custom_action_with_list_fields/', - action='post', - encoding='application/json', - description='A custom action using both list field and list serializer in the serializer.', - fields=[ - coreapi.Field('id', required=True, location='path', schema=coreschema.String()), - coreapi.Field('a', required=True, location='form', schema=coreschema.Array(title='A', items=coreschema.Integer())), - coreapi.Field('b', required=True, location='form', schema=coreschema.Array(title='B', items=coreschema.String())), - ] - ), - 'custom_list_action': coreapi.Link( - url='/example/custom_list_action/', - action='get' - ), - 'custom_list_action_multiple_methods': { - 'read': coreapi.Link( - url='/example/custom_list_action_multiple_methods/', - action='get', - description='Custom description.', - ), - 'create': coreapi.Link( - url='/example/custom_list_action_multiple_methods/', - action='post', - description='Custom description.', - ), - 'delete': coreapi.Link( - url='/example/custom_list_action_multiple_methods/', - action='delete', - description='Deletion description.', - ), - }, - 'documented_custom_action': { - 'read': coreapi.Link( - url='/example/documented_custom_action/', - action='get', - description='A description of the get method on the custom action.', - ), - 'create': coreapi.Link( - url='/example/documented_custom_action/', - action='post', - description='A description of the post method on the custom action.', - encoding='application/json', - fields=[ - coreapi.Field('a', required=True, location='form', schema=coreschema.String(title='A', description='A field description')), - coreapi.Field('b', required=False, location='form', schema=coreschema.String(title='B')) - ] - ), - 'update': coreapi.Link( - url='/example/documented_custom_action/', - action='put', - description='A description of the put method on the custom action from mapping.', - encoding='application/json', - fields=[ - coreapi.Field('a', required=True, location='form', schema=coreschema.String(title='A', description='A field description')), - coreapi.Field('b', required=False, location='form', schema=coreschema.String(title='B')) - ] - ), - }, - 'update': coreapi.Link( - url='/example/{id}/', - action='put', - encoding='application/json', - fields=[ - coreapi.Field('id', required=True, location='path', schema=coreschema.String()), - coreapi.Field('a', required=True, location='form', schema=coreschema.String(title='A', description=('A field description'))), - coreapi.Field('b', required=False, location='form', schema=coreschema.String(title='B')), - coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.')) - ] - ), - 'partial_update': coreapi.Link( - url='/example/{id}/', - action='patch', - encoding='application/json', - fields=[ - coreapi.Field('id', required=True, location='path', schema=coreschema.String()), - coreapi.Field('a', required=False, location='form', schema=coreschema.String(title='A', description='A field description')), - coreapi.Field('b', required=False, location='form', schema=coreschema.String(title='B')), - coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.')) - ] - ), - 'delete': coreapi.Link( - url='/example/{id}/', - action='delete', - fields=[ - coreapi.Field('id', required=True, location='path', schema=coreschema.String()), - coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.')) - ] - ) - } - } - ) - assert response.data == expected - - -class DenyAllUsingHttp404(permissions.BasePermission): - - def has_permission(self, request, view): - raise Http404() - - def has_object_permission(self, request, view, obj): - raise Http404() - - -class DenyAllUsingPermissionDenied(permissions.BasePermission): - - def has_permission(self, request, view): - raise PermissionDenied() - - def has_object_permission(self, request, view, obj): - raise PermissionDenied() - - -class Http404ExampleViewSet(ExampleViewSet): - permission_classes = [DenyAllUsingHttp404] - - -class PermissionDeniedExampleViewSet(ExampleViewSet): - permission_classes = [DenyAllUsingPermissionDenied] - - -class MethodLimitedViewSet(ExampleViewSet): - permission_classes = [] - http_method_names = ['get', 'head', 'options'] - - -class ExampleListView(APIView): - permission_classes = [permissions.IsAuthenticatedOrReadOnly] - - def get(self, *args, **kwargs): - pass - - def post(self, request, *args, **kwargs): - pass - - -class ExampleDetailView(APIView): - permission_classes = [permissions.IsAuthenticatedOrReadOnly] - - def get(self, *args, **kwargs): - pass - - -@unittest.skipUnless(coreapi, 'coreapi is not installed') -@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) -class TestSchemaGenerator(TestCase): - def setUp(self): - self.patterns = [ - path('example/', views.ExampleListView.as_view()), - path('example//', views.ExampleDetailView.as_view()), - path('example//sub/', views.ExampleDetailView.as_view()), - ] - - def test_schema_for_regular_views(self): - """ - Ensure that schema generation works for APIView classes. - """ - generator = SchemaGenerator(title='Example API', patterns=self.patterns) - schema = generator.get_schema() - expected = coreapi.Document( - url='', - title='Example API', - content={ - 'example': { - 'create': coreapi.Link( - url='/example/', - action='post', - fields=[] - ), - 'list': coreapi.Link( - url='/example/', - action='get', - fields=[] - ), - 'read': coreapi.Link( - url='/example/{id}/', - action='get', - fields=[ - coreapi.Field('id', required=True, location='path', schema=coreschema.String()) - ] - ), - 'sub': { - 'list': coreapi.Link( - url='/example/{id}/sub/', - action='get', - fields=[ - coreapi.Field('id', required=True, location='path', schema=coreschema.String()) - ] - ) - } - } - } - ) - assert schema == expected - - -@unittest.skipUnless(coreapi, 'coreapi is not installed') -@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) -class TestSchemaGeneratorDjango2(TestCase): - def setUp(self): - self.patterns = [ - path('example/', views.ExampleListView.as_view()), - path('example//', views.ExampleDetailView.as_view()), - path('example//sub/', views.ExampleDetailView.as_view()), - ] - - def test_schema_for_regular_views(self): - """ - Ensure that schema generation works for APIView classes. - """ - generator = SchemaGenerator(title='Example API', patterns=self.patterns) - schema = generator.get_schema() - expected = coreapi.Document( - url='', - title='Example API', - content={ - 'example': { - 'create': coreapi.Link( - url='/example/', - action='post', - fields=[] - ), - 'list': coreapi.Link( - url='/example/', - action='get', - fields=[] - ), - 'read': coreapi.Link( - url='/example/{id}/', - action='get', - fields=[ - coreapi.Field('id', required=True, location='path', schema=coreschema.String()) - ] - ), - 'sub': { - 'list': coreapi.Link( - url='/example/{id}/sub/', - action='get', - fields=[ - coreapi.Field('id', required=True, location='path', schema=coreschema.String()) - ] - ) - } - } - } - ) - assert schema == expected - - -@unittest.skipUnless(coreapi, 'coreapi is not installed') -@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) -class TestSchemaGeneratorNotAtRoot(TestCase): - def setUp(self): - self.patterns = [ - path('api/v1/example/', views.ExampleListView.as_view()), - path('api/v1/example//', views.ExampleDetailView.as_view()), - path('api/v1/example//sub/', views.ExampleDetailView.as_view()), - ] - - def test_schema_for_regular_views(self): - """ - Ensure that schema generation with an API that is not at the URL - root continues to use correct structure for link keys. - """ - generator = SchemaGenerator(title='Example API', patterns=self.patterns) - schema = generator.get_schema() - expected = coreapi.Document( - url='', - title='Example API', - content={ - 'example': { - 'create': coreapi.Link( - url='/api/v1/example/', - action='post', - fields=[] - ), - 'list': coreapi.Link( - url='/api/v1/example/', - action='get', - fields=[] - ), - 'read': coreapi.Link( - url='/api/v1/example/{id}/', - action='get', - fields=[ - coreapi.Field('id', required=True, location='path', schema=coreschema.String()) - ] - ), - 'sub': { - 'list': coreapi.Link( - url='/api/v1/example/{id}/sub/', - action='get', - fields=[ - coreapi.Field('id', required=True, location='path', schema=coreschema.String()) - ] - ) - } - } - } - ) - assert schema == expected - - -@unittest.skipUnless(coreapi, 'coreapi is not installed') -@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) -class TestSchemaGeneratorWithMethodLimitedViewSets(TestCase): - def setUp(self): - router = DefaultRouter() - router.register('example1', MethodLimitedViewSet, basename='example1') - self.patterns = [ - path('', include(router.urls)) - ] - - def test_schema_for_regular_views(self): - """ - Ensure that schema generation works for ViewSet classes - with method limitation by Django CBV's http_method_names attribute - """ - generator = SchemaGenerator(title='Example API', patterns=self.patterns) - request = factory.get('/example1/') - schema = generator.get_schema(Request(request)) - - expected = coreapi.Document( - url='http://testserver/example1/', - title='Example API', - content={ - 'example1': { - 'list': coreapi.Link( - url='/example1/', - action='get', - fields=[ - coreapi.Field('page', required=False, location='query', schema=coreschema.Integer(title='Page', description='A page number within the paginated result set.')), - coreapi.Field('page_size', required=False, location='query', schema=coreschema.Integer(title='Page size', description='Number of results to return per page.')), - coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.')) - ] - ), - 'custom_list_action': coreapi.Link( - url='/example1/custom_list_action/', - action='get' - ), - 'custom_list_action_multiple_methods': { - 'read': coreapi.Link( - url='/example1/custom_list_action_multiple_methods/', - action='get', - description='Custom description.', - ) - }, - 'documented_custom_action': { - 'read': coreapi.Link( - url='/example1/documented_custom_action/', - action='get', - description='A description of the get method on the custom action.', - ), - }, - 'read': coreapi.Link( - url='/example1/{id}/', - action='get', - fields=[ - coreapi.Field('id', required=True, location='path', schema=coreschema.String()), - coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.')) - ] - ) - } - } - ) - assert schema == expected - - -@unittest.skipUnless(coreapi, 'coreapi is not installed') -@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) -class TestSchemaGeneratorWithRestrictedViewSets(TestCase): - def setUp(self): - router = DefaultRouter() - router.register('example1', Http404ExampleViewSet, basename='example1') - router.register('example2', PermissionDeniedExampleViewSet, basename='example2') - self.patterns = [ - path('example/', views.ExampleListView.as_view()), - path('', include(router.urls)) - ] - - def test_schema_for_regular_views(self): - """ - Ensure that schema generation works for ViewSet classes - with permission classes raising exceptions. - """ - generator = SchemaGenerator(title='Example API', patterns=self.patterns) - request = factory.get('/') - schema = generator.get_schema(Request(request)) - expected = coreapi.Document( - url='http://testserver/', - title='Example API', - content={ - 'example': { - 'list': coreapi.Link( - url='/example/', - action='get', - fields=[] - ), - }, - } - ) - assert schema == expected - - -class ForeignKeySourceSerializer(serializers.ModelSerializer): - class Meta: - model = ForeignKeySource - fields = ('id', 'name', 'target') - - -class ForeignKeySourceView(generics.CreateAPIView): - queryset = ForeignKeySource.objects.all() - serializer_class = ForeignKeySourceSerializer - - -@unittest.skipUnless(coreapi, 'coreapi is not installed') -@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) -class TestSchemaGeneratorWithForeignKey(TestCase): - def setUp(self): - self.patterns = [ - path('example/', ForeignKeySourceView.as_view()), - ] - - def test_schema_for_regular_views(self): - """ - Ensure that AutoField foreign keys are output as Integer. - """ - generator = SchemaGenerator(title='Example API', patterns=self.patterns) - schema = generator.get_schema() - - expected = coreapi.Document( - url='', - title='Example API', - content={ - 'example': { - 'create': coreapi.Link( - url='/example/', - action='post', - encoding='application/json', - fields=[ - coreapi.Field('name', required=True, location='form', schema=coreschema.String(title='Name')), - coreapi.Field('target', required=True, location='form', schema=coreschema.Integer(description='Target', title='Target')), - ] - ) - } - } - ) - assert schema == expected - - -class ManyToManySourceSerializer(serializers.ModelSerializer): - class Meta: - model = ManyToManySource - fields = ('id', 'name', 'targets') - - -class ManyToManySourceView(generics.CreateAPIView): - queryset = ManyToManySource.objects.all() - serializer_class = ManyToManySourceSerializer - - -@unittest.skipUnless(coreapi, 'coreapi is not installed') -@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) -class TestSchemaGeneratorWithManyToMany(TestCase): - def setUp(self): - self.patterns = [ - path('example/', ManyToManySourceView.as_view()), - ] - - def test_schema_for_regular_views(self): - """ - Ensure that AutoField many to many fields are output as Integer. - """ - generator = SchemaGenerator(title='Example API', patterns=self.patterns) - schema = generator.get_schema() - - expected = coreapi.Document( - url='', - title='Example API', - content={ - 'example': { - 'create': coreapi.Link( - url='/example/', - action='post', - encoding='application/json', - fields=[ - coreapi.Field('name', required=True, location='form', schema=coreschema.String(title='Name')), - coreapi.Field('targets', required=True, location='form', schema=coreschema.Array(title='Targets', items=coreschema.Integer())), - ] - ) - } - } - ) - assert schema == expected - - -@unittest.skipUnless(coreapi, 'coreapi is not installed') -@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) -class TestSchemaGeneratorActionKeysViewSets(TestCase): - def test_action_not_coerced_for_get_and_head(self): - """ - Ensure that action name is preserved when action map contains "head". - """ - class CustomViewSet(GenericViewSet): - serializer_class = EmptySerializer - - @action(methods=['get', 'head'], detail=True) - def custom_read(self, request, pk): - raise NotImplementedError - - @action(methods=['put', 'patch'], detail=True) - def custom_mixed_update(self, request, pk): - raise NotImplementedError - - self.router = DefaultRouter() - self.router.register('example', CustomViewSet, basename='example') - self.patterns = [ - path('', include(self.router.urls)) - ] - - generator = SchemaGenerator(title='Example API', patterns=self.patterns) - schema = generator.get_schema() - - expected = coreapi.Document( - url='', - title='Example API', - content={ - 'example': { - 'custom_read': coreapi.Link( - url='/example/{id}/custom_read/', - action='get', - fields=[ - coreapi.Field('id', required=True, location='path', schema=coreschema.String()), - ] - ), - 'custom_mixed_update': { - 'update': coreapi.Link( - url='/example/{id}/custom_mixed_update/', - action='put', - fields=[ - coreapi.Field('id', required=True, location='path', schema=coreschema.String()), - ] - ), - 'partial_update': coreapi.Link( - url='/example/{id}/custom_mixed_update/', - action='patch', - fields=[ - coreapi.Field('id', required=True, location='path', schema=coreschema.String()), - ] - ) - } - } - } - ) - assert schema == expected - - -@unittest.skipUnless(coreapi, 'coreapi is not installed') -@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) -class Test4605Regression(TestCase): - def test_4605_regression(self): - generator = SchemaGenerator() - prefix = generator.determine_path_prefix([ - '/api/v1/items/', - '/auth/convert-token/' - ]) - assert prefix == '/' - - -class CustomViewInspector(AutoSchema): - """A dummy AutoSchema subclass""" - pass - - -@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) -class TestAutoSchema(TestCase): - - def test_apiview_schema_descriptor(self): - view = APIView() - assert hasattr(view, 'schema') - assert isinstance(view.schema, AutoSchema) - - def test_set_custom_inspector_class_on_view(self): - class CustomView(APIView): - schema = CustomViewInspector() - - view = CustomView() - assert isinstance(view.schema, CustomViewInspector) - - def test_set_custom_inspector_class_via_settings(self): - with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'tests.schemas.test_coreapi.CustomViewInspector'}): - view = APIView() - assert isinstance(view.schema, CustomViewInspector) - - def test_get_link_requires_instance(self): - descriptor = APIView.schema # Accessed from class - with pytest.raises(AssertionError): - descriptor.get_link(None, None, None) # ???: Do the dummy arguments require a tighter assert? - - @pytest.mark.skipif(not coreapi, reason='coreapi is not installed') - def test_update_fields(self): - """ - That updating fields by-name helper is correct - - Recall: `update_fields(fields, update_with)` - """ - schema = AutoSchema() - fields = [] - - # Adds a field... - fields = schema.update_fields(fields, [ - coreapi.Field( - "my_field", - required=True, - location="path", - schema=coreschema.String() - ), - ]) - - assert len(fields) == 1 - assert fields[0].name == "my_field" - - # Replaces a field... - fields = schema.update_fields(fields, [ - coreapi.Field( - "my_field", - required=False, - location="path", - schema=coreschema.String() - ), - ]) - - assert len(fields) == 1 - assert fields[0].required is False - - @pytest.mark.skipif(not coreapi, reason='coreapi is not installed') - def test_get_manual_fields(self): - """That get_manual_fields is applied during get_link""" - - class CustomView(APIView): - schema = AutoSchema(manual_fields=[ - coreapi.Field( - "my_extra_field", - required=True, - location="path", - schema=coreschema.String() - ), - ]) - - view = CustomView() - link = view.schema.get_link('/a/url/{id}/', 'GET', '') - fields = link.fields - - assert len(fields) == 2 - assert "my_extra_field" in [f.name for f in fields] - - @pytest.mark.skipif(not coreapi, reason='coreapi is not installed') - def test_viewset_action_with_schema(self): - class CustomViewSet(GenericViewSet): - @action(detail=True, schema=AutoSchema(manual_fields=[ - coreapi.Field( - "my_extra_field", - required=True, - location="path", - schema=coreschema.String() - ), - ])) - def extra_action(self, pk, **kwargs): - pass - - router = SimpleRouter() - router.register(r'detail', CustomViewSet, basename='detail') - - generator = SchemaGenerator() - view = generator.create_view(router.urls[0].callback, 'GET') - link = view.schema.get_link('/a/url/{id}/', 'GET', '') - fields = link.fields - - assert len(fields) == 2 - assert "my_extra_field" in [f.name for f in fields] - - @pytest.mark.skipif(not coreapi, reason='coreapi is not installed') - def test_viewset_action_with_null_schema(self): - class CustomViewSet(GenericViewSet): - @action(detail=True, schema=None) - def extra_action(self, pk, **kwargs): - pass - - router = SimpleRouter() - router.register(r'detail', CustomViewSet, basename='detail') - - generator = SchemaGenerator() - view = generator.create_view(router.urls[0].callback, 'GET') - assert view.schema is None - - @pytest.mark.skipif(not coreapi, reason='coreapi is not installed') - def test_view_with_manual_schema(self): - - path = '/example' - method = 'get' - base_url = None - - fields = [ - coreapi.Field( - "first_field", - required=True, - location="path", - schema=coreschema.String() - ), - coreapi.Field( - "second_field", - required=True, - location="path", - schema=coreschema.String() - ), - coreapi.Field( - "third_field", - required=True, - location="path", - schema=coreschema.String() - ), - ] - description = "A test endpoint" - - class CustomView(APIView): - """ - ManualSchema takes list of fields for endpoint. - - Provides url and action, which are always dynamic - """ - schema = ManualSchema(fields, description) - - expected = coreapi.Link( - url=path, - action=method, - fields=fields, - description=description - ) - - view = CustomView() - link = view.schema.get_link(path, method, base_url) - assert link == expected - - @unittest.skipUnless(coreschema, 'coreschema is not installed') - def test_field_to_schema(self): - label = 'Test label' - help_text = 'This is a helpful test text' - - cases = [ - # tuples are ([field], [expected schema]) - # TODO: Add remaining cases - ( - serializers.BooleanField(label=label, help_text=help_text), - coreschema.Boolean(title=label, description=help_text) - ), - ( - serializers.DecimalField(1000, 1000, label=label, help_text=help_text), - coreschema.Number(title=label, description=help_text) - ), - ( - serializers.FloatField(label=label, help_text=help_text), - coreschema.Number(title=label, description=help_text) - ), - ( - serializers.IntegerField(label=label, help_text=help_text), - coreschema.Integer(title=label, description=help_text) - ), - ( - serializers.DateField(label=label, help_text=help_text), - coreschema.String(title=label, description=help_text, format='date') - ), - ( - serializers.DateTimeField(label=label, help_text=help_text), - coreschema.String(title=label, description=help_text, format='date-time') - ), - ( - serializers.JSONField(label=label, help_text=help_text), - coreschema.Object(title=label, description=help_text) - ), - ] - - for case in cases: - self.assertEqual(field_to_schema(case[0]), case[1]) - - -@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) -def test_docstring_is_not_stripped_by_get_description(): - class ExampleDocstringAPIView(APIView): - """ - === title - - * item a - * item a-a - * item a-b - * item b - - - item 1 - - item 2 - - code block begin - code - code - code - code block end - - the end - """ - - def get(self, *args, **kwargs): - pass - - def post(self, request, *args, **kwargs): - pass - - view = ExampleDocstringAPIView() - schema = view.schema - descr = schema.get_description('example', 'get') - # the first and last character are '\n' correctly removed by get_description - assert descr == formatting.dedent(ExampleDocstringAPIView.__doc__[1:][:-1]) - - -# Views for SchemaGenerationExclusionTests -with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}): - class ExcludedAPIView(APIView): - schema = None - - def get(self, request, *args, **kwargs): - pass - - @api_view(['GET']) - @schema(None) - def excluded_fbv(request): - pass - - @api_view(['GET']) - def included_fbv(request): - pass - - -@unittest.skipUnless(coreapi, 'coreapi is not installed') -@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) -class SchemaGenerationExclusionTests(TestCase): - def setUp(self): - self.patterns = [ - path('excluded-cbv/', ExcludedAPIView.as_view()), - path('excluded-fbv/', excluded_fbv), - path('included-fbv/', included_fbv), - ] - - def test_schema_generator_excludes_correctly(self): - """Schema should not include excluded views""" - generator = SchemaGenerator(title='Exclusions', patterns=self.patterns) - schema = generator.get_schema() - expected = coreapi.Document( - url='', - title='Exclusions', - content={ - 'included-fbv': { - 'list': coreapi.Link(url='/included-fbv/', action='get') - } - } - ) - - assert len(schema.data) == 1 - assert 'included-fbv' in schema.data - assert schema == expected - - def test_endpoint_enumerator_excludes_correctly(self): - """It is responsibility of EndpointEnumerator to exclude views""" - inspector = EndpointEnumerator(self.patterns) - endpoints = inspector.get_api_endpoints() - - assert len(endpoints) == 1 - path, method, callback = endpoints[0] - assert path == '/included-fbv/' - - def test_should_include_endpoint_excludes_correctly(self): - """This is the specific method that should handle the exclusion""" - inspector = EndpointEnumerator(self.patterns) - - # Not pretty. Mimics internals of EndpointEnumerator to put should_include_endpoint under test - pairs = [(inspector.get_path_from_regex(pattern.pattern.regex.pattern), pattern.callback) - for pattern in self.patterns] - - should_include = [ - inspector.should_include_endpoint(*pair) for pair in pairs - ] - - expected = [False, False, True] - - assert should_include == expected - - -class BasicModelSerializer(serializers.ModelSerializer): - class Meta: - model = BasicModel - fields = "__all__" - - -class NamingCollisionView(generics.RetrieveUpdateDestroyAPIView): - queryset = BasicModel.objects.all() - serializer_class = BasicModelSerializer - - -class BasicNamingCollisionView(generics.RetrieveAPIView): - queryset = BasicModel.objects.all() - - -class NamingCollisionViewSet(GenericViewSet): - """ - Example via: https://stackoverflow.com/questions/43778668/django-rest-framwork-occured-typeerror-link-object-does-not-support-item-ass/ - """ - permission_classes = () - - @action(detail=False) - def detail(self, request): - return {} - - @action(detail=False, url_path='detail/export') - def detail_export(self, request): - return {} - - -naming_collisions_router = SimpleRouter() -naming_collisions_router.register(r'collision', NamingCollisionViewSet, basename="collision") - - -@pytest.mark.skipif(not coreapi, reason='coreapi is not installed') -@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) -class TestURLNamingCollisions(TestCase): - """ - Ref: https://github.com/encode/django-rest-framework/issues/4704 - """ - def test_manually_routing_nested_routes(self): - @api_view(["GET"]) - def simple_fbv(request): - pass - - patterns = [ - path('test', simple_fbv), - path('test/list/', simple_fbv), - ] - - generator = SchemaGenerator(title='Naming Colisions', patterns=patterns) - schema = generator.get_schema() - - expected = coreapi.Document( - url='', - title='Naming Colisions', - content={ - 'test': { - 'list': { - 'list': coreapi.Link(url='/test/list/', action='get') - }, - 'list_0': coreapi.Link(url='/test', action='get') - } - } - ) - - assert expected == schema - - def _verify_cbv_links(self, loc, url, methods=None, suffixes=None): - if methods is None: - methods = ('read', 'update', 'partial_update', 'delete') - if suffixes is None: - suffixes = (None for m in methods) - - for method, suffix in zip(methods, suffixes): - if suffix is not None: - key = f'{method}_{suffix}' - else: - key = method - assert loc[key].url == url - - def test_manually_routing_generic_view(self): - patterns = [ - path('test', NamingCollisionView.as_view()), - path('test/retrieve/', NamingCollisionView.as_view()), - path('test/update/', NamingCollisionView.as_view()), - - # Fails with method names: - path('test/get/', NamingCollisionView.as_view()), - path('test/put/', NamingCollisionView.as_view()), - path('test/delete/', NamingCollisionView.as_view()), - ] - - generator = SchemaGenerator(title='Naming Colisions', patterns=patterns) - - schema = generator.get_schema() - - self._verify_cbv_links(schema['test']['delete'], '/test/delete/') - self._verify_cbv_links(schema['test']['put'], '/test/put/') - self._verify_cbv_links(schema['test']['get'], '/test/get/') - self._verify_cbv_links(schema['test']['update'], '/test/update/') - self._verify_cbv_links(schema['test']['retrieve'], '/test/retrieve/') - self._verify_cbv_links(schema['test'], '/test', suffixes=(None, '0', None, '0')) - - def test_from_router(self): - patterns = [ - path('from-router', include(naming_collisions_router.urls)), - ] - - generator = SchemaGenerator(title='Naming Colisions', patterns=patterns) - schema = generator.get_schema() - - # not important here - desc_0 = schema['detail']['detail_export'].description - desc_1 = schema['detail_0'].description - - expected = coreapi.Document( - url='', - title='Naming Colisions', - content={ - 'detail': { - 'detail_export': coreapi.Link( - url='/from-routercollision/detail/export/', - action='get', - description=desc_0) - }, - 'detail_0': coreapi.Link( - url='/from-routercollision/detail/', - action='get', - description=desc_1 - ) - } - ) - - assert schema == expected - - def test_url_under_same_key_not_replaced(self): - patterns = [ - path('example//', BasicNamingCollisionView.as_view()), - path('example//', BasicNamingCollisionView.as_view()), - ] - - generator = SchemaGenerator(title='Naming Colisions', patterns=patterns) - schema = generator.get_schema() - - assert schema['example']['read'].url == '/example/{id}/' - assert schema['example']['read_0'].url == '/example/{slug}/' - - def test_url_under_same_key_not_replaced_another(self): - - @api_view(["GET"]) - def simple_fbv(request): - pass - - patterns = [ - path('test/list/', simple_fbv), - path('test//list/', simple_fbv), - ] - - generator = SchemaGenerator(title='Naming Colisions', patterns=patterns) - schema = generator.get_schema() - - assert schema['test']['list']['list'].url == '/test/list/' - assert schema['test']['list']['list_0'].url == '/test/{id}/list/' - - -def test_is_list_view_recognises_retrieve_view_subclasses(): - class TestView(generics.RetrieveAPIView): - pass - - path = '/looks/like/a/list/view/' - method = 'get' - view = TestView() - - is_list = is_list_view(path, method, view) - assert not is_list, "RetrieveAPIView subclasses should not be classified as list views." - - -def test_head_and_options_methods_are_excluded(): - """ - Regression test for #5528 - https://github.com/encode/django-rest-framework/issues/5528 - - Viewset OPTIONS actions were not being correctly excluded - - Initial cases here shown to be working as expected. - """ - - @api_view(['options', 'get']) - def fbv(request): - pass - - inspector = EndpointEnumerator() - - path = '/a/path/' - callback = fbv - - assert inspector.should_include_endpoint(path, callback) - assert inspector.get_allowed_methods(callback) == ["GET"] - - class AnAPIView(APIView): - - def get(self, request, *args, **kwargs): - pass - - def options(self, request, *args, **kwargs): - pass - - callback = AnAPIView.as_view() - - assert inspector.should_include_endpoint(path, callback) - assert inspector.get_allowed_methods(callback) == ["GET"] - - class AViewSet(ModelViewSet): - - @action(methods=['options', 'get'], detail=True) - def custom_action(self, request, pk): - pass - - callback = AViewSet.as_view({ - "options": "custom_action", - "get": "custom_action" - }) - - assert inspector.should_include_endpoint(path, callback) - assert inspector.get_allowed_methods(callback) == ["GET"] - - -class MockAPIView(APIView): - filter_backends = [filters.OrderingFilter] - - def _test(self, method): - view = self.MockAPIView() - fields = view.schema.get_filter_fields('', method) - field_names = [f.name for f in fields] - - return 'ordering' in field_names - - def test_get(self): - assert self._test('get') - - def test_GET(self): - assert self._test('GET') - - def test_put(self): - assert self._test('put') - - def test_PUT(self): - assert self._test('PUT') - - def test_patch(self): - assert self._test('patch') - - def test_PATCH(self): - assert self._test('PATCH') - - def test_delete(self): - assert self._test('delete') - - def test_DELETE(self): - assert self._test('DELETE') - - def test_post(self): - assert not self._test('post') - - def test_POST(self): - assert not self._test('POST') - - def test_foo(self): - assert not self._test('foo') - - def test_FOO(self): - assert not self._test('FOO') - - -@pytest.mark.skipif(not coreapi, reason='coreapi is not installed') -def test_schema_handles_exception(): - schema_view = get_schema_view(permission_classes=[DenyAllUsingPermissionDenied]) - request = factory.get('/') - response = schema_view(request) - response.render() - assert response.status_code == 403 - assert b"You do not have permission to perform this action." in response.content - - -@pytest.mark.skipif(not coreapi, reason='coreapi is not installed') -def test_coreapi_deprecation(): - with pytest.warns(RemovedInDRF317Warning): - SchemaGenerator() - - with pytest.warns(RemovedInDRF317Warning): - AutoSchema() - - with pytest.warns(RemovedInDRF317Warning): - ManualSchema({}) - - with pytest.warns(RemovedInDRF317Warning): - deprecated_filter = OrderingFilter() - deprecated_filter.get_schema_fields({}) - - with pytest.warns(RemovedInDRF317Warning): - deprecated_filter = BaseFilterBackend() - deprecated_filter.get_schema_fields({}) - - with pytest.warns(RemovedInDRF317Warning): - deprecated_filter = SearchFilter() - deprecated_filter.get_schema_fields({}) - - with pytest.warns(RemovedInDRF317Warning): - paginator = BasePagination() - paginator.get_schema_fields({}) - - with pytest.warns(RemovedInDRF317Warning): - paginator = PageNumberPagination() - paginator.get_schema_fields({}) - - with pytest.warns(RemovedInDRF317Warning): - paginator = LimitOffsetPagination() - paginator.get_schema_fields({}) - - with pytest.warns(RemovedInDRF317Warning): - paginator = CursorPagination() - paginator.get_schema_fields({}) - - with pytest.warns(RemovedInDRF317Warning): - is_enabled() diff --git a/tests/schemas/test_get_schema_view.py b/tests/schemas/test_get_schema_view.py deleted file mode 100644 index f582c64954..0000000000 --- a/tests/schemas/test_get_schema_view.py +++ /dev/null @@ -1,20 +0,0 @@ -import pytest -from django.test import TestCase, override_settings - -from rest_framework import renderers -from rest_framework.schemas import coreapi, get_schema_view, openapi - - -class GetSchemaViewTests(TestCase): - """For the get_schema_view() helper.""" - def test_openapi(self): - schema_view = get_schema_view(title="With OpenAPI") - assert isinstance(schema_view.initkwargs['schema_generator'], openapi.SchemaGenerator) - assert renderers.OpenAPIRenderer in schema_view.cls().renderer_classes - - @pytest.mark.skipif(not coreapi.coreapi, reason='coreapi is not installed') - def test_coreapi(self): - with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}): - schema_view = get_schema_view(title="With CoreAPI") - assert isinstance(schema_view.initkwargs['schema_generator'], coreapi.SchemaGenerator) - assert renderers.CoreAPIOpenAPIRenderer in schema_view.cls().renderer_classes diff --git a/tests/schemas/test_managementcommand.py b/tests/schemas/test_managementcommand.py deleted file mode 100644 index fa1b75fbf1..0000000000 --- a/tests/schemas/test_managementcommand.py +++ /dev/null @@ -1,154 +0,0 @@ -import io -import os -import tempfile - -import pytest -from django.core.management import call_command -from django.test import TestCase -from django.test.utils import override_settings -from django.urls import path - -from rest_framework.compat import coreapi, uritemplate, yaml -from rest_framework.management.commands import generateschema -from rest_framework.utils import formatting, json -from rest_framework.views import APIView - - -class FooView(APIView): - def get(self, request): - pass - - -urlpatterns = [ - path('', FooView.as_view()) -] - - -class CustomSchemaGenerator: - SCHEMA = {"key": "value"} - - def __init__(self, *args, **kwargs): - pass - - def get_schema(self, **kwargs): - return self.SCHEMA - - -@override_settings(ROOT_URLCONF=__name__) -@pytest.mark.skipif(not uritemplate, reason='uritemplate is not installed') -class GenerateSchemaTests(TestCase): - """Tests for management command generateschema.""" - - def setUp(self): - self.out = io.StringIO() - - def test_command_detects_schema_generation_mode(self): - """Switching between CoreAPI & OpenAPI""" - command = generateschema.Command() - assert command.get_mode() == generateschema.OPENAPI_MODE - with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}): - assert command.get_mode() == generateschema.COREAPI_MODE - - @pytest.mark.skipif(yaml is None, reason='PyYAML is required.') - def test_renders_default_schema_with_custom_title_url_and_description(self): - call_command('generateschema', - '--title=ExampleAPI', - '--url=http://api.example.com', - '--description=Example description', - stdout=self.out) - # Check valid YAML was output. - schema = yaml.safe_load(self.out.getvalue()) - assert schema['openapi'] == '3.0.2' - - def test_renders_openapi_json_schema(self): - call_command('generateschema', - '--format=openapi-json', - stdout=self.out) - # Check valid JSON was output. - out_json = json.loads(self.out.getvalue()) - assert out_json['openapi'] == '3.0.2' - - def test_accepts_custom_schema_generator(self): - call_command('generateschema', - f'--generator_class={__name__}.{CustomSchemaGenerator.__name__}', - stdout=self.out) - out_json = yaml.safe_load(self.out.getvalue()) - assert out_json == CustomSchemaGenerator.SCHEMA - - def test_writes_schema_to_file_on_parameter(self): - fd, path = tempfile.mkstemp() - try: - call_command('generateschema', f'--file={path}', stdout=self.out) - # nothing on stdout - assert not self.out.getvalue() - - call_command('generateschema', stdout=self.out) - expected_out = self.out.getvalue() - # file output identical to stdout output - with os.fdopen(fd) as fh: - assert expected_out and fh.read() == expected_out - finally: - os.remove(path) - - @pytest.mark.skipif(yaml is None, reason='PyYAML is required.') - @pytest.mark.skipif(coreapi is None, reason='coreapi is required.') - @override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) - def test_coreapi_renders_default_schema_with_custom_title_url_and_description(self): - expected_out = """info: - description: Example description - title: ExampleAPI - version: '' - openapi: 3.0.0 - paths: - /: - get: - operationId: list - servers: - - url: http://api.example.com/ - """ - call_command('generateschema', - '--title=ExampleAPI', - '--url=http://api.example.com', - '--description=Example description', - stdout=self.out) - - self.assertIn(formatting.dedent(expected_out), self.out.getvalue()) - - @pytest.mark.skipif(coreapi is None, reason='coreapi is required.') - @override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) - def test_coreapi_renders_openapi_json_schema(self): - expected_out = { - "openapi": "3.0.0", - "info": { - "version": "", - "title": "", - "description": "" - }, - "servers": [ - { - "url": "" - } - ], - "paths": { - "/": { - "get": { - "operationId": "list" - } - } - } - } - call_command('generateschema', - '--format=openapi-json', - stdout=self.out) - out_json = json.loads(self.out.getvalue()) - - self.assertDictEqual(out_json, expected_out) - - @pytest.mark.skipif(coreapi is None, reason='coreapi is required.') - @override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) - def test_renders_corejson_schema(self): - expected_out = """{"_type":"document","":{"list":{"_type":"link","url":"/","action":"get"}}}""" - call_command('generateschema', - '--format=corejson', - stdout=self.out) - self.assertIn(expected_out, self.out.getvalue()) diff --git a/tests/schemas/test_openapi.py b/tests/schemas/test_openapi.py deleted file mode 100644 index a168cb4661..0000000000 --- a/tests/schemas/test_openapi.py +++ /dev/null @@ -1,1366 +0,0 @@ -import uuid -import warnings - -import pytest -from django.db import models -from django.test import RequestFactory, TestCase, override_settings -from django.urls import path -from django.utils.safestring import SafeString -from django.utils.translation import gettext_lazy as _ - -from rest_framework import filters, generics, pagination, routers, serializers -from rest_framework.authtoken.views import obtain_auth_token -from rest_framework.compat import uritemplate -from rest_framework.parsers import JSONParser, MultiPartParser -from rest_framework.renderers import ( - BaseRenderer, BrowsableAPIRenderer, JSONOpenAPIRenderer, JSONRenderer, - OpenAPIRenderer -) -from rest_framework.request import Request -from rest_framework.schemas.openapi import AutoSchema, SchemaGenerator - -from ..models import BasicModel -from . import views - - -def create_request(path): - factory = RequestFactory() - request = Request(factory.get(path)) - return request - - -def create_view(view_cls, method, request): - generator = SchemaGenerator() - view = generator.create_view(view_cls.as_view(), method, request) - return view - - -class TestBasics(TestCase): - def dummy_view(request): - pass - - def test_filters(self): - classes = [filters.SearchFilter, filters.OrderingFilter] - for c in classes: - f = c() - assert f.get_schema_operation_parameters(self.dummy_view) - - def test_pagination(self): - classes = [pagination.PageNumberPagination, pagination.LimitOffsetPagination, pagination.CursorPagination] - for c in classes: - f = c() - assert f.get_schema_operation_parameters(self.dummy_view) - - -class TestFieldMapping(TestCase): - def test_list_field_mapping(self): - uuid1 = uuid.uuid4() - uuid2 = uuid.uuid4() - inspector = AutoSchema() - cases = [ - (serializers.ListField(), {'items': {}, 'type': 'array'}), - (serializers.ListField(child=serializers.BooleanField()), {'items': {'type': 'boolean'}, 'type': 'array'}), - (serializers.ListField(child=serializers.FloatField()), {'items': {'type': 'number'}, 'type': 'array'}), - (serializers.ListField(child=serializers.CharField()), {'items': {'type': 'string'}, 'type': 'array'}), - (serializers.ListField(child=serializers.IntegerField(max_value=4294967295)), - {'items': {'type': 'integer', 'maximum': 4294967295, 'format': 'int64'}, 'type': 'array'}), - (serializers.ListField(child=serializers.ChoiceField(choices=[('a', 'Choice A'), ('b', 'Choice B')])), - {'items': {'enum': ['a', 'b'], 'type': 'string'}, 'type': 'array'}), - (serializers.ListField(child=serializers.ChoiceField(choices=[(1, 'One'), (2, 'Two')])), - {'items': {'enum': [1, 2], 'type': 'integer'}, 'type': 'array'}), - (serializers.ListField(child=serializers.ChoiceField(choices=[(1.1, 'First'), (2.2, 'Second')])), - {'items': {'enum': [1.1, 2.2], 'type': 'number'}, 'type': 'array'}), - (serializers.ListField(child=serializers.ChoiceField(choices=[(True, 'true'), (False, 'false')])), - {'items': {'enum': [True, False], 'type': 'boolean'}, 'type': 'array'}), - (serializers.ListField(child=serializers.ChoiceField(choices=[(uuid1, 'uuid1'), (uuid2, 'uuid2')])), - {'items': {'enum': [uuid1, uuid2]}, 'type': 'array'}), - (serializers.ListField(child=serializers.ChoiceField(choices=[(1, 'One'), ('a', 'Choice A')])), - {'items': {'enum': [1, 'a']}, 'type': 'array'}), - (serializers.ListField(child=serializers.ChoiceField(choices=[ - (1, 'One'), ('a', 'Choice A'), (1.1, 'First'), (1.1, 'First'), (1, 'One'), ('a', 'Choice A'), (1, 'One') - ])), - {'items': {'enum': [1, 'a', 1.1]}, 'type': 'array'}), - (serializers.ListField(child=serializers.ChoiceField(choices=[ - (1, 'One'), (2, 'Two'), (3, 'Three'), (2, 'Two'), (3, 'Three'), (1, 'One'), - ])), - {'items': {'enum': [1, 2, 3], 'type': 'integer'}, 'type': 'array'}), - (serializers.IntegerField(min_value=2147483648), - {'type': 'integer', 'minimum': 2147483648, 'format': 'int64'}), - ] - for field, mapping in cases: - with self.subTest(field=field): - assert inspector.map_field(field) == mapping - - def test_lazy_string_field(self): - class ItemSerializer(serializers.Serializer): - text = serializers.CharField(help_text=_('lazy string')) - - inspector = AutoSchema() - - data = inspector.map_serializer(ItemSerializer()) - assert isinstance(data['properties']['text']['description'], str), "description must be str" - - def test_boolean_default_field(self): - class Serializer(serializers.Serializer): - default_true = serializers.BooleanField(default=True) - default_false = serializers.BooleanField(default=False) - without_default = serializers.BooleanField() - - inspector = AutoSchema() - - data = inspector.map_serializer(Serializer()) - assert data['properties']['default_true']['default'] is True, "default must be true" - assert data['properties']['default_false']['default'] is False, "default must be false" - assert 'default' not in data['properties']['without_default'], "default must not be defined" - - def test_custom_field_name(self): - class CustomSchema(AutoSchema): - def get_field_name(self, field): - return 'custom_' + field.field_name - - class Serializer(serializers.Serializer): - text_field = serializers.CharField() - - inspector = CustomSchema() - - data = inspector.map_serializer(Serializer()) - assert 'custom_text_field' in data['properties'] - assert 'text_field' not in data['properties'] - - def test_nullable_fields(self): - class Model(models.Model): - rw_field = models.CharField(null=True) - ro_field = models.CharField(null=True) - - class Serializer(serializers.ModelSerializer): - class Meta: - model = Model - fields = ["rw_field", "ro_field"] - read_only_fields = ["ro_field"] - - inspector = AutoSchema() - - data = inspector.map_serializer(Serializer()) - assert data['properties']['rw_field']['nullable'], "rw_field nullable must be true" - assert data['properties']['ro_field']['nullable'], "ro_field nullable must be true" - assert data['properties']['ro_field']['readOnly'], "ro_field read_only must be true" - - def test_primary_key_related_field(self): - class PrimaryKeyRelatedFieldSerializer(serializers.Serializer): - basic = serializers.PrimaryKeyRelatedField(queryset=BasicModel.objects.all()) - uuid = serializers.PrimaryKeyRelatedField(queryset=BasicModel.objects.all(), - pk_field=serializers.UUIDField()) - char = serializers.PrimaryKeyRelatedField(queryset=BasicModel.objects.all(), - pk_field=serializers.CharField()) - - serializer = PrimaryKeyRelatedFieldSerializer() - inspector = AutoSchema() - - data = inspector.map_serializer(serializer=serializer) - assert data['properties']['basic']['type'] == "integer" - assert data['properties']['uuid']['format'] == "uuid" - assert data['properties']['char']['type'] == "string" - - -@pytest.mark.skipif(uritemplate is None, reason='uritemplate not installed.') -class TestOperationIntrospection(TestCase): - - def test_path_without_parameters(self): - path = '/example/' - method = 'GET' - - view = create_view( - views.DocStringExampleListView, - method, - create_request(path) - ) - inspector = AutoSchema() - inspector.view = view - - operation = inspector.get_operation(path, method) - assert operation == { - 'operationId': 'listDocStringExamples', - 'description': 'A description of my GET operation.', - 'parameters': [], - 'tags': ['example'], - 'responses': { - '200': { - 'description': '', - 'content': { - 'application/json': { - 'schema': { - 'type': 'array', - 'items': {}, - }, - }, - }, - }, - }, - } - - def test_path_with_id_parameter(self): - path = '/example/{id}/' - method = 'GET' - - view = create_view( - views.DocStringExampleDetailView, - method, - create_request(path) - ) - inspector = AutoSchema() - inspector.view = view - - operation = inspector.get_operation(path, method) - assert operation == { - 'operationId': 'retrieveDocStringExampleDetail', - 'description': 'A description of my GET operation.', - 'parameters': [{ - 'description': '', - 'in': 'path', - 'name': 'id', - 'required': True, - 'schema': { - 'type': 'string', - }, - }], - 'tags': ['example'], - 'responses': { - '200': { - 'description': '', - 'content': { - 'application/json': { - 'schema': { - }, - }, - }, - }, - }, - } - - def test_request_body(self): - path = '/' - method = 'POST' - - class ItemSerializer(serializers.Serializer): - text = serializers.CharField() - read_only = serializers.CharField(read_only=True) - - class View(generics.GenericAPIView): - serializer_class = ItemSerializer - - view = create_view( - View, - method, - create_request(path) - ) - inspector = AutoSchema() - inspector.view = view - - request_body = inspector.get_request_body(path, method) - print(request_body) - assert request_body['content']['application/json']['schema']['$ref'] == '#/components/schemas/Item' - - components = inspector.get_components(path, method) - assert components['Item']['required'] == ['text'] - assert sorted(list(components['Item']['properties'].keys())) == ['read_only', 'text'] - - def test_invalid_serializer_class_name(self): - path = '/' - method = 'POST' - - class Serializer(serializers.Serializer): - text = serializers.CharField() - read_only = serializers.CharField(read_only=True) - - class View(generics.GenericAPIView): - serializer_class = Serializer - - view = create_view( - View, - method, - create_request(path) - ) - inspector = AutoSchema() - inspector.view = view - - serializer = inspector.get_serializer(path, method) - - with pytest.raises(Exception) as exc: - inspector.get_component_name(serializer) - assert "is an invalid class name for schema generation" in str(exc.value) - - def test_empty_required(self): - path = '/' - method = 'POST' - - class ItemSerializer(serializers.Serializer): - read_only = serializers.CharField(read_only=True) - write_only = serializers.CharField(write_only=True, required=False) - - class View(generics.GenericAPIView): - serializer_class = ItemSerializer - - view = create_view( - View, - method, - create_request(path) - ) - inspector = AutoSchema() - inspector.view = view - - components = inspector.get_components(path, method) - component = components['Item'] - # there should be no empty 'required' property, see #6834 - assert 'required' not in component - - for response in inspector.get_responses(path, method).values(): - assert 'required' not in component - - def test_empty_required_with_patch_method(self): - path = '/' - method = 'PATCH' - - class ItemSerializer(serializers.Serializer): - read_only = serializers.CharField(read_only=True) - write_only = serializers.CharField(write_only=True, required=False) - - class View(generics.GenericAPIView): - serializer_class = ItemSerializer - - view = create_view( - View, - method, - create_request(path) - ) - inspector = AutoSchema() - inspector.view = view - - components = inspector.get_components(path, method) - component = components['Item'] - # there should be no empty 'required' property, see #6834 - assert 'required' not in component - for response in inspector.get_responses(path, method).values(): - assert 'required' not in component - - def test_response_body_generation(self): - path = '/' - method = 'POST' - - class ItemSerializer(serializers.Serializer): - text = serializers.CharField() - write_only = serializers.CharField(write_only=True) - - class View(generics.GenericAPIView): - serializer_class = ItemSerializer - - view = create_view( - View, - method, - create_request(path) - ) - inspector = AutoSchema() - inspector.view = view - - responses = inspector.get_responses(path, method) - assert responses['201']['content']['application/json']['schema']['$ref'] == '#/components/schemas/Item' - - components = inspector.get_components(path, method) - assert sorted(components['Item']['required']) == ['text', 'write_only'] - assert sorted(list(components['Item']['properties'].keys())) == ['text', 'write_only'] - assert 'description' in responses['201'] - - def test_response_body_nested_serializer(self): - path = '/' - method = 'POST' - - class NestedSerializer(serializers.Serializer): - number = serializers.IntegerField() - - class ItemSerializer(serializers.Serializer): - text = serializers.CharField() - nested = NestedSerializer() - - class View(generics.GenericAPIView): - serializer_class = ItemSerializer - - view = create_view( - View, - method, - create_request(path), - ) - inspector = AutoSchema() - inspector.view = view - - responses = inspector.get_responses(path, method) - assert responses['201']['content']['application/json']['schema']['$ref'] == '#/components/schemas/Item' - components = inspector.get_components(path, method) - assert components['Item'] - - schema = components['Item'] - assert sorted(schema['required']) == ['nested', 'text'] - assert sorted(list(schema['properties'].keys())) == ['nested', 'text'] - assert schema['properties']['nested']['type'] == 'object' - assert list(schema['properties']['nested']['properties'].keys()) == ['number'] - assert schema['properties']['nested']['required'] == ['number'] - - def test_response_body_partial_serializer(self): - path = '/' - method = 'GET' - - class ItemSerializer(serializers.Serializer): - text = serializers.CharField() - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.partial = True - - class View(generics.GenericAPIView): - serializer_class = ItemSerializer - - view = create_view( - View, - method, - create_request(path), - ) - inspector = AutoSchema() - inspector.view = view - - responses = inspector.get_responses(path, method) - assert responses == { - '200': { - 'description': '', - 'content': { - 'application/json': { - 'schema': { - 'type': 'array', - 'items': { - '$ref': '#/components/schemas/Item' - }, - }, - }, - }, - }, - } - components = inspector.get_components(path, method) - assert components == { - 'Item': { - 'type': 'object', - 'properties': { - 'text': { - 'type': 'string', - }, - }, - } - } - - def test_list_response_body_generation(self): - """Test that an array schema is returned for list views.""" - path = '/' - method = 'GET' - - class ItemSerializer(serializers.Serializer): - text = serializers.CharField() - - class View(generics.GenericAPIView): - serializer_class = ItemSerializer - - view = create_view( - View, - method, - create_request(path), - ) - inspector = AutoSchema() - inspector.view = view - - responses = inspector.get_responses(path, method) - assert responses == { - '200': { - 'description': '', - 'content': { - 'application/json': { - 'schema': { - 'type': 'array', - 'items': { - '$ref': '#/components/schemas/Item' - }, - }, - }, - }, - }, - } - components = inspector.get_components(path, method) - assert components == { - 'Item': { - 'type': 'object', - 'properties': { - 'text': { - 'type': 'string', - }, - }, - 'required': ['text'], - } - } - - def test_paginated_list_response_body_generation(self): - """Test that pagination properties are added for a paginated list view.""" - path = '/' - method = 'GET' - - class Pagination(pagination.BasePagination): - def get_paginated_response_schema(self, schema): - return { - 'type': 'object', - 'item': schema, - } - - class ItemSerializer(serializers.Serializer): - text = serializers.CharField() - - class View(generics.GenericAPIView): - serializer_class = ItemSerializer - pagination_class = Pagination - - view = create_view( - View, - method, - create_request(path), - ) - inspector = AutoSchema() - inspector.view = view - - responses = inspector.get_responses(path, method) - assert responses == { - '200': { - 'description': '', - 'content': { - 'application/json': { - 'schema': { - 'type': 'object', - 'item': { - 'type': 'array', - 'items': { - '$ref': '#/components/schemas/Item' - }, - }, - }, - }, - }, - }, - } - components = inspector.get_components(path, method) - assert components == { - 'Item': { - 'type': 'object', - 'properties': { - 'text': { - 'type': 'string', - }, - }, - 'required': ['text'], - } - } - - def test_delete_response_body_generation(self): - """Test that a view's delete method generates a proper response body schema.""" - path = '/{id}/' - method = 'DELETE' - - class View(generics.DestroyAPIView): - serializer_class = views.ExampleSerializer - - view = create_view( - View, - method, - create_request(path), - ) - inspector = AutoSchema() - inspector.view = view - - responses = inspector.get_responses(path, method) - assert responses == { - '204': { - 'description': '', - }, - } - - def test_parser_mapping(self): - """Test that view's parsers are mapped to OA media types""" - path = '/{id}/' - method = 'POST' - - class View(generics.CreateAPIView): - serializer_class = views.ExampleSerializer - parser_classes = [JSONParser, MultiPartParser] - - view = create_view( - View, - method, - create_request(path), - ) - inspector = AutoSchema() - inspector.view = view - - request_body = inspector.get_request_body(path, method) - - assert len(request_body['content'].keys()) == 2 - assert 'multipart/form-data' in request_body['content'] - assert 'application/json' in request_body['content'] - - def test_renderer_mapping(self): - """Test that view's renderers are mapped to OA media types""" - path = '/{id}/' - method = 'GET' - - class CustomBrowsableAPIRenderer(BrowsableAPIRenderer): - media_type = 'image/jpeg' # that's a wild API renderer - - class TextRenderer(BaseRenderer): - media_type = 'text/plain' - format = 'text' - - class View(generics.CreateAPIView): - serializer_class = views.ExampleSerializer - renderer_classes = [JSONRenderer, TextRenderer, BrowsableAPIRenderer, CustomBrowsableAPIRenderer] - - view = create_view( - View, - method, - create_request(path), - ) - inspector = AutoSchema() - inspector.view = view - - responses = inspector.get_responses(path, method) - # TODO this should be changed once the multiple response - # schema support is there - success_response = responses['200'] - - # Check that the API renderers aren't included, but custom renderers are - assert set(success_response['content']) == {'application/json', 'text/plain'} - - def test_openapi_yaml_rendering_without_aliases(self): - renderer = OpenAPIRenderer() - - reused_object = {'test': 'test'} - data = { - 'o1': reused_object, - 'o2': reused_object, - } - assert ( - renderer.render(data) == b'o1:\n test: test\no2:\n test: test\n' or - renderer.render(data) == b'o2:\n test: test\no1:\n test: test\n' # py <= 3.5 - ) - - def test_openapi_yaml_safestring_render(self): - renderer = OpenAPIRenderer() - data = {'o1': SafeString('test')} - assert renderer.render(data) == b'o1: test\n' - - def test_serializer_filefield(self): - path = '/{id}/' - method = 'POST' - - class ItemSerializer(serializers.Serializer): - attachment = serializers.FileField() - - class View(generics.CreateAPIView): - serializer_class = ItemSerializer - - view = create_view( - View, - method, - create_request(path), - ) - inspector = AutoSchema() - inspector.view = view - - components = inspector.get_components(path, method) - component = components['Item'] - properties = component['properties'] - assert properties['attachment']['format'] == 'binary' - - def test_retrieve_response_body_generation(self): - """ - Test that a list of properties is returned for retrieve item views. - - Pagination properties should not be added as the view represents a single item. - """ - path = '/{id}/' - method = 'GET' - - class Pagination(pagination.BasePagination): - def get_paginated_response_schema(self, schema): - return { - 'type': 'object', - 'item': schema, - } - - class ItemSerializer(serializers.Serializer): - text = serializers.CharField() - - class View(generics.GenericAPIView): - serializer_class = ItemSerializer - pagination_class = Pagination - - view = create_view( - View, - method, - create_request(path), - ) - inspector = AutoSchema() - inspector.view = view - - responses = inspector.get_responses(path, method) - assert responses == { - '200': { - 'description': '', - 'content': { - 'application/json': { - 'schema': { - '$ref': '#/components/schemas/Item' - }, - }, - }, - }, - } - - components = inspector.get_components(path, method) - assert components == { - 'Item': { - 'type': 'object', - 'properties': { - 'text': { - 'type': 'string', - }, - }, - 'required': ['text'], - } - } - - def test_operation_id_generation(self): - path = '/' - method = 'GET' - - view = create_view( - views.ExampleGenericAPIView, - method, - create_request(path), - ) - inspector = AutoSchema() - inspector.view = view - - operationId = inspector.get_operation_id(path, method) - assert operationId == 'listExamples' - - def test_operation_id_custom_operation_id_base(self): - path = '/' - method = 'GET' - - view = create_view( - views.ExampleGenericAPIView, - method, - create_request(path), - ) - inspector = AutoSchema(operation_id_base="Ulysse") - inspector.view = view - - operationId = inspector.get_operation_id(path, method) - assert operationId == 'listUlysses' - - def test_operation_id_custom_name(self): - path = '/' - method = 'GET' - - view = create_view( - views.ExampleGenericAPIView, - method, - create_request(path), - ) - inspector = AutoSchema(operation_id_base='Ulysse') - inspector.view = view - - operationId = inspector.get_operation_id(path, method) - assert operationId == 'listUlysses' - - def test_operation_id_plural(self): - path = '/' - method = 'GET' - - view = create_view( - views.ExampleGenericAPIView, - method, - create_request(path), - ) - inspector = AutoSchema(operation_id_base='City') - inspector.view = view - - operationId = inspector.get_operation_id(path, method) - assert operationId == 'listCities' - - def test_operation_id_override_get(self): - class CustomSchema(AutoSchema): - def get_operation_id(self, path, method): - return 'myCustomOperationId' - - path = '/' - method = 'GET' - view = create_view( - views.ExampleGenericAPIView, - method, - create_request(path), - ) - inspector = CustomSchema() - inspector.view = view - - operationId = inspector.get_operation_id(path, method) - assert operationId == 'myCustomOperationId' - - def test_operation_id_override_base(self): - class CustomSchema(AutoSchema): - def get_operation_id_base(self, path, method, action): - return 'Item' - - path = '/' - method = 'GET' - view = create_view( - views.ExampleGenericAPIView, - method, - create_request(path), - ) - inspector = CustomSchema() - inspector.view = view - - operationId = inspector.get_operation_id(path, method) - assert operationId == 'listItem' - - def test_different_request_response_objects(self): - class RequestSerializer(serializers.Serializer): - text = serializers.CharField() - - class ResponseSerializer(serializers.Serializer): - text = serializers.BooleanField() - - class CustomSchema(AutoSchema): - def get_request_serializer(self, path, method): - return RequestSerializer() - - def get_response_serializer(self, path, method): - return ResponseSerializer() - - path = '/' - method = 'POST' - view = create_view( - views.ExampleGenericAPIView, - method, - create_request(path), - ) - inspector = CustomSchema() - inspector.view = view - - components = inspector.get_components(path, method) - assert components == { - 'Request': { - 'properties': { - 'text': { - 'type': 'string' - } - }, - 'required': ['text'], - 'type': 'object' - }, - 'Response': { - 'properties': { - 'text': { - 'type': 'boolean' - } - }, - 'required': ['text'], - 'type': 'object' - } - } - - operation = inspector.get_operation(path, method) - assert operation == { - 'operationId': 'createExample', - 'description': '', - 'parameters': [], - 'requestBody': { - 'content': { - 'application/json': { - 'schema': { - '$ref': '#/components/schemas/Request' - } - }, - 'application/x-www-form-urlencoded': { - 'schema': { - '$ref': '#/components/schemas/Request' - } - }, - 'multipart/form-data': { - 'schema': { - '$ref': '#/components/schemas/Request' - } - } - } - }, - 'responses': { - '201': { - 'content': { - 'application/json': { - 'schema': { - '$ref': '#/components/schemas/Response' - } - } - }, - 'description': '' - } - }, - 'tags': [''] - } - - def test_repeat_operation_ids(self): - router = routers.SimpleRouter() - router.register('account', views.ExampleGenericViewSet, basename="account") - urlpatterns = router.urls - - generator = SchemaGenerator(patterns=urlpatterns) - - request = create_request('/') - schema = generator.get_schema(request=request) - schema_str = str(schema) - print(schema_str) - assert schema_str.count("operationId") == 2 - assert schema_str.count("newExample") == 1 - assert schema_str.count("oldExample") == 1 - - def test_duplicate_operation_id(self): - patterns = [ - path('duplicate1/', views.ExampleOperationIdDuplicate1.as_view()), - path('duplicate2/', views.ExampleOperationIdDuplicate2.as_view()), - ] - - generator = SchemaGenerator(patterns=patterns) - request = create_request('/') - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') - generator.get_schema(request=request) - - assert len(w) == 1 - assert issubclass(w[-1].category, UserWarning) - print(str(w[-1].message)) - assert 'You have a duplicated operationId' in str(w[-1].message) - - def test_operation_id_viewset(self): - router = routers.SimpleRouter() - router.register('account', views.ExampleViewSet, basename="account") - urlpatterns = router.urls - - generator = SchemaGenerator(patterns=urlpatterns) - - request = create_request('/') - schema = generator.get_schema(request=request) - print(schema) - assert schema['paths']['/account/']['get']['operationId'] == 'listExampleViewSets' - assert schema['paths']['/account/']['post']['operationId'] == 'createExampleViewSet' - assert schema['paths']['/account/{id}/']['get']['operationId'] == 'retrieveExampleViewSet' - assert schema['paths']['/account/{id}/']['put']['operationId'] == 'updateExampleViewSet' - assert schema['paths']['/account/{id}/']['patch']['operationId'] == 'partialUpdateExampleViewSet' - assert schema['paths']['/account/{id}/']['delete']['operationId'] == 'destroyExampleViewSet' - - def test_serializer_datefield(self): - path = '/' - method = 'GET' - view = create_view( - views.ExampleGenericAPIView, - method, - create_request(path), - ) - inspector = AutoSchema() - inspector.view = view - - components = inspector.get_components(path, method) - component = components['Example'] - properties = component['properties'] - assert properties['date']['type'] == properties['datetime']['type'] == 'string' - assert properties['date']['format'] == 'date' - assert properties['datetime']['format'] == 'date-time' - - def test_serializer_hstorefield(self): - path = '/' - method = 'GET' - view = create_view( - views.ExampleGenericAPIView, - method, - create_request(path), - ) - inspector = AutoSchema() - inspector.view = view - - components = inspector.get_components(path, method) - component = components['Example'] - properties = component['properties'] - assert properties['hstore']['type'] == 'object' - - def test_serializer_callable_default(self): - path = '/' - method = 'GET' - view = create_view( - views.ExampleGenericAPIView, - method, - create_request(path), - ) - inspector = AutoSchema() - inspector.view = view - - components = inspector.get_components(path, method) - component = components['Example'] - properties = component['properties'] - assert 'default' not in properties['uuid_field'] - - def test_serializer_validators(self): - path = '/' - method = 'GET' - view = create_view( - views.ExampleValidatedAPIView, - method, - create_request(path), - ) - inspector = AutoSchema() - inspector.view = view - - components = inspector.get_components(path, method) - component = components['ExampleValidated'] - properties = component['properties'] - - assert properties['integer']['type'] == 'integer' - assert properties['integer']['maximum'] == 99 - assert properties['integer']['minimum'] == -11 - - assert properties['string']['minLength'] == 2 - assert properties['string']['maxLength'] == 10 - - assert properties['lst']['minItems'] == 2 - assert properties['lst']['maxItems'] == 10 - - assert properties['regex']['pattern'] == r'[ABC]12{3}' - assert properties['regex']['description'] == 'must have an A, B, or C followed by 1222' - - assert properties['decimal1']['type'] == 'number' - assert properties['decimal1']['multipleOf'] == .01 - assert properties['decimal1']['maximum'] == 10000 - assert properties['decimal1']['minimum'] == -10000 - - assert properties['decimal2']['type'] == 'number' - assert properties['decimal2']['multipleOf'] == .0001 - - assert properties['decimal3'] == { - 'type': 'string', 'format': 'decimal', 'maximum': 1000000, 'minimum': -1000000, 'multipleOf': 0.01 - } - assert properties['decimal4'] == { - 'type': 'string', 'format': 'decimal', 'maximum': 1000000, 'minimum': -1000000, 'multipleOf': 0.01 - } - assert properties['decimal5'] == { - 'type': 'string', 'format': 'decimal', 'maximum': 10000, 'minimum': -10000, 'multipleOf': 0.01 - } - - assert properties['email']['type'] == 'string' - assert properties['email']['format'] == 'email' - assert properties['email']['default'] == 'foo@bar.com' - - assert properties['url']['type'] == 'string' - assert properties['url']['nullable'] is True - assert properties['url']['default'] == 'http://www.example.com' - assert '\\Z' not in properties['url']['pattern'] - - assert properties['uuid']['type'] == 'string' - assert properties['uuid']['format'] == 'uuid' - - assert properties['ip4']['type'] == 'string' - assert properties['ip4']['format'] == 'ipv4' - - assert properties['ip6']['type'] == 'string' - assert properties['ip6']['format'] == 'ipv6' - - assert properties['ip']['type'] == 'string' - assert 'format' not in properties['ip'] - - def test_overridden_tags(self): - class ExampleStringTagsViewSet(views.ExampleGenericAPIView): - schema = AutoSchema(tags=['example1', 'example2']) - - url_patterns = [ - path('test/', ExampleStringTagsViewSet.as_view()), - ] - generator = SchemaGenerator(patterns=url_patterns) - schema = generator.get_schema(request=create_request('/')) - assert schema['paths']['/test/']['get']['tags'] == ['example1', 'example2'] - - def test_overridden_get_tags_method(self): - class MySchema(AutoSchema): - def get_tags(self, path, method): - if path.endswith('/new/'): - tags = ['tag1', 'tag2'] - elif path.endswith('/old/'): - tags = ['tag2', 'tag3'] - else: - tags = ['tag4', 'tag5'] - - return tags - - class ExampleStringTagsViewSet(views.ExampleGenericViewSet): - schema = MySchema() - - router = routers.SimpleRouter() - router.register('example', ExampleStringTagsViewSet, basename="example") - generator = SchemaGenerator(patterns=router.urls) - schema = generator.get_schema(request=create_request('/')) - assert schema['paths']['/example/new/']['get']['tags'] == ['tag1', 'tag2'] - assert schema['paths']['/example/old/']['get']['tags'] == ['tag2', 'tag3'] - - def test_auto_generated_apiview_tags(self): - class RestaurantAPIView(views.ExampleGenericAPIView): - schema = AutoSchema(operation_id_base="restaurant") - pass - - class BranchAPIView(views.ExampleGenericAPIView): - pass - - url_patterns = [ - path('any-dash_underscore/', RestaurantAPIView.as_view()), - path('restaurants/branches/', BranchAPIView.as_view()) - ] - generator = SchemaGenerator(patterns=url_patterns) - schema = generator.get_schema(request=create_request('/')) - assert schema['paths']['/any-dash_underscore/']['get']['tags'] == ['any-dash-underscore'] - assert schema['paths']['/restaurants/branches/']['get']['tags'] == ['restaurants'] - - -@pytest.mark.skipif(uritemplate is None, reason='uritemplate not installed.') -@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.openapi.AutoSchema'}) -class TestGenerator(TestCase): - - def test_override_settings(self): - assert isinstance(views.ExampleListView.schema, AutoSchema) - - def test_paths_construction(self): - """Construction of the `paths` key.""" - patterns = [ - path('example/', views.ExampleListView.as_view()), - ] - generator = SchemaGenerator(patterns=patterns) - generator._initialise_endpoints() - - paths = generator.get_schema()["paths"] - - assert '/example/' in paths - example_operations = paths['/example/'] - assert len(example_operations) == 2 - assert 'get' in example_operations - assert 'post' in example_operations - - def test_prefixed_paths_construction(self): - """Construction of the `paths` key maintains a common prefix.""" - patterns = [ - path('v1/example/', views.ExampleListView.as_view()), - path('v1/example/{pk}/', views.ExampleDetailView.as_view()), - ] - generator = SchemaGenerator(patterns=patterns) - generator._initialise_endpoints() - - paths = generator.get_schema()["paths"] - - assert '/v1/example/' in paths - assert '/v1/example/{id}/' in paths - - def test_mount_url_prefixed_to_paths(self): - patterns = [ - path('example/', views.ExampleListView.as_view()), - path('example/{pk}/', views.ExampleDetailView.as_view()), - ] - generator = SchemaGenerator(patterns=patterns, url='/api') - generator._initialise_endpoints() - - paths = generator.get_schema()["paths"] - - assert '/api/example/' in paths - assert '/api/example/{id}/' in paths - - def test_schema_construction(self): - """Construction of the top level dictionary.""" - patterns = [ - path('example/', views.ExampleListView.as_view()), - ] - generator = SchemaGenerator(patterns=patterns) - - request = create_request('/') - schema = generator.get_schema(request=request) - - assert 'openapi' in schema - assert 'paths' in schema - - def test_schema_rendering_to_json(self): - patterns = [ - path('example/', views.ExampleGenericAPIView.as_view()), - ] - generator = SchemaGenerator(patterns=patterns) - - request = create_request('/') - schema = generator.get_schema(request=request) - ret = JSONOpenAPIRenderer().render(schema) - - assert b'"openapi": "' in ret - assert b'"default": "0.0"' in ret - - def test_schema_rendering_to_yaml(self): - patterns = [ - path('example/', views.ExampleGenericAPIView.as_view()), - ] - generator = SchemaGenerator(patterns=patterns) - - request = create_request('/') - schema = generator.get_schema(request=request) - ret = OpenAPIRenderer().render(schema) - assert b"openapi: " in ret - assert b"default: '0.0'" in ret - - def test_schema_rendering_timedelta_to_yaml_with_validator(self): - - patterns = [ - path('example/', views.ExampleValidatedAPIView.as_view()), - ] - generator = SchemaGenerator(patterns=patterns) - - request = create_request('/') - schema = generator.get_schema(request=request) - ret = OpenAPIRenderer().render(schema) - assert b"openapi: " in ret - assert b"duration:\n type: string\n minimum: \'10.0\'\n" in ret - - def test_schema_with_no_paths(self): - patterns = [] - generator = SchemaGenerator(patterns=patterns) - - request = create_request('/') - schema = generator.get_schema(request=request) - - assert schema['paths'] == {} - - def test_schema_information(self): - """Construction of the top level dictionary.""" - patterns = [ - path('example/', views.ExampleListView.as_view()), - ] - generator = SchemaGenerator(patterns=patterns, title='My title', version='1.2.3', description='My description') - - request = create_request('/') - schema = generator.get_schema(request=request) - - assert schema['info']['title'] == 'My title' - assert schema['info']['version'] == '1.2.3' - assert schema['info']['description'] == 'My description' - - def test_schema_information_empty(self): - """Construction of the top level dictionary.""" - patterns = [ - path('example/', views.ExampleListView.as_view()), - ] - generator = SchemaGenerator(patterns=patterns) - - request = create_request('/') - schema = generator.get_schema(request=request) - - assert schema['info']['title'] == '' - assert schema['info']['version'] == '' - - def test_serializer_model(self): - """Construction of the top level dictionary.""" - patterns = [ - path('example/', views.ExampleGenericAPIViewModel.as_view()), - ] - - generator = SchemaGenerator(patterns=patterns) - - request = create_request('/') - schema = generator.get_schema(request=request) - - print(schema) - - assert 'components' in schema - assert 'schemas' in schema['components'] - assert 'ExampleModel' in schema['components']['schemas'] - - def test_authtoken_serializer(self): - patterns = [ - path('api-token-auth/', obtain_auth_token) - ] - generator = SchemaGenerator(patterns=patterns) - - request = create_request('/') - schema = generator.get_schema(request=request) - - print(schema) - - route = schema['paths']['/api-token-auth/']['post'] - body_schema = route['requestBody']['content']['application/json']['schema'] - - assert body_schema == { - '$ref': '#/components/schemas/AuthToken' - } - assert schema['components']['schemas']['AuthToken'] == { - 'type': 'object', - 'properties': { - 'username': {'type': 'string', 'writeOnly': True}, - 'password': {'type': 'string', 'writeOnly': True}, - 'token': {'type': 'string', 'readOnly': True}, - }, - 'required': ['username', 'password'] - } - - def test_component_name(self): - patterns = [ - path('example/', views.ExampleAutoSchemaComponentName.as_view()), - ] - - generator = SchemaGenerator(patterns=patterns) - - request = create_request('/') - schema = generator.get_schema(request=request) - - print(schema) - assert 'components' in schema - assert 'schemas' in schema['components'] - assert 'Ulysses' in schema['components']['schemas'] - - def test_duplicate_component_name(self): - patterns = [ - path('duplicate1/', views.ExampleAutoSchemaDuplicate1.as_view()), - path('duplicate2/', views.ExampleAutoSchemaDuplicate2.as_view()), - ] - - generator = SchemaGenerator(patterns=patterns) - request = create_request('/') - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') - schema = generator.get_schema(request=request) - - assert len(w) == 1 - assert issubclass(w[-1].category, UserWarning) - assert 'has been overridden with a different value.' in str(w[-1].message) - - assert 'components' in schema - assert 'schemas' in schema['components'] - assert 'Duplicate' in schema['components']['schemas'] - - def test_component_should_not_be_generated_for_delete_method(self): - class ExampleView(generics.DestroyAPIView): - schema = AutoSchema(operation_id_base='example') - - url_patterns = [ - path('example/', ExampleView.as_view()), - ] - generator = SchemaGenerator(patterns=url_patterns) - schema = generator.get_schema(request=create_request('/')) - assert 'components' not in schema - assert 'content' not in schema['paths']['/example/']['delete']['responses']['204'] diff --git a/tests/schemas/views.py b/tests/schemas/views.py deleted file mode 100644 index c08208bf26..0000000000 --- a/tests/schemas/views.py +++ /dev/null @@ -1,250 +0,0 @@ -import uuid -from datetime import timedelta - -from django.core.validators import ( - DecimalValidator, MaxLengthValidator, MaxValueValidator, - MinLengthValidator, MinValueValidator, RegexValidator -) -from django.db import models - -from rest_framework import generics, permissions, serializers -from rest_framework.decorators import action -from rest_framework.response import Response -from rest_framework.schemas.openapi import AutoSchema -from rest_framework.views import APIView -from rest_framework.viewsets import GenericViewSet, ViewSet - - -class ExampleListView(APIView): - permission_classes = [permissions.IsAuthenticatedOrReadOnly] - - def get(self, *args, **kwargs): - pass - - def post(self, request, *args, **kwargs): - pass - - -class ExampleDetailView(APIView): - permission_classes = [permissions.IsAuthenticatedOrReadOnly] - - def get(self, *args, **kwargs): - pass - - -class DocStringExampleListView(APIView): - """ - get: A description of my GET operation. - post: A description of my POST operation. - """ - permission_classes = [permissions.IsAuthenticatedOrReadOnly] - - def get(self, *args, **kwargs): - pass - - def post(self, request, *args, **kwargs): - pass - - -class DocStringExampleDetailView(APIView): - permission_classes = [permissions.IsAuthenticatedOrReadOnly] - - def get(self, *args, **kwargs): - """ - A description of my GET operation. - """ - pass - - -# Generics. -class ExampleSerializer(serializers.Serializer): - date = serializers.DateField() - datetime = serializers.DateTimeField() - duration = serializers.DurationField(default=timedelta()) - hstore = serializers.HStoreField() - uuid_field = serializers.UUIDField(default=uuid.uuid4) - - -class ExampleGenericAPIView(generics.GenericAPIView): - serializer_class = ExampleSerializer - - def get(self, *args, **kwargs): - from datetime import datetime - now = datetime.now() - - serializer = self.get_serializer(data=now.date(), datetime=now) - return Response(serializer.data) - - -class ExampleGenericViewSet(GenericViewSet): - serializer_class = ExampleSerializer - - def get(self, *args, **kwargs): - from datetime import datetime - now = datetime.now() - - serializer = self.get_serializer(data=now.date(), datetime=now) - return Response(serializer.data) - - @action(detail=False) - def new(self, *args, **kwargs): - pass - - @action(detail=False) - def old(self, *args, **kwargs): - pass - - -# Validators and/or equivalent Field attributes. -class ExampleValidatedSerializer(serializers.Serializer): - integer = serializers.IntegerField( - validators=( - MaxValueValidator(limit_value=99), - MinValueValidator(limit_value=-11), - ) - ) - string = serializers.CharField( - validators=( - MaxLengthValidator(limit_value=10), - MinLengthValidator(limit_value=2), - ) - ) - regex = serializers.CharField( - validators=( - RegexValidator(regex=r'[ABC]12{3}'), - ), - help_text='must have an A, B, or C followed by 1222' - ) - lst = serializers.ListField( - validators=( - MaxLengthValidator(limit_value=10), - MinLengthValidator(limit_value=2), - ) - ) - decimal1 = serializers.DecimalField(max_digits=6, decimal_places=2, coerce_to_string=False) - decimal2 = serializers.DecimalField(max_digits=5, decimal_places=0, coerce_to_string=False, - validators=(DecimalValidator(max_digits=17, decimal_places=4),)) - decimal3 = serializers.DecimalField(max_digits=8, decimal_places=2, coerce_to_string=True) - decimal4 = serializers.DecimalField(max_digits=8, decimal_places=2, coerce_to_string=True, - validators=(DecimalValidator(max_digits=17, decimal_places=4),)) - decimal5 = serializers.DecimalField(max_digits=6, decimal_places=2) - email = serializers.EmailField(default='foo@bar.com') - url = serializers.URLField(default='http://www.example.com', allow_null=True) - uuid = serializers.UUIDField() - ip4 = serializers.IPAddressField(protocol='ipv4') - ip6 = serializers.IPAddressField(protocol='ipv6') - ip = serializers.IPAddressField() - duration = serializers.DurationField( - validators=( - MinValueValidator(timedelta(seconds=10)), - ) - ) - - -class ExampleValidatedAPIView(generics.GenericAPIView): - serializer_class = ExampleValidatedSerializer - - def get(self, *args, **kwargs): - serializer = self.get_serializer(integer=33, string='hello', regex='foo', decimal1=3.55, - decimal2=5.33, email='a@b.co', - url='http://localhost', uuid=uuid.uuid4(), ip4='127.0.0.1', ip6='::1', - ip='192.168.1.1') - return Response(serializer.data) - - -# Serializer with model. -class OpenAPIExample(models.Model): - first_name = models.CharField(max_length=30) - - -class ExampleSerializerModel(serializers.Serializer): - date = serializers.DateField() - datetime = serializers.DateTimeField() - hstore = serializers.HStoreField() - uuid_field = serializers.UUIDField(default=uuid.uuid4) - - class Meta: - model = OpenAPIExample - - -class ExampleOperationIdDuplicate1(generics.GenericAPIView): - serializer_class = ExampleSerializerModel - - def get(self, *args, **kwargs): - pass - - -class ExampleOperationIdDuplicate2(generics.GenericAPIView): - serializer_class = ExampleSerializerModel - - def get(self, *args, **kwargs): - pass - - -class ExampleGenericAPIViewModel(generics.GenericAPIView): - serializer_class = ExampleSerializerModel - - def get(self, *args, **kwargs): - from datetime import datetime - now = datetime.now() - - serializer = self.get_serializer(data=now.date(), datetime=now) - return Response(serializer.data) - - -class ExampleAutoSchemaComponentName(generics.GenericAPIView): - serializer_class = ExampleSerializerModel - schema = AutoSchema(component_name="Ulysses") - - def get(self, *args, **kwargs): - from datetime import datetime - now = datetime.now() - - serializer = self.get_serializer(data=now.date(), datetime=now) - return Response(serializer.data) - - -class ExampleAutoSchemaDuplicate1(generics.GenericAPIView): - serializer_class = ExampleValidatedSerializer - schema = AutoSchema(component_name="Duplicate") - - def get(self, *args, **kwargs): - from datetime import datetime - now = datetime.now() - - serializer = self.get_serializer(data=now.date(), datetime=now) - return Response(serializer.data) - - -class ExampleAutoSchemaDuplicate2(generics.GenericAPIView): - serializer_class = ExampleSerializerModel - schema = AutoSchema(component_name="Duplicate") - - def get(self, *args, **kwargs): - from datetime import datetime - now = datetime.now() - - serializer = self.get_serializer(data=now.date(), datetime=now) - return Response(serializer.data) - - -class ExampleViewSet(ViewSet): - serializer_class = ExampleSerializerModel - - def list(self, request): - pass - - def create(self, request): - pass - - def retrieve(self, request, pk=None): - pass - - def update(self, request, pk=None): - pass - - def partial_update(self, request, pk=None): - pass - - def destroy(self, request, pk=None): - pass diff --git a/tests/test_renderers.py b/tests/test_renderers.py index 1b396575d4..6b51ff34cc 100644 --- a/tests/test_renderers.py +++ b/tests/test_renderers.py @@ -5,18 +5,16 @@ from django.core.cache import cache from django.db import models from django.http.request import HttpRequest -from django.template import loader from django.test import TestCase, override_settings from django.urls import include, path, re_path from django.utils.safestring import SafeText from django.utils.translation import gettext_lazy as _ from rest_framework import permissions, serializers, status -from rest_framework.compat import coreapi from rest_framework.decorators import action from rest_framework.renderers import ( - AdminRenderer, BaseRenderer, BrowsableAPIRenderer, DocumentationRenderer, - HTMLFormRenderer, JSONRenderer, SchemaJSRenderer, StaticHTMLRenderer + AdminRenderer, BaseRenderer, BrowsableAPIRenderer, HTMLFormRenderer, + JSONRenderer, StaticHTMLRenderer ) from rest_framework.request import Request from rest_framework.response import Response @@ -871,61 +869,3 @@ def reverse_action(view, url_name, args=None, kwargs=None): self.assertEqual(results[1]['url'], '/example') self.assertEqual(results[2]['url'], None) self.assertNotIn('url', results[3]) - - -@pytest.mark.skipif(not coreapi, reason='coreapi is not installed') -class TestDocumentationRenderer(TestCase): - - def test_document_with_link_named_data(self): - """ - Ref #5395: Doc's `document.data` would fail with a Link named "data". - As per #4972, use templatetag instead. - """ - document = coreapi.Document( - title='Data Endpoint API', - url='https://api.example.org/', - content={ - 'data': coreapi.Link( - url='/data/', - action='get', - fields=[], - description='Return data.' - ) - } - ) - - factory = APIRequestFactory() - request = factory.get('/') - - renderer = DocumentationRenderer() - - html = renderer.render(document, accepted_media_type="text/html", renderer_context={"request": request}) - assert '

Data Endpoint API

' in html - - def test_shell_code_example_rendering(self): - template = loader.get_template('rest_framework/docs/langs/shell.html') - context = { - 'document': coreapi.Document(url='https://api.example.org/'), - 'link_key': 'testcases > list', - 'link': coreapi.Link(url='/data/', action='get', fields=[]), - } - html = template.render(context) - assert 'testcases list' in html - - -@pytest.mark.skipif(not coreapi, reason='coreapi is not installed') -class TestSchemaJSRenderer(TestCase): - - def test_schemajs_output(self): - """ - Test output of the SchemaJS renderer as per #5608. Django 2.0 on Py3 prints binary data as b'xyz' in templates, - and the base64 encoding used by SchemaJSRenderer outputs base64 as binary. Test fix. - """ - factory = APIRequestFactory() - request = factory.get('/') - - renderer = SchemaJSRenderer() - - output = renderer.render('data', renderer_context={"request": request}) - assert "'ImRhdGEi'" in output - assert "'b'ImRhdGEi''" not in output