diff --git a/src/pyckson/__init__.py b/src/pyckson/__init__.py index a8ea8d6..21adbc9 100644 --- a/src/pyckson/__init__.py +++ b/src/pyckson/__init__.py @@ -2,6 +2,7 @@ from pyckson.json import * from pyckson.parser import parse from pyckson.parsers.base import Parser +from pyckson.parsers.base import ParserException from pyckson.serializer import serialize from pyckson.serializers.base import Serializer from pyckson.dates.helpers import configure_date_formatter, configure_explicit_nulls diff --git a/src/pyckson/helpers.py b/src/pyckson/helpers.py index 960ab97..b74fd54 100644 --- a/src/pyckson/helpers.py +++ b/src/pyckson/helpers.py @@ -127,6 +127,12 @@ def is_typing_dict_annotation(annotation): return False +def is_union_annotation(annotation) -> bool: + if hasattr(annotation, '__name__'): + return annotation.__name__ == 'Union' + return False + + def using(attr): def class_decorator(cls): set_cls_attr(cls, PYCKSON_RULE_ATTR, attr) diff --git a/src/pyckson/model/union.py b/src/pyckson/model/union.py index 73e68d6..bc17c1d 100644 --- a/src/pyckson/model/union.py +++ b/src/pyckson/model/union.py @@ -20,5 +20,14 @@ def inspect_optional_typing(annotation) -> Tuple[bool, type]: else: union_params = annotation.__union_params__ - is_optional = len(union_params) == 2 and isinstance(None, union_params[1]) - return is_optional, union_params[0] + try: + is_optional = isinstance(None, union_params[-1]) + except TypeError: + is_optional = False + if is_optional: + union_param = Union[union_params[:-1]] + elif len(union_params) > 1: + union_param = Union[union_params] + else: + union_param = union_params[0] + return is_optional, union_param diff --git a/src/pyckson/parsers/base.py b/src/pyckson/parsers/base.py index 8236056..cb1f0b3 100644 --- a/src/pyckson/parsers/base.py +++ b/src/pyckson/parsers/base.py @@ -1,4 +1,8 @@ from decimal import Decimal +from enum import Enum + +class ParserException(Exception): + pass class Parser: @@ -16,22 +20,30 @@ def __init__(self, cls): self.cls = cls def parse(self, json_value): + if not isinstance(json_value, self.cls): + raise ParserException(f'"{json_value}" is supposed to be a {self.cls.__name__}.') return self.cls(json_value) class ListParser(Parser): def __init__(self, sub_parser: Parser): self.sub_parser = sub_parser + self.cls = list def parse(self, json_value): + if not isinstance(json_value, list): + raise ParserException(f'"{json_value}" is supposed to be a list.') return [self.sub_parser.parse(item) for item in json_value] class SetParser(Parser): def __init__(self, sub_parser: Parser): self.sub_parser = sub_parser + self.cls = set def parse(self, json_value): + if not isinstance(json_value, set) and not isinstance(json_value, list): + raise ParserException(f'"{json_value}" is supposed to be a set or a list.') return {self.sub_parser.parse(item) for item in json_value} @@ -40,14 +52,19 @@ def __init__(self, cls): self.cls = cls def parse(self, value): + if value not in self.cls.__members__: + raise ParserException(f'"{value}" is not a valid value for "{self.cls.__name__}" Enum.') return self.cls[value] class CaseInsensitiveEnumParser(Parser): def __init__(self, cls): self.values = {member.name.lower(): member for member in cls} + self.cls = Enum def parse(self, value): + if value.lower() not in self.values: + raise ParserException(f'"{value}" is not a valid value for "{self.cls.__name__}" Enum.') return self.values[value.lower()] @@ -75,3 +92,17 @@ def parse(self, json_value): class DecimalParser(Parser): def parse(self, json_value): return Decimal(json_value) + + +class UnionParser(Parser): + def __init__(self, value_parsers: list[Parser]): + self.value_parsers = value_parsers + + def parse(self, json_value): + for parser in self.value_parsers: + if hasattr(parser, 'cls') and isinstance(json_value, parser.cls): + try: + return parser.parse(json_value) + except: + pass + raise TypeError(f'{json_value} is not compatible with Union type in Pyckson.') diff --git a/src/pyckson/parsers/provider.py b/src/pyckson/parsers/provider.py index f43507c..4085493 100644 --- a/src/pyckson/parsers/provider.py +++ b/src/pyckson/parsers/provider.py @@ -10,10 +10,10 @@ from pyckson.const import BASIC_TYPES, PYCKSON_TYPEINFO, PYCKSON_ENUM_OPTIONS, ENUM_CASE_INSENSITIVE, PYCKSON_PARSER, \ DATE_TYPES, EXTRA_TYPES, get_cls_attr, has_cls_attr, ENUM_USE_VALUES from pyckson.helpers import TypeProvider, is_list_annotation, is_set_annotation, is_enum_annotation, \ - is_basic_dict_annotation, is_typing_dict_annotation + is_basic_dict_annotation, is_typing_dict_annotation, is_union_annotation from pyckson.parsers.advanced import UnresolvedParser, ClassParser, CustomDeferredParser, DateParser from pyckson.parsers.base import Parser, BasicParser, ListParser, CaseInsensitiveEnumParser, DefaultEnumParser, \ - BasicDictParser, SetParser, BasicParserWithCast, TypingDictParser, DecimalParser, ValuesEnumParser + BasicDictParser, SetParser, BasicParserWithCast, TypingDictParser, DecimalParser, ValuesEnumParser, UnionParser from pyckson.providers import ParserProvider, ModelProvider @@ -66,6 +66,9 @@ def get(self, obj_type, parent_class, name_in_parent) -> Parser: if obj_type.__args__[0] != str: raise TypeError('typing.Dict key can only be str in class {}'.format(parent_class)) return TypingDictParser(self.get(obj_type.__args__[1], parent_class, name_in_parent)) + if is_union_annotation(obj_type): + return UnionParser([self.get(obj_type_arg, parent_class, name_in_parent) + for obj_type_arg in obj_type.__args__]) if has_cls_attr(obj_type, PYCKSON_PARSER): return CustomDeferredParser(obj_type) return ClassParser(obj_type, self.model_provider) diff --git a/src/pyckson/serializers/advanced.py b/src/pyckson/serializers/advanced.py index c3310eb..4f5b667 100644 --- a/src/pyckson/serializers/advanced.py +++ b/src/pyckson/serializers/advanced.py @@ -2,9 +2,9 @@ from pyckson.const import PYCKSON_SERIALIZER, has_cls_attr from pyckson.dates.helpers import get_class_date_formatter, get_class_use_explicit_nulls -from pyckson.helpers import is_base_type, get_custom_serializer +from pyckson.helpers import is_base_type, get_custom_serializer, is_base_type_with_cast from pyckson.providers import ModelProvider -from pyckson.serializers.base import Serializer, BasicSerializer +from pyckson.serializers.base import Serializer, BasicSerializer, ListSerializer class GenericSerializer(Serializer): @@ -12,10 +12,12 @@ def __init__(self, model_provider: ModelProvider): self.model_provider = model_provider def serialize(self, obj): - if is_base_type(obj): + if is_base_type(obj) or is_base_type_with_cast(obj): return BasicSerializer().serialize(obj) elif has_cls_attr(obj.__class__, PYCKSON_SERIALIZER): return get_custom_serializer(obj.__class__).serialize(obj) + elif isinstance(obj, list): + return ListSerializer(GenericSerializer(self.model_provider)).serialize(obj) else: return ClassSerializer(self.model_provider).serialize(obj) diff --git a/src/pyckson/serializers/provider.py b/src/pyckson/serializers/provider.py index 4256641..b8d1285 100644 --- a/src/pyckson/serializers/provider.py +++ b/src/pyckson/serializers/provider.py @@ -2,7 +2,7 @@ from pyckson.defaults import apply_enum_default from pyckson.helpers import is_list_annotation, is_set_annotation, is_enum_annotation, is_basic_dict_annotation, \ - is_typing_dict_annotation + is_typing_dict_annotation, is_union_annotation try: from typing import _ForwardRef as ForwardRef @@ -52,6 +52,8 @@ def get(self, obj_type, parent_class, name_in_parent) -> Serializer: if obj_type.__args__[0] != str: raise TypeError('typing.Dict key can only be str in class {}'.format(parent_class)) return TypingDictSerializer(self.get(obj_type.__args__[1], parent_class, name_in_parent)) + if is_union_annotation(obj_type): + return GenericSerializer(self.model_provider) if has_cls_attr(obj_type, PYCKSON_SERIALIZER): return CustomDeferredSerializer(obj_type) return ClassSerializer(self.model_provider) diff --git a/tests/model/test_union.py b/tests/model/test_union.py index 586aaa7..1575792 100644 --- a/tests/model/test_union.py +++ b/tests/model/test_union.py @@ -21,5 +21,8 @@ def test_union_with_none_should_be_optional(): def test_other_unions_should_not_be_optional(): - assert inspect_optional_typing(Union[int, str]) == (False, int) - assert inspect_optional_typing(Union[int, str, None]) == (False, int) + assert inspect_optional_typing(Union[int, str]) == (False, Union[int, str]) + + +def test_multiple_union_with_none_should_be_optional(): + assert inspect_optional_typing(Union[int, str, None]) == (True, Union[int, str]) diff --git a/tests/parsers/test_base.py b/tests/parsers/test_base.py new file mode 100644 index 0000000..af394b7 --- /dev/null +++ b/tests/parsers/test_base.py @@ -0,0 +1,88 @@ +from assertpy import assert_that + +from pyckson.parsers.base import ParserException, SetParser, UnionParser, BasicParserWithCast, ListParser, BasicParser + + +class TestBasicParserWithCast: + def test_should_handle_simple_type(self): + parser = BasicParserWithCast(int) + + result = parser.parse(5) + + assert_that(result).is_equal_to(5) + + def test_should_raise_when_it_is_not_the_correct_type(self): + parser = BasicParserWithCast(str) + + assert_that(parser.parse).raises(ParserException).when_called_with(5) + + +class TestUnionParser: + def test_should_parse_simple_union(self): + parser = UnionParser([BasicParserWithCast(int)]) + + result = parser.parse(5) + + assert result == 5 + + def test_should_parse_list_in_union(self): + parser = UnionParser([ListParser(BasicParserWithCast(int))]) + + result = parser.parse([5, 6]) + + assert result == [5, 6] + + def test_should_raise_if_parser_does_not_correspond_to_union_type(self): + parser = UnionParser([BasicParserWithCast(int)]) + + assert_that(parser.parse).raises(TypeError).when_called_with("str") + + def test_should_not_raise_if_parser_does_not_have_cls(self): + parser = UnionParser([BasicParser(), BasicParserWithCast(int)]) + + result = parser.parse(5) + + assert_that(result).is_equal_to(5) + + def test_should_parse_list_of_list_in_union(self): + parser = UnionParser([ListParser(BasicParserWithCast(int)), ListParser(ListParser(BasicParserWithCast(int)))]) + + result = parser.parse([[5], [6]]) + + assert result == [[5], [6]] + + + +class TestListParser: + def test_should_accept_list(self): + parser = ListParser(BasicParserWithCast(int)) + + result = parser.parse([5]) + + assert_that(result).is_equal_to([5]) + + def test_should_raise_when_parse_other_than_list(self): + parser = ListParser(BasicParserWithCast(int)) + + assert_that(parser.parse).raises(ParserException).when_called_with(5) + + +class TestSetParser: + def test_should_accept_set(self): + parser = SetParser(BasicParserWithCast(int)) + + result = parser.parse({5}) + + assert_that(result).is_equal_to({5}) + + def test_should_accept_list_as_set(self): + parser = SetParser(BasicParserWithCast(int)) + + result = parser.parse([5]) + + assert_that(result).is_equal_to({5}) + + def test_should_raise_when_parse_other_than_list(self): + parser = SetParser(BasicParserWithCast(int)) + + assert_that(parser.parse).raises(ParserException).when_called_with(5) diff --git a/tests/parsers/test_enum.py b/tests/parsers/test_enum.py index fa24677..c69f64a 100644 --- a/tests/parsers/test_enum.py +++ b/tests/parsers/test_enum.py @@ -1,7 +1,7 @@ from enum import Enum from unittest import TestCase -from pyckson.parsers.base import DefaultEnumParser, CaseInsensitiveEnumParser +from pyckson.parsers.base import DefaultEnumParser, CaseInsensitiveEnumParser, ParserException class MyEnum(Enum): @@ -20,11 +20,11 @@ def test_should_parse_value_in_enum(self): self.assertEqual(self.parser.parse('b'), MyEnum.b) def test_should_not_parse_uppercase_not_in_enum(self): - with self.assertRaises(KeyError): + with self.assertRaises(ParserException): self.parser.parse('B') def test_should_not_parse_value_not_in_enum(self): - with self.assertRaises(KeyError): + with self.assertRaises(ParserException): self.parser.parse('c') @@ -46,5 +46,5 @@ def test_should_parse_case_insensitive(self): self.assertEqual(self.parser.parse('b'), MyInsensitiveEnum.B) def test_should_not_parse_value_not_in_enum(self): - with self.assertRaises(KeyError): + with self.assertRaises(ParserException): self.parser.parse('c') diff --git a/tests/parsers/test_parser.py b/tests/parsers/test_parser.py index 60158cd..d83707e 100644 --- a/tests/parsers/test_parser.py +++ b/tests/parsers/test_parser.py @@ -2,7 +2,7 @@ from datetime import datetime, date from decimal import Decimal from enum import Enum -from typing import List, Dict, Set, Optional +from typing import List, Dict, Set, Optional, Union from unittest import TestCase from pyckson import date_formatter, loads @@ -377,3 +377,27 @@ def __init__(self, e: MyEnum): self.e = e assert parse(Foo, {'e': 'fooo'}).e == MyEnum.FOO + + +def test_parse_union_str_values(): + class Foo: + def __init__(self, e: Union[str, int]): + self.e = e + + assert parse(Foo, {'e': 'fooo'}).e == 'fooo' + + +def test_parse_union_int_values(): + class Foo: + def __init__(self, e: Union[str, int]): + self.e = e + + assert parse(Foo, {'e': 5}).e == 5 + + +def test_parse_union_list_values(): + class Foo: + def __init__(self, e: Union[str, List[str]]): + self.e = e + + assert parse(Foo, {'e': ['yo']}).e == ['yo'] diff --git a/tests/serializers/test_serializer.py b/tests/serializers/test_serializer.py index a24e935..f651e32 100644 --- a/tests/serializers/test_serializer.py +++ b/tests/serializers/test_serializer.py @@ -322,6 +322,30 @@ def __init__(self, foo: Union[X, Y]): assert serialize(Foo(X('a'))) == {'foo': {'x': 'a'}} +def test_serialize_union_str_values(): + class Foo: + def __init__(self, e: Union[str, int]): + self.e = e + + assert serialize(Foo('fooo')) == {'e': 'fooo'} + + +def test_serialize_union_int_values(): + class Foo: + def __init__(self, e: Union[str, int]): + self.e = e + + assert serialize(Foo(5)) == {'e': 5} + + +def test_serialize_union_list_values(): + class Foo: + def __init__(self, e: Union[str, List[str]]): + self.e = e + + assert serialize(Foo(['yo'])) == {'e': ['yo']} + + def test_should_serialize_decimal(): class Foo: def __init__(self, x: Decimal): diff --git a/tests/test_helpers.py b/tests/test_helpers.py index dab995c..83fa41e 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -1,6 +1,7 @@ +from typing import Union from unittest import TestCase -from pyckson.helpers import camel_case_name +from pyckson.helpers import camel_case_name, is_union_annotation class CamelCaseNameTest(TestCase): @@ -9,3 +10,11 @@ def test_should_do_nothing_on_simple_names(self): def test_should_camel_case_when_there_is_an_underscore(self): self.assertEqual(camel_case_name('foo_bar'), 'fooBar') + + +def test_should_detect_union_annotation(): + annotation = Union[str, int] + + is_union = is_union_annotation(annotation) + + assert is_union