-
Notifications
You must be signed in to change notification settings - Fork 37
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4949eea
commit 74b4fc0
Showing
4 changed files
with
156 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
from __future__ import annotations | ||
|
||
import typing as t | ||
import inspect | ||
import dataclasses | ||
from functools import wraps | ||
|
||
import click | ||
|
||
|
||
def argument(dataclass_types: t.Any) -> t.Any: | ||
"""argument is a decorator function to define arguments for model running(predict, evaluate, serve and finetune). | ||
The decorated function will receive the instances of the dataclass types as the arguments. | ||
Argument: | ||
dataclass_types: [required] The dataclass type of the arguments. | ||
A list of dataclass types or a single dataclass type is supported. | ||
Examples: | ||
```python | ||
from starwhale import argument, evaluation | ||
@dataclass | ||
class EvaluationArguments: | ||
reshape: int = field(default=64, metadata={"help": "reshape image size"}) | ||
@argument(EvaluationArguments) | ||
@evaluation.predict | ||
def predict_image(data, eval_args: EvaluationArguments): | ||
... | ||
``` | ||
""" | ||
|
||
def _register_wrapper(func: t.Callable) -> t.Any: | ||
parser = get_parser_from_dataclasses(dataclass_types) | ||
|
||
@wraps(func) | ||
def _run_wrapper(*args: t.Any, **kw: t.Any) -> t.Any: | ||
# parser parse from click ctx | ||
# inject args to func | ||
return func(*args, **kw) | ||
|
||
return _run_wrapper | ||
|
||
return _register_wrapper | ||
|
||
|
||
def get_parser_from_dataclasses(dataclass_types: t.Any) -> click.OptionParser: | ||
if dataclasses.is_dataclass(dataclass_types): | ||
dataclass_types = [dataclass_types] | ||
|
||
parser = click.OptionParser() | ||
for dtype in dataclass_types: | ||
if not dataclasses.is_dataclass(dtype): | ||
raise ValueError(f"{dtype} is not a dataclass type") | ||
|
||
type_hints: t.Dict[str, type] = t.get_type_hints(dtype) | ||
for field in dataclasses.fields(dtype): | ||
if not field.init: | ||
continue | ||
field.type = type_hints[field.name] | ||
add_field_into_parser(parser, field) | ||
|
||
return parser | ||
|
||
|
||
def add_field_into_parser(parser: click.OptionParser, field: dataclasses.Field) -> None: | ||
# TODO: field.name need format for click option? | ||
kw: t.Dict[str, t.Any] = { | ||
"param_decls": [f"--{field.name}"], | ||
"help": field.metadata.get("help"), | ||
"show_default": True, | ||
"hidden": field.metadata.get("hidden", False), | ||
} | ||
|
||
# reference from huggingface transformers: https://github.com/huggingface/transformers/blob/main/src/transformers/hf_argparser.py | ||
# only support: Union[xxx, None] or Union[EnumType, str] or [List[EnumType], str] type | ||
origin_type = getattr(field.type, "__origin__", field.type) | ||
if origin_type is t.Union: | ||
if (str not in field.type.__args and type(None) not in field.type.__args__) or ( | ||
len(field.type.__args__) != 2 | ||
): | ||
raise ValueError( | ||
f"{field.type} is not supported." | ||
"Only support Union[xxx, None] or Union[EnumType, str] or [List[EnumType], str] type" | ||
) | ||
|
||
if type(None) in field.type.__args__: | ||
# ignore None type, use another type as the field type | ||
field.type = ( | ||
field.type.__args__[0] | ||
if field.type.__args__[1] == type(None) | ||
else field.type.__args__[1] | ||
) | ||
origin_type = getattr(field.type, "__origin__", field.type) | ||
else: | ||
# ignore str and None type, use another type as the field type | ||
field.type = ( | ||
field.type.__args__[0] | ||
if field.type.__args__[1] == str | ||
else field.type.__args__[1] | ||
) | ||
origin_type = getattr(field.type, "__origin__", field.type) | ||
|
||
if origin_type is t.Literal or ( | ||
isinstance(field.type, type) and issubclass(field.type, Enum) | ||
): | ||
if origin_type is t.Literal: | ||
kw["type"] = click.Choice(field.type.__args__) | ||
else: | ||
kw["type"] = click.Choice([e.value for e in field.type]) | ||
|
||
kw["show_choices"] = True | ||
if field.default is not dataclasses.MISSING: | ||
kw["default"] = field.default | ||
else: | ||
kw["required"] = True | ||
elif field.type is bool or field.type == t.Optional[bool]: | ||
kw["is_flag"] = True | ||
kw["type"] = bool | ||
kw["default"] = False if field.default is dataclasses.MISSING else field.default | ||
elif inspect.isclass(origin_type) and issubclass(origin_type, list): | ||
kw["type"] = field.type.__args__[0] | ||
kw["multiple"] = True | ||
if field.default is not dataclasses.MISSING: | ||
kw["default"] = field.default | ||
elif field.default_factory is not dataclasses.MISSING: | ||
kw["default"] = field.default_factory() | ||
else: | ||
kw["required"] = True | ||
else: | ||
kw["type"] = field.type | ||
if field.default is not dataclasses.MISSING: | ||
kw["default"] = field.default | ||
elif field.default_factory is not dataclasses.MISSING: | ||
kw["default"] = field.default_factory() | ||
else: | ||
kw["required"] = True | ||
|
||
click.Option(**kw).add_to_parser(parser=parser, ctx=None) # type: ignore |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from ._impl.argument import argument | ||
|
||
__all__ = ["argument"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters