Skip to content

Commit

Permalink
Boxed types in Option now supported as well as external types such as…
Browse files Browse the repository at this point in the history
… Path
  • Loading branch information
martinjm97 committed Mar 27, 2021
1 parent 4b89fb9 commit 5198e0b
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 34 deletions.
2 changes: 1 addition & 1 deletion tap/_version.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
__all__ = ['__version__']

# major, minor, patch
version_info = 1, 6, 1
version_info = 1, 6, 2

# Nice string for the version
__version__ = '.'.join(map(str, version_info))
27 changes: 10 additions & 17 deletions tap/tap.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from warnings import warn
from types import MethodType
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, TypeVar, Union, get_type_hints
from typing_inspect import is_literal_type, get_args, is_union_type
from typing_inspect import is_literal_type, get_args

from tap.utils import (
get_class_variables,
Expand Down Expand Up @@ -166,6 +166,15 @@ def _add_argument(self, *name_or_flags, **kwargs) -> None:

# If type is not explicitly provided, set it if it's one of our supported default types
if 'type' not in kwargs:

# Unbox Optional[type] and set var_type = type
if get_origin(var_type) in OPTIONAL_TYPES:
var_args = get_args(var_type)

if len(var_args) > 0:
var_type = get_args(var_type)[0]
explicit_bool = True

# First check whether it is a literal type or a boxed literal type
if is_literal_type(var_type):
var_type, kwargs['choices'] = get_literals(var_type, variable)
Expand Down Expand Up @@ -193,22 +202,6 @@ def _add_argument(self, *name_or_flags, **kwargs) -> None:
kwargs['nargs'] = len(types)

var_type = TupleTypeEnforcer(types=types, loop=loop)
# To identify an Optional type, check if it's a union of a None and something else
elif (
is_union_type(var_type)
and len(get_args(var_type)) == 2
and isinstance(None, get_args(var_type)[1])
and is_literal_type(get_args(var_type)[0])
):
var_type, kwargs['choices'] = get_literals(get_args(var_type)[0], variable)

# Unbox Optional[type] and set var_type = type
if get_origin(var_type) in OPTIONAL_TYPES:
var_args = get_args(var_type)

if len(var_args) > 0:
var_type = get_args(var_type)[0]
explicit_bool = True

if get_origin(var_type) in BOXED_TYPES:
# If List or Set type, set nargs
Expand Down
39 changes: 23 additions & 16 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,10 @@ class NestedOptionalTypesTap(Tap):
tuple_bool: Optional[Tuple[bool]]
tuple_int: Optional[Tuple[int]]
tuple_str: Optional[Tuple[str]]
# tuple_pair: Optional[Tuple[bool, str, int]]
# tuple_arbitrary_len_bool: Optional[Tuple[bool, ...]]
# tuple_arbitrary_len_int: Optional[Tuple[int, ...]]
# tuple_arbitrary_len_str: Optional[Tuple[str, ...]]
tuple_pair: Optional[Tuple[bool, str, int]]
tuple_arbitrary_len_bool: Optional[Tuple[bool, ...]]
tuple_arbitrary_len_int: Optional[Tuple[int, ...]]
tuple_arbitrary_len_str: Optional[Tuple[str, ...]]


class NestedOptionalTypeTests(TestCase):
Expand Down Expand Up @@ -208,10 +208,10 @@ def test_nested_optional_types(self):
'--tuple_bool', *stringify(tuple_bool),
'--tuple_int', *stringify(tuple_int),
'--tuple_str', *stringify(tuple_str),
# '--tuple_pair', *stringify(tuple_pair),
# '--tuple_arbitrary_len_bool', *stringify(tuple_arbitrary_len_bool),
# '--tuple_arbitrary_len_int', *stringify(tuple_arbitrary_len_int),
# '--tuple_arbitrary_len_str', *stringify(tuple_arbitrary_len_str),
'--tuple_pair', *stringify(tuple_pair),
'--tuple_arbitrary_len_bool', *stringify(tuple_arbitrary_len_bool),
'--tuple_arbitrary_len_int', *stringify(tuple_arbitrary_len_int),
'--tuple_arbitrary_len_str', *stringify(tuple_arbitrary_len_str),
])

self.assertEqual(args.list_bool, list_bool)
Expand All @@ -225,10 +225,10 @@ def test_nested_optional_types(self):
self.assertEqual(args.tuple_bool, tuple_bool)
self.assertEqual(args.tuple_int, tuple_int)
self.assertEqual(args.tuple_str, tuple_str)
# self.assertEqual(args.tuple_pair, tuple_pair)
# self.assertEqual(args.tuple_arbitrary_len_bool, tuple_arbitrary_len_bool)
# self.assertEqual(args.tuple_arbitrary_len_int, tuple_arbitrary_len_int)
# self.assertEqual(args.tuple_arbitrary_len_str, tuple_arbitrary_len_str)
self.assertEqual(args.tuple_pair, tuple_pair)
self.assertEqual(args.tuple_arbitrary_len_bool, tuple_arbitrary_len_bool)
self.assertEqual(args.tuple_arbitrary_len_int, tuple_arbitrary_len_int)
self.assertEqual(args.tuple_arbitrary_len_str, tuple_arbitrary_len_str)


class ComplexTypeTap(Tap):
Expand Down Expand Up @@ -261,6 +261,7 @@ def test_complex_types(self):
self.assertEqual(args.set_path, set_path)
self.assertEqual(args.tuple_path, tuple_path)


class Person:
def __init__(self, name: str):
self.name = name
Expand Down Expand Up @@ -631,26 +632,32 @@ def configure(self) -> None:
def test_complex_type(self) -> None:
class AddArgumentComplexTypeTap(IntegrationDefaultTap):
arg_person: Person = Person('tap')
# arg_person_required: Person # TODO
arg_person_required: Person
arg_person_untyped = Person('tap untyped')

# TODO: assert a crash if any complex types are not explicitly added in add_argument
def configure(self) -> None:
self.add_argument('--arg_person', type=Person)
# self.add_argument('--arg_person_required', type=Person) # TODO
self.add_argument('--arg_person_required', type=Person)
self.add_argument('--arg_person_untyped', type=Person)

args = AddArgumentComplexTypeTap().parse_args([])
arg_person_required = Person("hello, it's me")

args = AddArgumentComplexTypeTap().parse_args([
'--arg_person_required', arg_person_required.name,
])
self.assertEqual(args.arg_person, Person('tap'))
self.assertEqual(args.arg_person_required, arg_person_required)
self.assertEqual(args.arg_person_untyped, Person('tap untyped'))

arg_person = Person('hi there')
arg_person_untyped = Person('heyyyy')
args = AddArgumentComplexTypeTap().parse_args([
'--arg_person', arg_person.name,
'--arg_person_required', arg_person_required.name,
'--arg_person_untyped', arg_person_untyped.name
])
self.assertEqual(args.arg_person, arg_person)
self.assertEqual(args.arg_person_required, arg_person_required)
self.assertEqual(args.arg_person_untyped, arg_person_untyped)

def test_repeat_default(self) -> None:
Expand Down

0 comments on commit 5198e0b

Please sign in to comment.