diff --git a/click_params/__init__.py b/click_params/__init__.py index 6e8abdb..6284299 100644 --- a/click_params/__init__.py +++ b/click_params/__init__.py @@ -1,6 +1,6 @@ __version__ = '0.2.0' -from .base import BaseParamType, ValidatorParamType, RangeParamType, ListParamType, UnionParamType +from .base import BaseParamType, ValidatorParamType, RangeParamType, ListParamType, UnionParamType, EnumParamType from .domain import ( DOMAIN, PUBLIC_URL, URL, EMAIL, SLUG, EmailParamType, DomainListParamType, PublicUrlListParamType, UrlListParamType, EmailListParamType, SlugListParamType @@ -21,7 +21,7 @@ __all__ = [ # base - 'BaseParamType', 'ValidatorParamType', 'RangeParamType', 'ListParamType', 'UnionParamType', + 'BaseParamType', 'ValidatorParamType', 'RangeParamType', 'ListParamType', 'UnionParamType', 'EnumParamType', # domain 'DOMAIN', 'PUBLIC_URL', 'URL', 'EmailParamType', 'EMAIL', 'SLUG', 'DomainListParamType', 'PublicUrlListParamType', diff --git a/click_params/base.py b/click_params/base.py index 975034a..a33bd6a 100644 --- a/click_params/base.py +++ b/click_params/base.py @@ -1,5 +1,6 @@ """Base classes to implement various parameter types""" -from typing import Union, Tuple, Callable, List, Any, Optional, Sequence +import enum +from typing import Union, Tuple, Callable, List, Any, Optional, Sequence, Type import click @@ -139,3 +140,33 @@ def convert(self, value, param, ctx): def __repr__(self): return self.name.upper() + + +class EnumParamType(CustomParamType): + + def __init__(self, enum_type: Type[enum.Enum], transform_upper: bool = True): + self.enum_type = enum_type + self.transform_upper = transform_upper + + def convert(self, value, param, ctx): + if self.transform_upper: + value = value.upper() + try: + return self.enum_type[value] + except KeyError: + raise click.BadParameter( + "Unknown {enum_type} value: {value}".format( + enum_type=self.enum_type.__name__, + value=value + ) + ) + + def get_metavar(self, param): + choices_str = '|'.join([element.name for element in self.enum_type]) + + # Use curly braces to indicate a required argument. + if param.required and param.param_type_name == 'argument': + return '{{{choices_str}}}'.format(choices_str=choices_str) + + # Use square braces to indicate an option or optional argument. + return '[{choices_str}]'.format(choices_str=choices_str) diff --git a/docs/usage/miscellaneous.md b/docs/usage/miscellaneous.md index c5572e0..d423d47 100644 --- a/docs/usage/miscellaneous.md +++ b/docs/usage/miscellaneous.md @@ -213,4 +213,42 @@ Two remarks compared to the last script. - The order of parameter types in the union is the order click will try to parse the value. - In the last two examples click was unable to parse because they were neither an integer nor a string from allowed -choices. \ No newline at end of file +choices. + +## EnumParamType + +Signature: `EnumParamType(enum_type: Type[enum.Enum], transform_upper: bool = True)` + +Converts string to an enum value. if `transform_upper` is set to be true, convert name +to uppercase before getting the enum value. + +````python +import enum +import click +from click_params import EnumParamType + + +class MyEnum(enum.Enum): + + ONE = "one" + TWO = "two" + THREE = "three" + ONE_ALIAS = ONE + + +@click.command() +@click.option("-c", "--choice", type=EnumParamType(MyEnum)) +def f(choice): + click.echo("You have chosen {choice}".format(choice=choice)) +```` + +````bash +$ python cli.py -c one +You have chosen MyEnum.ONE + +$ python cli.py -c three +You have chosen MyEnum.Three + +$ python cli.py -c unreal +Error: Unknown MyEnum value: UNREAL +```` \ No newline at end of file diff --git a/tests/test_base.py b/tests/test_base.py index c50dc13..097bbc3 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -1,10 +1,12 @@ +import enum from fractions import Fraction import click import pytest from validators.utils import validator -from click_params.base import RangeParamType, BaseParamType, ValidatorParamType, ListParamType, UnionParamType +from click_params.base import RangeParamType, BaseParamType, ValidatorParamType, ListParamType, UnionParamType, \ + EnumParamType from click_params.numeric import DECIMAL, FRACTION, COMPLEX @@ -218,3 +220,48 @@ def test_parse_expression_unsuccessfully(self, expression, param_types): union_type = UnionParamType(param_types=param_types) with pytest.raises(click.BadParameter): union_type.convert(expression, None, None) + + +class MyEnum(enum.Enum): + + ONE = 1 + TWO = 2 + THREE = 3 + ONE_ALIAS = 1 + + +class TestEnumParamType: + """Test class EnumParamType""" + + @pytest.mark.parametrize(('expression', 'transform_upper', 'value'), [ + ('one', True, MyEnum.ONE), + ('One', True, MyEnum.ONE), + ('ONE', True, MyEnum.ONE), + ('ONE', False, MyEnum.ONE), + ('two', True, MyEnum.TWO), + ('TWO', True, MyEnum.TWO), + ('TWO', False, MyEnum.TWO), + ('three', True, MyEnum.THREE), + ('THREE', True, MyEnum.THREE), + ('THREE', False, MyEnum.THREE), + ('one_alias', True, MyEnum.ONE), + ]) + def test_parse_expression_successfully(self, expression, transform_upper, value): + enum_type = EnumParamType(MyEnum, transform_upper=transform_upper) + converted_value = enum_type.convert(expression, None, None) + assert type(value) == MyEnum + assert value == converted_value + + @pytest.mark.parametrize(('expression', 'transform_upper'), [ + ('one', False), + ('One', False), + ('two', False), + ('three', False), + ('one_alias', False), + ('unreal', True), + ('unreal', False), + ]) + def test_parse_expression_unsuccessfully(self, expression, transform_upper): + enum_type = EnumParamType(MyEnum, transform_upper=transform_upper) + with pytest.raises(click.BadParameter): + enum_type.convert(expression, None, None)