Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Enum OAS generation (#3518) #3525

Merged
merged 13 commits into from
Nov 29, 2024
2 changes: 1 addition & 1 deletion docs/examples/openapi/plugins/swagger_ui_config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from litestar.openapi.plugins import SwaggerRenderPlugin

swagger_plugin = SwaggerRenderPlugin(version="5.1.3", path="/swagger")
swagger_plugin = SwaggerRenderPlugin(version="5.18.2", path="/swagger")
62 changes: 38 additions & 24 deletions litestar/_openapi/schema_generation/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from copy import copy
from datetime import date, datetime, time, timedelta
from decimal import Decimal
from enum import Enum, EnumMeta
from enum import Enum
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
from pathlib import Path
from typing import (
Expand Down Expand Up @@ -40,9 +40,7 @@
create_string_constrained_field_schema,
)
from litestar._openapi.schema_generation.utils import (
_should_create_enum_schema,
_should_create_literal_schema,
_type_or_first_not_none_inner_type,
get_json_schema_formatted_examples,
)
from litestar.datastructures import SecretBytes, SecretString, UploadFile
Expand Down Expand Up @@ -181,22 +179,6 @@ def _get_type_schema_name(field_definition: FieldDefinition) -> str:
return name


def create_enum_schema(annotation: EnumMeta, include_null: bool = False) -> Schema:
"""Create a schema instance for an enum.

Args:
annotation: An enum.
include_null: Whether to include null as a possible value.

Returns:
A schema instance.
"""
enum_values: list[str | int | None] = [v.value for v in annotation] # type: ignore[var-annotated]
if include_null and None not in enum_values:
enum_values.append(None)
return Schema(type=_types_in_list(enum_values), enum=enum_values)


def _iter_flat_literal_args(annotation: Any) -> Iterable[Any]:
"""Iterate over the flattened arguments of a Literal.

Expand Down Expand Up @@ -331,18 +313,20 @@ def for_field_definition(self, field_definition: FieldDefinition) -> Schema | Re
result = self.for_type_alias_type(field_definition)
elif plugin_for_annotation := self.get_plugin_for(field_definition):
result = self.for_plugin(field_definition, plugin_for_annotation)
elif _should_create_enum_schema(field_definition):
annotation = _type_or_first_not_none_inner_type(field_definition)
result = create_enum_schema(annotation, include_null=field_definition.is_optional)
elif _should_create_literal_schema(field_definition):
annotation = (
make_non_optional_union(field_definition.annotation)
if field_definition.is_optional
else field_definition.annotation
)
result = create_literal_schema(annotation, include_null=field_definition.is_optional)
result = create_literal_schema(
annotation,
include_null=field_definition.is_optional,
)
elif field_definition.is_optional:
result = self.for_optional_field(field_definition)
elif field_definition.is_enum:
result = self.for_enum_field(field_definition)
elif field_definition.is_union:
result = self.for_union_field(field_definition)
elif field_definition.is_type_var:
Expand Down Expand Up @@ -445,7 +429,7 @@ def for_optional_field(self, field_definition: FieldDefinition) -> Schema:
else:
result = [schema_or_reference]

return Schema(one_of=[Schema(type=OpenAPIType.NULL), *result])
return Schema(one_of=[*result, Schema(type=OpenAPIType.NULL)])

def for_union_field(self, field_definition: FieldDefinition) -> Schema:
"""Create a Schema for a union FieldDefinition.
Expand Down Expand Up @@ -569,6 +553,36 @@ def for_collection_constrained_field(self, field_definition: FieldDefinition) ->
# INFO: Removed because it was only for pydantic constrained collections
return schema

def for_enum_field(
self,
field_definition: FieldDefinition,
) -> Schema | Reference:
"""Create a schema instance for an enum.

Args:
field_definition: A signature field instance.

Returns:
A schema or reference instance.
"""
enum_type: None | OpenAPIType | list[OpenAPIType] = None
if issubclass(field_definition.annotation, str): # StrEnum
enum_type = OpenAPIType.STRING
elif issubclass(field_definition.annotation, int): # IntEnum
enum_type = OpenAPIType.INTEGER

enum_values: list[Any] = [v.value for v in field_definition.annotation] # pyright: ignore
Alc-Alc marked this conversation as resolved.
Show resolved Hide resolved
if enum_type is None:
enum_type = _types_in_list(enum_values)

schema = self.schema_registry.get_schema_for_field_definition(field_definition)
schema.type = enum_type
schema.enum = enum_values
schema.title = get_name(field_definition.annotation)
schema.description = field_definition.annotation.__doc__

return self.schema_registry.get_reference_for_field_definition(field_definition) or schema

def process_schema_result(self, field: FieldDefinition, schema: Schema) -> Schema | Reference:
if field.kwarg_definition and field.is_const and field.has_default and schema.const is None:
schema.const = field.default
Expand Down
49 changes: 1 addition & 48 deletions litestar/_openapi/schema_generation/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

from enum import Enum
from typing import TYPE_CHECKING, Any, Mapping

from litestar.utils.helpers import get_name
Expand All @@ -11,53 +10,7 @@
from litestar.openapi.spec import Example
from litestar.typing import FieldDefinition

__all__ = (
"_should_create_enum_schema",
"_should_create_literal_schema",
"_type_or_first_not_none_inner_type",
)


def _type_or_first_not_none_inner_type(field_definition: FieldDefinition) -> Any:
"""Get the first inner type that is not None.

This is a narrow focussed utility to be used when we know that a field definition either represents
a single type, or a single type in a union with `None`, and we want the single type.

Args:
field_definition: A field definition instance.

Returns:
A field definition instance.
"""
if not field_definition.is_optional:
return field_definition.annotation
inner = next((t for t in field_definition.inner_types if not t.is_none_type), None)
if inner is None:
raise ValueError("Field definition has no inner type that is not None")
return inner.annotation


def _should_create_enum_schema(field_definition: FieldDefinition) -> bool:
"""Predicate to determine if we should create an enum schema for the field def, or not.

This returns true if the field definition is an enum, or if the field definition is a union
of an enum and ``None``.

When an annotation is ``SomeEnum | None`` we should create a schema for the enum that includes ``null``
in the enum values.

Args:
field_definition: A field definition instance.

Returns:
A boolean
"""
return field_definition.is_subclass_of(Enum) or (
field_definition.is_optional
and len(field_definition.args) == 2
and field_definition.has_inner_subclass_of(Enum)
)
__all__ = ("_should_create_literal_schema",)


def _should_create_literal_schema(field_definition: FieldDefinition) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion litestar/openapi/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class OpenAPIController(Controller):
"""Base styling of the html body."""
redoc_version: str = "next"
"""Redoc version to download from the CDN."""
swagger_ui_version: str = "5.1.3"
swagger_ui_version: str = "5.18.2"
"""SwaggerUI version to download from the CDN."""
stoplight_elements_version: str = "7.7.18"
"""StopLight Elements version to download from the CDN."""
Expand Down
2 changes: 1 addition & 1 deletion litestar/openapi/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ class SwaggerRenderPlugin(OpenAPIRenderPlugin):

def __init__(
self,
version: str = "5.1.3",
version: str = "5.18.2",
js_url: str | None = None,
css_url: str | None = None,
standalone_preset_js_url: str | None = None,
Expand Down
5 changes: 5 additions & 0 deletions litestar/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections import abc
from copy import deepcopy
from dataclasses import dataclass, is_dataclass, replace
from enum import Enum
from inspect import Parameter, Signature
from typing import Any, AnyStr, Callable, Collection, ForwardRef, Literal, Mapping, TypeVar, cast

Expand Down Expand Up @@ -339,6 +340,10 @@ def is_typeddict_type(self) -> bool:

return is_typeddict(self.origin or self.annotation)

@property
def is_enum(self) -> bool:
return self.is_subclass_of(Enum)

@property
def type_(self) -> Any:
"""The type of the annotation with all the wrappers removed, including the generic types."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,14 @@ def test_piccolo_dto_openapi_spec_generation() -> None:
assert concert_schema
assert concert_schema.to_schema() == {
"properties": {
"band_1": {"oneOf": [{"type": "null"}, {"type": "integer"}]},
"band_2": {"oneOf": [{"type": "null"}, {"type": "integer"}]},
"venue": {"oneOf": [{"type": "null"}, {"type": "integer"}]},
"band_1": {"oneOf": [{"type": "integer"}, {"type": "null"}]},
"band_2": {
"oneOf": [
{"type": "integer"},
{"type": "null"},
]
},
"venue": {"oneOf": [{"type": "integer"}, {"type": "null"}]},
},
"required": [],
"title": "CreateConcertConcertRequestBody",
Expand All @@ -152,10 +157,10 @@ def test_piccolo_dto_openapi_spec_generation() -> None:
assert record_studio_schema
assert record_studio_schema.to_schema() == {
"properties": {
"facilities": {"oneOf": [{"type": "null"}, {"type": "string"}]},
"facilities_b": {"oneOf": [{"type": "null"}, {"type": "string"}]},
"microphones": {"oneOf": [{"type": "null"}, {"items": {"type": "string"}, "type": "array"}]},
"id": {"oneOf": [{"type": "null"}, {"type": "integer"}]},
"facilities": {"oneOf": [{"type": "string"}, {"type": "null"}]},
"facilities_b": {"oneOf": [{"type": "string"}, {"type": "null"}]},
"microphones": {"oneOf": [{"items": {"type": "string"}, "type": "array"}, {"type": "null"}]},
"id": {"oneOf": [{"type": "integer"}, {"type": "null"}]},
},
"required": [],
"title": "RetrieveStudioRecordingStudioResponseBody",
Expand All @@ -166,8 +171,8 @@ def test_piccolo_dto_openapi_spec_generation() -> None:
assert venue_schema
assert venue_schema.to_schema() == {
"properties": {
"id": {"oneOf": [{"type": "null"}, {"type": "integer"}]},
"name": {"oneOf": [{"type": "null"}, {"type": "string"}]},
"id": {"oneOf": [{"type": "integer"}, {"type": "null"}]},
"name": {"oneOf": [{"type": "string"}, {"type": "null"}]},
},
"required": [],
"title": "RetrieveVenuesVenueResponseBody",
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_openapi/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def get_persons(
from_date: Optional[Union[int, datetime, date]] = None,
to_date: Optional[Union[int, datetime, date]] = None,
gender: Optional[Union[Gender, List[Gender]]] = Parameter(
examples=[Example(value="M"), Example(value=["M", "O"])]
examples=[Example(value=Gender.MALE), Example(value=[Gender.MALE, Gender.OTHER])]
),
# header parameter
secret_header: str = Parameter(header="secret"),
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_openapi/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_default_redoc_cdn_urls(
def test_default_swagger_ui_cdn_urls(
person_controller: Type[Controller], pet_controller: Type[Controller], config: OpenAPIConfig
) -> None:
default_swagger_ui_version = "5.1.3"
default_swagger_ui_version = "5.18.2"
default_swagger_bundles = [
f"https://cdn.jsdelivr.net/npm/swagger-ui-dist@{default_swagger_ui_version}/swagger-ui.css",
f"https://cdn.jsdelivr.net/npm/swagger-ui-dist@{default_swagger_ui_version}/swagger-ui-bundle.js",
Expand Down
23 changes: 8 additions & 15 deletions tests/unit/test_openapi/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
from litestar.exceptions import ImproperlyConfiguredException
from litestar.handlers import HTTPRouteHandler
from litestar.openapi import OpenAPIConfig
from litestar.openapi.spec import Example, OpenAPI, Schema
from litestar.openapi.spec import Example, OpenAPI, Reference, Schema
from litestar.openapi.spec.enums import OpenAPIType
from litestar.params import Dependency, Parameter
from litestar.routes import BaseRoute
from litestar.testing import create_test_client
from litestar.utils import find_index
from tests.unit.test_openapi.utils import Gender

if TYPE_CHECKING:
from litestar.openapi.spec.parameter import Parameter as OpenAPIParameter
Expand Down Expand Up @@ -104,23 +105,15 @@ def test_create_parameters(person_controller: Type[Controller]) -> None:
assert is_schema_value(gender.schema)
assert gender.schema == Schema(
one_of=[
Schema(type=OpenAPIType.NULL),
Schema(
type=OpenAPIType.STRING,
enum=["M", "F", "O", "A"],
examples=["M"],
),
Reference(ref="#/components/schemas/tests_unit_test_openapi_utils_Gender"),
Schema(
type=OpenAPIType.ARRAY,
items=Schema(
type=OpenAPIType.STRING,
enum=["M", "F", "O", "A"],
examples=["F"],
),
examples=[["A"]],
items=Reference(ref="#/components/schemas/tests_unit_test_openapi_utils_Gender"),
examples=[[Gender.MALE]],
),
Schema(type=OpenAPIType.NULL),
],
examples=["M", ["M", "O"]],
examples=[Gender.MALE, [Gender.MALE, Gender.OTHER]],
)
assert not gender.required

Expand Down Expand Up @@ -397,8 +390,8 @@ async def handler(
app = Litestar([handler])
assert app.openapi_schema.paths["/{path_param}"].get.parameters[0].schema.type == OpenAPIType.STRING # type: ignore[index, union-attr]
assert app.openapi_schema.paths["/{path_param}"].get.parameters[1].schema.one_of == [ # type: ignore[index, union-attr]
Schema(type=OpenAPIType.NULL),
Schema(type=OpenAPIType.STRING),
Schema(type=OpenAPIType.NULL),
]
assert app.openapi_schema.paths["/{path_param}"].get.parameters[2].schema.type == OpenAPIType.STRING # type: ignore[index, union-attr]
assert (
Expand Down
Loading
Loading