From 8b69f049bc3e5bdba00ab7be4238dbedc91d3bc8 Mon Sep 17 00:00:00 2001 From: Sanskar Jethi Date: Tue, 26 Nov 2024 02:48:14 +0000 Subject: [PATCH] fix: openapi schema generation --- robyn/openapi.py | 42 ++++++++++++++++++++++++++++++------------ 1 file changed, 30 insertions(+), 12 deletions(-) diff --git a/robyn/openapi.py b/robyn/openapi.py index be5aeecb..2182f6fc 100644 --- a/robyn/openapi.py +++ b/robyn/openapi.py @@ -5,7 +5,7 @@ from importlib import resources from inspect import Signature from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple, TypedDict +from typing import Any, Callable, Dict, List, Optional, Tuple, TypedDict, Union from robyn.responses import html from robyn.robyn import QueryParams, Response @@ -370,7 +370,6 @@ def get_schema_object(self, parameter: str, param_type: Any) -> dict: @param param_type: Any the type to be inferred @return: dict the properties object """ - properties = { "title": parameter.capitalize(), } @@ -384,25 +383,44 @@ def get_schema_object(self, parameter: str, param_type: Any) -> dict: list: "array", } + # Handle basic types for type_name in type_mapping: if param_type is type_name: properties["type"] = type_mapping[type_name] return properties - # check for Optional type - if param_type.__module__ == "typing": - properties["anyOf"] = [{"type": self.get_openapi_type(param_type.__args__[0])}, {"type": "null"}] - return properties - # check for custom classes and TypedDicts + # Handle typing module types (Optional, List, etc) + if hasattr(param_type, "__module__") and param_type.__module__ == "typing": + origin = typing.get_origin(param_type) + args = typing.get_args(param_type) + + # Handle Optional types + if origin is Union and type(None) in args: + non_none_type = next(t for t in args if t is not type(None)) + properties["anyOf"] = [ + {"type": self.get_openapi_type(non_none_type)}, + {"type": "null"} + ] + return properties + + # Handle List types + elif origin in (list, List): + properties["type"] = "array" + if args: + item_type = args[0] + properties["items"] = self.get_schema_object("item", item_type) + return properties + + # Handle custom classes and TypedDicts elif inspect.isclass(param_type): properties["type"] = "object" - properties["properties"] = {} - for e in param_type.__annotations__: - properties["properties"][e] = self.get_schema_object(e, param_type.__annotations__[e]) - - properties["type"] = "object" + if hasattr(param_type, "__annotations__"): + for e in param_type.__annotations__: + properties["properties"][e] = self.get_schema_object( + e, param_type.__annotations__[e] + ) return properties