Skip to content

Commit

Permalink
Improve handling of consumes and produces attributes (#55)
Browse files Browse the repository at this point in the history
* Fix get_consumes
* Generate produces for Operation
* Set global consumes and produces from rest framework DEFAULT_ settings
  • Loading branch information
axnsan12 authored Jan 24, 2018
1 parent a46b684 commit a3e81ef
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 65 deletions.
37 changes: 25 additions & 12 deletions src/drf_yasg/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
from rest_framework.schemas.generators import EndpointEnumerator as _EndpointEnumerator
from rest_framework.schemas.generators import SchemaGenerator, endpoint_ordering
from rest_framework.schemas.inspectors import get_pk_description

from drf_yasg.errors import SwaggerGenerationError
from rest_framework.settings import api_settings as rest_framework_settings

from . import openapi
from .app_settings import swagger_settings
from .errors import SwaggerGenerationError
from .inspectors.field import get_basic_type_info, get_queryset_field
from .openapi import ReferenceResolver
from .utils import get_consumes, get_produces

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -165,6 +166,9 @@ def __init__(self, info, version='', url=None, patterns=None, urlconf=None):
self._gen = SchemaGenerator(info.title, url, info.get('description', ''), patterns, urlconf)
self.info = info
self.version = version
self.consumes = []
self.produces = []

if url is None and swagger_settings.DEFAULT_API_URL is not None:
url = swagger_settings.DEFAULT_API_URL

Expand All @@ -191,22 +195,24 @@ def get_schema(self, request=None, public=False):
"""
endpoints = self.get_endpoints(request)
components = ReferenceResolver(openapi.SCHEMA_DEFINITIONS)
self.consumes = get_consumes(rest_framework_settings.DEFAULT_PARSER_CLASSES)
self.produces = get_produces(rest_framework_settings.DEFAULT_RENDERER_CLASSES)
paths, prefix = self.get_paths(endpoints, components, request, public)

security_definitions = swagger_settings.SECURITY_DEFINITIONS
security_requirements = swagger_settings.SECURITY_REQUIREMENTS
if security_requirements is None:
security_requirements = [{security_scheme: [] for security_scheme in swagger_settings.SECURITY_DEFINITIONS}]

url = self.url
if url is None and request is not None:
url = request.build_absolute_uri()

swagger = openapi.Swagger(
info=self.info, paths=paths,
return openapi.Swagger(
info=self.info, paths=paths, consumes=self.consumes or None, produces=self.produces or None,
security_definitions=security_definitions, security=security_requirements,
_url=url, _prefix=prefix, _version=self.version, **dict(components)
)
swagger.security_definitions = swagger_settings.SECURITY_DEFINITIONS
security_requirements = swagger_settings.SECURITY_REQUIREMENTS
if security_requirements is None:
security_requirements = [{security_scheme: [] for security_scheme in swagger_settings.SECURITY_DEFINITIONS}]
swagger.security = security_requirements
return swagger

def create_view(self, callback, method, request=None):
"""Create a view instance from a view callback as registered in urlpatterns.
Expand Down Expand Up @@ -330,7 +336,6 @@ def get_operation(self, view, path, prefix, method, components, request):
:param Request request: the request made against the schema view; can be None
:rtype: openapi.Operation
"""

operation_keys = self.get_operation_keys(path[len(prefix):], method, view)
overrides = self.get_overrides(view, method)

Expand All @@ -342,8 +347,16 @@ def get_operation(self, view, path, prefix, method, components, request):
# 3. on the swagger_auto_schema decorator
view_inspector_cls = overrides.get('auto_schema', view_inspector_cls)

if view_inspector_cls is None:
return None

view_inspector = view_inspector_cls(view, path, method, components, request, overrides)
return view_inspector.get_operation(operation_keys)
operation = view_inspector.get_operation(operation_keys)
if set(operation.consumes) == set(self.consumes):
del operation.consumes
if set(operation.produces) == set(self.produces):
del operation.produces
return operation

def get_path_item(self, path, view_cls, operations):
"""Get a :class:`.PathItem` object that describes the parameters and operations related to a single path in the
Expand Down
21 changes: 15 additions & 6 deletions src/drf_yasg/inspectors/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@

from .. import openapi
from ..errors import SwaggerGenerationError
from ..utils import force_serializer_instance, guess_response_status, is_list_view, no_body, param_list_to_odict
from ..utils import (
force_serializer_instance, get_consumes, get_produces, guess_response_status, is_list_view, no_body,
param_list_to_odict
)
from .base import ViewInspector


Expand All @@ -18,6 +21,7 @@ def __init__(self, view, path, method, components, request, overrides):

def get_operation(self, operation_keys):
consumes = self.get_consumes()
produces = self.get_produces()

body = self.get_request_body_parameters(consumes)
query = self.get_query_parameters()
Expand All @@ -39,6 +43,7 @@ def get_operation(self, operation_keys):
responses=responses,
parameters=parameters,
consumes=consumes,
produces=produces,
tags=tags,
security=security
)
Expand Down Expand Up @@ -296,7 +301,7 @@ def get_security(self):
authentication schemes). Returning ``None`` will inherit the top-level secuirty requirements.
:return: security requirements
:rtype: list"""
:rtype: list[dict[str,list[str]]]"""
return self.overrides.get('security', None)

