Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enhance(sdk): support new typing hint example for transformers 4.36.0+ #3116

Merged
merged 1 commit into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 16 additions & 21 deletions client/starwhale/api/_impl/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,32 +259,27 @@ def convert_field_to_option(field: dataclasses.Field) -> click.Option:
}

# reference from huggingface transformers: https://github.com/huggingface/transformers/blob/main/src/transformers/hf_argparser.py
# only support: Union[xxx, None] or Union[EnumType, str] or [List[EnumType], str] type
# only support Union: NoneType(optional), str(optional) and other types, such as: Optional[int], Union[int], Union[int, str], Union[List[str], str] and Optional[Union[List[str], str]]
origin_type = getattr(field.type, "__origin__", field.type)
if origin_type is t.Union:
if (
str not in field.type.__args__ and type(None) not in field.type.__args__
) or (len(field.type.__args__) != 2):
_args = list(field.type.__args__)
if type(None) in _args:
_args.remove(type(None))

_args_cnt = len(_args)
if (_args_cnt == 2 and str not in _args) or _args_cnt > 2 or _args_cnt == 0:
raise ValueError(
f"{field.type} is not supported."
"Only support Union[xxx, None] or Union[EnumType, str] or [List[EnumType], str] type"
"Only `Union[X, str, NoneType]` (i.e., `Optional[X]`) or `Union[X, str]` is allowed for `Union` because"
" the argument parser only supports one type per argument."
f" Problem encountered in field '{field.name}'."
)

if type(None) in field.type.__args__:
# ignore None type, use another type as the field type
field.type = (
field.type.__args__[0]
if field.type.__args__[1] == type(None)
else field.type.__args__[1]
)
if _args_cnt == 1:
field.type = _args[0]
origin_type = getattr(field.type, "__origin__", field.type)
else:
# ignore str and None type, use another type as the field type
field.type = (
field.type.__args__[0]
if field.type.__args__[1] == str
else field.type.__args__[1]
)
elif _args_cnt == 2:
# filter `str` in Union
field.type = _args[0] if _args[1] == str else _args[1]
origin_type = getattr(field.type, "__origin__", field.type)

if (origin_type is Literal) or (
Expand All @@ -300,7 +295,7 @@ def convert_field_to_option(field: dataclasses.Field) -> click.Option:
kw["default"] = field.default
else:
kw["required"] = True
elif field.type is bool or field.type == t.Optional[bool]:
elif field.type is bool:
kw["is_flag"] = True
kw["type"] = bool
kw["default"] = False if field.default is dataclasses.MISSING else field.default
Expand Down
24 changes: 23 additions & 1 deletion client/tests/sdk/test_argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ class DebugOption(Enum):
TPU_METRICS_DEBUG = "tpu_metrics_debug"


class FSDPOption(Enum):
FSDP = "fsdp"
FSDP2 = "fsdp2"


@dataclasses.dataclass
class ScalarArguments:
no_field = 1
Expand Down Expand Up @@ -62,6 +67,13 @@ class ComposeArguments:
label_names: t.Optional[t.List[str]] = dataclasses.field(
default=None, metadata={"help": "label names"}
)
fsdp: t.Optional[t.Union[t.List[FSDPOption], str]] = dataclasses.field(
default="", metadata={"help": "fsdp"}
)
fsdp2: t.Optional[t.Union[str, t.List[FSDPOption]]] = dataclasses.field(
default="", metadata={"help": "fsdp2"}
)
tf32: t.Optional[bool] = dataclasses.field(default=None, metadata={"help": "tf32"})


class ArgumentTestCase(TestCase):
Expand Down Expand Up @@ -212,10 +224,20 @@ def test_compose_parser(self) -> None:
assert optional_list_obj.multiple
assert optional_list_obj.default is None

fsdp_obj = compose_parser._long_opt["--fsdp"].obj
assert isinstance(fsdp_obj.type, click.types.FuncParamType)
assert fsdp_obj.type.func == FSDPOption

fsdp_obj2 = compose_parser._long_opt["--fsdp2"].obj
assert fsdp_obj2.type.func == fsdp_obj.type.func

tf32_obj = compose_parser._long_opt["--tf32"].obj
assert tf32_obj.type == click.BOOL

argument_ctx = ArgumentContext.get_current_context()
assert len(argument_ctx._options) == 1
options = argument_ctx._options["tests.sdk.test_argument:ComposeArguments"]
assert len(options) == 6
assert len(options) == 9
assert options[0].name == "debug"
argument_ctx.echo_help()

Expand Down
Loading