From d87cdda3ad2ae837f9ffe9328fd9c671e5ae84f3 Mon Sep 17 00:00:00 2001 From: fazeelghafoor Date: Mon, 29 Jul 2024 05:30:31 +0500 Subject: [PATCH] add union type and subtypes check in schema model signature --- ninja/signature/details.py | 61 ++++++++++++++++++++++++++++---------- 1 file changed, 45 insertions(+), 16 deletions(-) diff --git a/ninja/signature/details.py b/ninja/signature/details.py index a05903c96..765be9d84 100644 --- a/ninja/signature/details.py +++ b/ninja/signature/details.py @@ -198,13 +198,20 @@ def _args_flatten_map(self, args: List[FuncParam]) -> Dict[str, Tuple[str, ...]] def _model_flatten_map(self, model: TModel, prefix: str) -> Generator: field: FieldInfo - for attr, field in model.model_fields.items(): - field_name = field.alias or attr - name = f"{prefix}{self.FLATTEN_PATH_SEP}{field_name}" - if is_pydantic_model(field.annotation): - yield from self._model_flatten_map(field.annotation, name) # type: ignore - else: - yield field_name, name + if get_origin(model) in UNION_TYPES: + # If the model is a union type, process each type in the union + for arg in get_args(model): + if type(arg) is None: + continue # Skip NoneType + yield from self._model_flatten_map(arg, prefix) + else: + for attr, field in model.model_fields.items(): + field_name = field.alias or attr + name = f"{prefix}{self.FLATTEN_PATH_SEP}{field_name}" + if is_pydantic_model(field.annotation): + yield from self._model_flatten_map(field.annotation, name) # type: ignore + else: + yield field_name, name def _get_param_type(self, name: str, arg: inspect.Parameter) -> FuncParam: # _EMPTY = self.signature.empty @@ -278,7 +285,11 @@ def _get_param_type(self, name: str, arg: inspect.Parameter) -> FuncParam: def is_pydantic_model(cls: Any) -> bool: try: if get_origin(cls) in UNION_TYPES: - return any(issubclass(arg, pydantic.BaseModel) for arg in get_args(cls)) + return any( + issubclass(arg, pydantic.BaseModel) + for arg in get_args(cls) + if (type(arg) is not None) + ) return issubclass(cls, pydantic.BaseModel) except TypeError: return False @@ -321,14 +332,32 @@ def detect_collection_fields( for attr in path[1:]: if hasattr(annotation_or_field, "annotation"): annotation_or_field = annotation_or_field.annotation - annotation_or_field = next( - ( - a - for a in annotation_or_field.model_fields.values() - if a.alias == attr - ), - annotation_or_field.model_fields.get(attr), - ) # pragma: no cover + + # check union types + if get_origin(annotation_or_field) in UNION_TYPES: + for arg in get_args(annotation_or_field): + if type(arg) is None: + continue # Skip NoneType + if hasattr(arg, "model_fields"): + annotation_or_field = next( + ( + a + for a in arg.model_fields.values() + if a.alias == attr + ), + arg.model_fields.get(attr), + ) # pragma: no cover + else: + continue + else: + annotation_or_field = next( + ( + a + for a in annotation_or_field.model_fields.values() + if a.alias == attr + ), + annotation_or_field.model_fields.get(attr), + ) # pragma: no cover annotation_or_field = getattr( annotation_or_field, "outer_type_", annotation_or_field