diff --git a/src/drf_yasg/inspectors/field.py b/src/drf_yasg/inspectors/field.py index 4d456fbf..ec25d7a6 100644 --- a/src/drf_yasg/inspectors/field.py +++ b/src/drf_yasg/inspectors/field.py @@ -6,7 +6,6 @@ from contextlib import suppress from collections import OrderedDict from decimal import Decimal -from inspect import signature as inspect_signature import typing from django.core import validators @@ -193,7 +192,7 @@ def get_queryset_from_view(view, serializer=None): if queryset is not None and serializer is not None: # make sure the view is actually using *this* serializer - assert type(serializer) == call_view_method(view, 'get_serializer_class', 'serializer_class') + assert type(serializer) is call_view_method(view, 'get_serializer_class', 'serializer_class') return queryset except Exception: # pragma: no cover @@ -617,17 +616,19 @@ def field_to_swagger_object(self, field, swagger_object_type, use_references, ** return self.probe_field_inspectors(serializer, swagger_object_type, use_references, read_only=True) else: # look for Python 3.5+ style type hinting of the return value - hint_class = inspect_signature(method).return_annotation - - if not inspect.isclass(hint_class) and hasattr(hint_class, '__args__'): - hint_class = hint_class.__args__[0] - if inspect.isclass(hint_class) and not issubclass(hint_class, inspect._empty): - type_info = get_basic_type_info_from_hint(hint_class) - - if type_info is not None: - SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, - use_references, **kwargs) - return SwaggerType(**type_info) + hint_class = typing.get_type_hints(method).get('return') + + # annotations such as typing.Optional have an __instancecheck__ + # hook and will not look like classes, but `issubclass` needs + # a class as its first argument, so only in that case abort + if inspect.isclass(hint_class) and issubclass(hint_class, inspect._empty): + return NotHandled + + type_info = get_basic_type_info_from_hint(hint_class) + if type_info is not None: + SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, + use_references, **kwargs) + return SwaggerType(**type_info) return NotHandled diff --git a/src/drf_yasg/openapi.py b/src/drf_yasg/openapi.py index 0a2cfb8f..c4632173 100644 --- a/src/drf_yasg/openapi.py +++ b/src/drf_yasg/openapi.py @@ -88,7 +88,7 @@ class SwaggerDict(OrderedDict): def __init__(self, **attrs): super(SwaggerDict, self).__init__() self._extras__ = attrs - if type(self) == SwaggerDict: + if type(self) is SwaggerDict: self._insert_extras__() def __setattr__(self, key, value): @@ -441,6 +441,8 @@ def __init__(self, name, in_, description=None, required=None, schema=None, # path parameters must always be required assert required is not False, "path parameter cannot be optional" self.required = True + if self['in'] == IN_QUERY and type == TYPE_ARRAY: + self.collection_format = 'multi' if self['in'] != IN_BODY and schema is not None: raise AssertionError("schema can only be applied to a body Parameter, not %s" % type) if default and not type: @@ -516,7 +518,7 @@ def __init__(self, resolver, name, scope, expected_type, ignore_unresolved=False :param bool ignore_unresolved: do not throw if the referenced object does not exist """ super(_Ref, self).__init__() - assert not type(self) == _Ref, "do not instantiate _Ref directly" + assert not type(self) is _Ref, "do not instantiate _Ref directly" ref_name = "#/{scope}/{name}".format(scope=scope, name=name) if not ignore_unresolved: obj = resolver.get(name, scope) diff --git a/src/drf_yasg/utils.py b/src/drf_yasg/utils.py index c78a4821..480c35d0 100644 --- a/src/drf_yasg/utils.py +++ b/src/drf_yasg/utils.py @@ -442,7 +442,7 @@ def force_real_str(s, encoding='utf-8', strings_only=False, errors='strict'): """ if s is not None: s = force_str(s, encoding, strings_only, errors) - if type(s) != str: + if type(s) is not str: s = '' + s # Remove common indentation to get the correct Markdown rendering diff --git a/tests/test_schema_generator.py b/tests/test_schema_generator.py index 063e7c08..85999074 100644 --- a/tests/test_schema_generator.py +++ b/tests/test_schema_generator.py @@ -345,7 +345,7 @@ def retrieve(self, request, pk=None): ) swagger = generator.get_schema(None, True) property_schema = swagger["definitions"]["OptionalMethod"]["properties"]["x"] - assert property_schema == openapi.Schema(title='X', type=expected_type, readOnly=True) + assert property_schema == openapi.Schema(title='X', type=expected_type, readOnly=True, x_nullable=True) EXPECTED_DESCRIPTION = """\