def get_tags(self, operation_keys):
Expand All @@ -314,7 +319,11 @@ def get_consumes(self):
:rtype: list[str]
"""
media_types = [parser.media_type for parser in getattr(self.view, 'parser_classes', [])]
if all(is_form_media_type(encoding) for encoding in media_types):
return media_types
return media_types[:1]
return get_consumes(getattr(self.view, 'parser_classes', []))

def get_produces(self):
"""Return the MIME types this endpoint can produce.
:rtype: list[str]
"""
return get_produces(getattr(self.view, 'renderer_classes', []))
17 changes: 14 additions & 3 deletions src/drf_yasg/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,14 +211,19 @@ def __init__(self, title, default_version, description=None, terms_of_service=No


class Swagger(SwaggerDict):
def __init__(self, info=None, _url=None, _prefix=None, _version=None, paths=None, definitions=None, **extra):
def __init__(self, info=None, _url=None, _prefix=None, _version=None, consumes=None, produces=None,
security_definitions=None, security=None, paths=None, definitions=None, **extra):
"""Root Swagger object.
:param .Info info: info object
:param str _url: URL used for setting the API host and scheme
:param str _prefix: api path prefix to use in setting basePath; this will be appended to the wsgi
SCRIPT_NAME prefix or Django's FORCE_SCRIPT_NAME if applicable
:param str _version: version string to override Info
:param list[dict] security_definitions: list of supported authentication mechanisms
:param list[dict] security: authentication mechanisms accepted by default; can be overriden in Operation
:param list[str] consumes: consumed MIME types; can be overriden in Operation
:param list[str] produces: produced MIME types; can be overriden in Operation
:param .Paths paths: paths object
:param dict[str,.Schema] definitions: named models
"""
Expand All @@ -234,6 +239,10 @@ def __init__(self, info=None, _url=None, _prefix=None, _version=None, paths=None
self.schemes = [url.scheme]

self.base_path = self.get_base_path(get_script_prefix(), _prefix)
self.consumes = consumes
self.produces = produces
self.security_definitions = filter_none(security_definitions)
self.security = filter_none(security)
self.paths = paths
self.definitions = filter_none(definitions)
self._insert_extras__()
Expand Down Expand Up @@ -304,8 +313,8 @@ def __init__(self, get=None, put=None, post=None, delete=None, options=None,


class Operation(SwaggerDict):
def __init__(self, operation_id, responses, parameters=None, consumes=None,
produces=None, summary=None, description=None, tags=None, **extra):
def __init__(self, operation_id, responses, parameters=None, consumes=None, produces=None, summary=None,
description=None, tags=None, security=None, **extra):
"""Information about an API operation (path + http method combination)
:param str operation_id: operation ID, should be unique across all operations
Expand All @@ -316,6 +325,7 @@ def __init__(self, operation_id, responses, parameters=None, consumes=None,
:param str summary: operation summary; should be < 120 characters
:param str description: operation description; can be of any length and supports markdown
:param list[str] tags: operation tags
:param list[dict[str,list[str]]] security: list of security requirements
"""
super(Operation, self).__init__(**extra)
self.operation_id = operation_id
Expand All @@ -326,6 +336,7 @@ def __init__(self, operation_id, responses, parameters=None, consumes=None,
self.consumes = filter_none(consumes)
self.produces = filter_none(produces)
self.tags = filter_none(tags)
self.security = filter_none(security)
self._insert_extras__()


Expand Down
28 changes: 28 additions & 0 deletions src/drf_yasg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from rest_framework import serializers, status
from rest_framework.mixins import DestroyModelMixin, RetrieveModelMixin, UpdateModelMixin
from rest_framework.request import is_form_media_type
from rest_framework.views import APIView

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -248,3 +249,30 @@ def force_serializer_instance(serializer):
assert isinstance(serializer, serializers.BaseSerializer), \
"Serializer class or instance required, not %s" % type(serializer).__name__
return serializer


def get_consumes(parser_classes):
"""Extract ``consumes`` MIME types from a list of parser classes.
:param list parser_classes: parser classes
:return: MIME types for ``consumes``
:rtype: list[str]
"""
media_types = [parser.media_type for parser in parser_classes or []]
if all(is_form_media_type(encoding) for encoding in media_types):
return media_types
else:
media_types = [encoding for encoding in media_types if not is_form_media_type(encoding)]
return media_types


def get_produces(renderer_classes):
"""Extract ``produces`` MIME types from a list of renderer classes.
:param list renderer_classes: renderer classes
:return: MIME types for ``produces``
:rtype: list[str]
"""
media_types = [renderer.media_type for renderer in renderer_classes or []]
media_types = [encoding for encoding in media_types if 'html' not in encoding]
return media_types
3 changes: 2 additions & 1 deletion testproj/snippets/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from djangorestframework_camel_case.render import CamelCaseJSONRenderer
from inflection import camelize
from rest_framework import generics
from rest_framework.parsers import FormParser

from drf_yasg import openapi
from drf_yasg.inspectors import SwaggerAutoSchema
Expand All @@ -21,7 +22,7 @@ class SnippetList(generics.ListCreateAPIView):
queryset = Snippet.objects.all()
serializer_class = SnippetSerializer

parser_classes = (CamelCaseJSONParser,)
parser_classes = (FormParser, CamelCaseJSONParser,)
renderer_classes = (CamelCaseJSONRenderer,)
swagger_schema = CamelCaseOperationIDAutoSchema

Expand Down
Loading

0 comments on commit a3e81ef

Please sign in to comment.