diff --git a/tap/_version.py b/tap/_version.py index 3f983e3..2f6e9a4 100644 --- a/tap/_version.py +++ b/tap/_version.py @@ -1,7 +1,7 @@ __all__ = ['__version__'] # major, minor, patch -version_info = 1, 6, 1 +version_info = 1, 6, 2 # Nice string for the version __version__ = '.'.join(map(str, version_info)) diff --git a/tap/tap.py b/tap/tap.py index ef9df47..7ef16fa 100644 --- a/tap/tap.py +++ b/tap/tap.py @@ -8,7 +8,7 @@ from warnings import warn from types import MethodType from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, TypeVar, Union, get_type_hints -from typing_inspect import is_literal_type, get_args, get_origin, is_union_type +from typing_inspect import is_literal_type, get_args from tap.utils import ( get_class_variables, @@ -16,6 +16,7 @@ get_git_root, get_dest, get_git_url, + get_origin, has_git, has_uncommitted_changes, is_option_arg, @@ -32,16 +33,10 @@ # Constants EMPTY_TYPE = get_args(List)[0] if len(get_args(List)) > 0 else tuple() +BOXED_COLLECTION_TYPES = {List, list, Set, set, Tuple, tuple} +OPTIONAL_TYPES = {Optional, Union} +BOXED_TYPES = BOXED_COLLECTION_TYPES | OPTIONAL_TYPES -SUPPORTED_DEFAULT_BASE_TYPES = {str, int, float, bool} -SUPPORTED_DEFAULT_OPTIONAL_TYPES = {Optional, Optional[str], Optional[int], Optional[float], Optional[bool]} -SUPPORTED_DEFAULT_LIST_TYPES = {List, List[str], List[int], List[float], List[bool]} -SUPPORTED_DEFAULT_SET_TYPES = {Set, Set[str], Set[int], Set[float], Set[bool]} -SUPPORTED_DEFAULT_COLLECTION_TYPES = SUPPORTED_DEFAULT_LIST_TYPES | SUPPORTED_DEFAULT_SET_TYPES | {Tuple} -SUPPORTED_DEFAULT_BOXED_TYPES = SUPPORTED_DEFAULT_OPTIONAL_TYPES | SUPPORTED_DEFAULT_COLLECTION_TYPES -SUPPORTED_DEFAULT_TYPES = set.union(SUPPORTED_DEFAULT_BASE_TYPES, - SUPPORTED_DEFAULT_OPTIONAL_TYPES, - SUPPORTED_DEFAULT_COLLECTION_TYPES) TapType = TypeVar('TapType', bound='Tap') @@ -125,6 +120,9 @@ def _add_argument(self, *name_or_flags, **kwargs) -> None: :param name_or_flags: Either a name or a list of option strings, e.g. foo or -f, --foo. :param kwargs: Keyword arguments. """ + # Set explicit bool + explicit_bool = self._explicit_bool + # Get variable name variable = get_argument_name(*name_or_flags) @@ -168,6 +166,21 @@ def _add_argument(self, *name_or_flags, **kwargs) -> None: # If type is not explicitly provided, set it if it's one of our supported default types if 'type' not in kwargs: + + # Unbox Optional[type] and set var_type = type + if get_origin(var_type) in OPTIONAL_TYPES: + var_args = get_args(var_type) + + if len(var_args) > 0: + var_type = get_args(var_type)[0] + + # If var_type is tuple as in Python 3.6, change to a typing type + # (e.g., (typing.List, ) ==> typing.List[bool]) + if isinstance(var_type, tuple): + var_type = var_type[0][var_type[1:]] + + explicit_bool = True + # First check whether it is a literal type or a boxed literal type if is_literal_type(var_type): var_type, kwargs['choices'] = get_literals(var_type, variable) @@ -195,27 +208,10 @@ def _add_argument(self, *name_or_flags, **kwargs) -> None: kwargs['nargs'] = len(types) var_type = TupleTypeEnforcer(types=types, loop=loop) - # To identify an Optional type, check if it's a union of a None and something else - elif ( - is_union_type(var_type) - and len(get_args(var_type)) == 2 - and isinstance(None, get_args(var_type)[1]) - and is_literal_type(get_args(var_type)[0]) - ): - var_type, kwargs['choices'] = get_literals(get_args(var_type)[0], variable) - elif var_type not in SUPPORTED_DEFAULT_TYPES: - is_required = kwargs.get('required', False) - arg_params = 'required=True' if is_required else f'default={getattr(self, variable)}' - raise ValueError( - f'Variable "{variable}" has type "{var_type}" which is not supported by default.\n' - f'Please explicitly add the argument to the parser by writing:\n\n' - f'def configure(self) -> None:\n' - f' self.add_argument("--{variable}", type=func, {arg_params})\n\n' - f'where "func" maps from str to {var_type}.') - - if var_type in SUPPORTED_DEFAULT_BOXED_TYPES: + + if get_origin(var_type) in BOXED_TYPES: # If List or Set type, set nargs - if (var_type in SUPPORTED_DEFAULT_COLLECTION_TYPES + if (get_origin(var_type) in BOXED_COLLECTION_TYPES and kwargs.get('action') not in {'append', 'append_const'}): kwargs['nargs'] = kwargs.get('nargs', '*') @@ -228,13 +224,12 @@ def _add_argument(self, *name_or_flags, **kwargs) -> None: else: var_type = arg_types[0] - # Handle the cases of Optional[bool], List[bool], Set[bool] + # Handle the cases of List[bool], Set[bool], Tuple[bool] if var_type == bool: var_type = boolean_type - # If bool then set action, otherwise set type if var_type == bool: - if self._explicit_bool: + if explicit_bool: kwargs['type'] = boolean_type kwargs['choices'] = [True, False] # this makes the help message more helpful else: @@ -404,10 +399,14 @@ def parse_args(self: TapType, if type(value) == list: var_type = get_origin(self._annotations[variable]) - # TODO: remove this once typing_inspect.get_origin is fixed for Python 3.9 - # https://github.com/ilevkivskyi/typing_inspect/issues/64 - # https://github.com/ilevkivskyi/typing_inspect/issues/65 - var_type = var_type if var_type is not None else self._annotations[variable] + # Unpack nested boxed types such as Optional[List[int]] + if var_type is Union: + var_type = get_origin(get_args(self._annotations[variable])[0]) + + # If var_type is tuple as in Python 3.6, change to a typing type + # (e.g., (typing.Tuple, ) ==> typing.Tuple) + if isinstance(var_type, tuple): + var_type = var_type[0] if var_type in (Set, set): value = set(value) diff --git a/tap/utils.py b/tap/utils.py index 9fe2a4a..6ed8eeb 100644 --- a/tap/utils.py +++ b/tap/utils.py @@ -24,7 +24,7 @@ Union, ) from typing_extensions import Literal -from typing_inspect import get_args +from typing_inspect import get_args, get_origin as typing_inspect_get_origin NO_CHANGES_STATUS = """nothing to commit, working tree clean""" @@ -467,3 +467,16 @@ def enforce_reproducibility(saved_reproducibility_data: Optional[Dict[str, str]] if current_reproducibility_data['git_has_uncommitted_changes']: raise ValueError(f'{no_reproducibility_message}: Uncommitted changes ' f'in current args.') + + +# TODO: remove this once typing_inspect.get_origin is fixed for Python 3.8 and 3.9 +# https://github.com/ilevkivskyi/typing_inspect/issues/64 +# https://github.com/ilevkivskyi/typing_inspect/issues/65 +def get_origin(tp: Any) -> Any: + """Same as typing_inspect.get_origin but fixes unparameterized generic types like Set.""" + origin = typing_inspect_get_origin(tp) + + if origin is None: + origin = tp + + return origin diff --git a/tests/test_integration.py b/tests/test_integration.py index 475ca57..e5e45f8 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -1,8 +1,9 @@ from copy import deepcopy import os +from pathlib import Path import sys from tempfile import TemporaryDirectory -from typing import Any, List, Optional, Set, Tuple +from typing import Any, Iterable, List, Optional, Set, Tuple from typing_extensions import Literal import unittest from unittest import TestCase @@ -10,6 +11,15 @@ from tap import Tap +def stringify(arg_list: Iterable[Any]) -> List[str]: + """Converts an iterable of arguments of any type to a list of strings. + + :param arg_list: An iterable of arguments of any type. + :return: A list of the arguments as strings. + """ + return [str(arg) for arg in arg_list] + + class EdgeCaseTests(TestCase): def test_empty(self) -> None: class EmptyTap(Tap): @@ -112,17 +122,144 @@ def test_both_assigned_okay(self): self.assertEqual(args.arg_list_str_required, ['hi', 'there']) -class CrashesOnUnsupportedTypesTests(TestCase): +# TODO: need to implement list[str] etc. +# class ParameterizedStandardCollectionTap(Tap): +# arg_list_str: list[str] +# arg_list_int: list[int] +# arg_list_int_default: list[int] = [1, 2, 5] +# arg_set_float: set[float] +# arg_set_str_default: set[str] = ['one', 'two', 'five'] +# arg_tuple_int: tuple[int, ...] +# arg_tuple_float_default: tuple[float, float, float] = (1.0, 2.0, 5.0) +# arg_tuple_str_override: tuple[str, str] = ('hi', 'there') +# arg_optional_list_int: Optional[list[int]] = None + + +# class ParameterizedStandardCollectionTests(TestCase): +# @unittest.skipIf(sys.version_info < (3, 9), 'Parameterized standard collections (e.g., list[int]) introduced in Python 3.9') +# def test_parameterized_standard_collection(self): +# arg_list_str = ['a', 'b', 'pi'] +# arg_list_int = [-2, -5, 10] +# arg_set_float = {3.54, 2.235} +# arg_tuple_int = (-4, 5, 9, 103) +# arg_tuple_str_override = ('why', 'so', 'many', 'tests?') +# arg_optional_list_int = [5, 4, 3] + +# args = ParameterizedStandardCollectionTap().parse_args([ +# '--arg_list_str', *arg_list_str, +# '--arg_list_int', *[str(var) for var in arg_list_int], +# '--arg_set_float', *[str(var) for var in arg_set_float], +# '--arg_tuple_int', *[str(var) for var in arg_tuple_int], +# '--arg_tuple_str_override', *arg_tuple_str_override, +# '--arg_optional_list_int', *[str(var) for var in arg_optional_list_int] +# ]) + +# self.assertEqual(args.arg_list_str, arg_list_str) +# self.assertEqual(args.arg_list_int, arg_list_int) +# self.assertEqual(args.arg_list_int_default, ParameterizedStandardCollectionTap.arg_list_int_default) +# self.assertEqual(args.arg_set_float, arg_set_float) +# self.assertEqual(args.arg_set_str_default, ParameterizedStandardCollectionTap.arg_set_str_default) +# self.assertEqual(args.arg_tuple_int, arg_tuple_int) +# self.assertEqual(args.arg_tuple_float_default, ParameterizedStandardCollectionTap.arg_tuple_float_default) +# self.assertEqual(args.arg_tuple_str_override, arg_tuple_str_override) +# self.assertEqual(args.arg_optional_list_int, arg_optional_list_int) + + +class NestedOptionalTypesTap(Tap): + list_bool: Optional[List[bool]] + list_int: Optional[List[int]] + list_str: Optional[List[str]] + set_bool: Optional[Set[bool]] + set_int: Optional[Set[int]] + set_str: Optional[Set[str]] + tuple_bool: Optional[Tuple[bool]] + tuple_int: Optional[Tuple[int]] + tuple_str: Optional[Tuple[str]] + tuple_pair: Optional[Tuple[bool, str, int]] + tuple_arbitrary_len_bool: Optional[Tuple[bool, ...]] + tuple_arbitrary_len_int: Optional[Tuple[int, ...]] + tuple_arbitrary_len_str: Optional[Tuple[str, ...]] + + +class NestedOptionalTypeTests(TestCase): + + def test_nested_optional_types(self): + list_bool = [True, False] + list_int = [0, 1, 2] + list_str = ['a', 'bee', 'cd', 'ee'] + set_bool = {True, False, True} + set_int = {0, 1} + set_str = {'a', 'bee', 'cd'} + tuple_bool = (False,) + tuple_int = (0,) + tuple_str = ('a',) + tuple_pair = (False, 'a', 1) + tuple_arbitrary_len_bool = (True, False, False) + tuple_arbitrary_len_int = (1, 2, 3, 4) + tuple_arbitrary_len_str = ('a', 'b') + + args = NestedOptionalTypesTap().parse_args([ + '--list_bool', *stringify(list_bool), + '--list_int', *stringify(list_int), + '--list_str', *stringify(list_str), + '--set_bool', *stringify(set_bool), + '--set_int', *stringify(set_int), + '--set_str', *stringify(set_str), + '--tuple_bool', *stringify(tuple_bool), + '--tuple_int', *stringify(tuple_int), + '--tuple_str', *stringify(tuple_str), + '--tuple_pair', *stringify(tuple_pair), + '--tuple_arbitrary_len_bool', *stringify(tuple_arbitrary_len_bool), + '--tuple_arbitrary_len_int', *stringify(tuple_arbitrary_len_int), + '--tuple_arbitrary_len_str', *stringify(tuple_arbitrary_len_str), + ]) - def test_crashes_on_unsupported(self): - # From PiDelport: https://github.com/swansonk14/typed-argument-parser/issues/27 - from pathlib import Path + self.assertEqual(args.list_bool, list_bool) + self.assertEqual(args.list_int, list_int) + self.assertEqual(args.list_str, list_str) + + self.assertEqual(args.set_bool, set_bool) + self.assertEqual(args.set_int, set_int) + self.assertEqual(args.set_str, set_str) + + self.assertEqual(args.tuple_bool, tuple_bool) + self.assertEqual(args.tuple_int, tuple_int) + self.assertEqual(args.tuple_str, tuple_str) + self.assertEqual(args.tuple_pair, tuple_pair) + self.assertEqual(args.tuple_arbitrary_len_bool, tuple_arbitrary_len_bool) + self.assertEqual(args.tuple_arbitrary_len_int, tuple_arbitrary_len_int) + self.assertEqual(args.tuple_arbitrary_len_str, tuple_arbitrary_len_str) + + +class ComplexTypeTap(Tap): + path: Path + optional_path: Optional[Path] + list_path: List[Path] + set_path: Set[Path] + tuple_path: Tuple[Path, Path] + + +class ComplexTypeTests(TestCase): + def test_complex_types(self): + path = Path('/path/to/file.txt') + optional_path = Path('/path/to/optional/file.txt') + list_path = [Path('/path/to/list/file1.txt'), Path('/path/to/list/file2.txt')] + set_path = {Path('/path/to/set/file1.txt'), Path('/path/to/set/file2.txt')} + tuple_path = (Path('/path/to/tuple/file1.txt'), Path('/path/to/tuple/file2.txt')) + + args = ComplexTypeTap().parse_args([ + '--path', str(path), + '--optional_path', str(optional_path), + '--list_path', *[str(path) for path in list_path], + '--set_path', *[str(path) for path in set_path], + '--tuple_path', *[str(path) for path in tuple_path] + ]) - class CrashingArgumentParser(Tap): - some_path: Path = 'some_path' - - with self.assertRaises(ValueError): - CrashingArgumentParser().parse_args([]) + self.assertEqual(args.path, path) + self.assertEqual(args.optional_path, optional_path) + self.assertEqual(args.list_path, list_path) + self.assertEqual(args.set_path, set_path) + self.assertEqual(args.tuple_path, tuple_path) class Person: @@ -312,7 +449,6 @@ def test_set_default_args(self) -> None: '--arg_list_bool', *arg_list_bool, '--arg_list_str_empty', *arg_list_str_empty, '--arg_list_literal', *arg_list_literal, - '--arg_set', *arg_set, '--arg_set_str', *arg_set_str, '--arg_set_int', *arg_set_int, @@ -496,26 +632,32 @@ def configure(self) -> None: def test_complex_type(self) -> None: class AddArgumentComplexTypeTap(IntegrationDefaultTap): arg_person: Person = Person('tap') - # arg_person_required: Person # TODO + arg_person_required: Person arg_person_untyped = Person('tap untyped') - # TODO: assert a crash if any complex types are not explicitly added in add_argument def configure(self) -> None: self.add_argument('--arg_person', type=Person) - # self.add_argument('--arg_person_required', type=Person) # TODO + self.add_argument('--arg_person_required', type=Person) self.add_argument('--arg_person_untyped', type=Person) - args = AddArgumentComplexTypeTap().parse_args([]) + arg_person_required = Person("hello, it's me") + + args = AddArgumentComplexTypeTap().parse_args([ + '--arg_person_required', arg_person_required.name, + ]) self.assertEqual(args.arg_person, Person('tap')) + self.assertEqual(args.arg_person_required, arg_person_required) self.assertEqual(args.arg_person_untyped, Person('tap untyped')) arg_person = Person('hi there') arg_person_untyped = Person('heyyyy') args = AddArgumentComplexTypeTap().parse_args([ '--arg_person', arg_person.name, + '--arg_person_required', arg_person_required.name, '--arg_person_untyped', arg_person_untyped.name ]) self.assertEqual(args.arg_person, arg_person) + self.assertEqual(args.arg_person_required, arg_person_required) self.assertEqual(args.arg_person_untyped, arg_person_untyped) def test_repeat_default(self) -> None: