diff --git a/client/starwhale/__init__.py b/client/starwhale/__init__.py index 548840fa96..b20f6e07dd 100644 --- a/client/starwhale/__init__.py +++ b/client/starwhale/__init__.py @@ -14,6 +14,7 @@ from starwhale.api.metric import multi_classification from starwhale.api.dataset import Dataset from starwhale.utils.debug import init_logger +from starwhale.api.argument import argument from starwhale.api.instance import login, logout from starwhale.base.context import Context, pass_context from starwhale.api.evaluation import Evaluation, PipelineHandler @@ -47,6 +48,7 @@ __all__ = [ "__version__", + "argument", "model", "Job", "job", diff --git a/client/starwhale/api/_impl/argument.py b/client/starwhale/api/_impl/argument.py new file mode 100644 index 0000000000..3cec30dd98 --- /dev/null +++ b/client/starwhale/api/_impl/argument.py @@ -0,0 +1,141 @@ +from __future__ import annotations + +import typing as t +import inspect +import dataclasses +from functools import wraps + +import click + + +def argument(dataclass_types: t.Any) -> t.Any: + """argument is a decorator function to define arguments for model running(predict, evaluate, serve and finetune). + + The decorated function will receive the instances of the dataclass types as the arguments. + + Argument: + dataclass_types: [required] The dataclass type of the arguments. + A list of dataclass types or a single dataclass type is supported. + + Examples: + ```python + from starwhale import argument, evaluation + + @dataclass + class EvaluationArguments: + reshape: int = field(default=64, metadata={"help": "reshape image size"}) + + @argument(EvaluationArguments) + @evaluation.predict + def predict_image(data, eval_args: EvaluationArguments): + ... + ``` + """ + + def _register_wrapper(func: t.Callable) -> t.Any: + parser = get_parser_from_dataclasses(dataclass_types) + + @wraps(func) + def _run_wrapper(*args: t.Any, **kw: t.Any) -> t.Any: + # parser parse from click ctx + # inject args to func + return func(*args, **kw) + + return _run_wrapper + + return _register_wrapper + + +def get_parser_from_dataclasses(dataclass_types: t.Any) -> click.OptionParser: + if dataclasses.is_dataclass(dataclass_types): + dataclass_types = [dataclass_types] + + parser = click.OptionParser() + for dtype in dataclass_types: + if not dataclasses.is_dataclass(dtype): + raise ValueError(f"{dtype} is not a dataclass type") + + type_hints: t.Dict[str, type] = t.get_type_hints(dtype) + for field in dataclasses.fields(dtype): + if not field.init: + continue + field.type = type_hints[field.name] + add_field_into_parser(parser, field) + + return parser + + +def add_field_into_parser(parser: click.OptionParser, field: dataclasses.Field) -> None: + # TODO: field.name need format for click option? + kw: t.Dict[str, t.Any] = { + "param_decls": [f"--{field.name}"], + "help": field.metadata.get("help"), + "show_default": True, + "hidden": field.metadata.get("hidden", False), + } + + # reference from huggingface transformers: https://github.com/huggingface/transformers/blob/main/src/transformers/hf_argparser.py + # only support: Union[xxx, None] or Union[EnumType, str] or [List[EnumType], str] type + origin_type = getattr(field.type, "__origin__", field.type) + if origin_type is t.Union: + if (str not in field.type.__args and type(None) not in field.type.__args__) or ( + len(field.type.__args__) != 2 + ): + raise ValueError( + f"{field.type} is not supported." + "Only support Union[xxx, None] or Union[EnumType, str] or [List[EnumType], str] type" + ) + + if type(None) in field.type.__args__: + # ignore None type, use another type as the field type + field.type = ( + field.type.__args__[0] + if field.type.__args__[1] == type(None) + else field.type.__args__[1] + ) + origin_type = getattr(field.type, "__origin__", field.type) + else: + # ignore str and None type, use another type as the field type + field.type = ( + field.type.__args__[0] + if field.type.__args__[1] == str + else field.type.__args__[1] + ) + origin_type = getattr(field.type, "__origin__", field.type) + + if origin_type is t.Literal or ( + isinstance(field.type, type) and issubclass(field.type, Enum) + ): + if origin_type is t.Literal: + kw["type"] = click.Choice(field.type.__args__) + else: + kw["type"] = click.Choice([e.value for e in field.type]) + + kw["show_choices"] = True + if field.default is not dataclasses.MISSING: + kw["default"] = field.default + else: + kw["required"] = True + elif field.type is bool or field.type == t.Optional[bool]: + kw["is_flag"] = True + kw["type"] = bool + kw["default"] = False if field.default is dataclasses.MISSING else field.default + elif inspect.isclass(origin_type) and issubclass(origin_type, list): + kw["type"] = field.type.__args__[0] + kw["multiple"] = True + if field.default is not dataclasses.MISSING: + kw["default"] = field.default + elif field.default_factory is not dataclasses.MISSING: + kw["default"] = field.default_factory() + else: + kw["required"] = True + else: + kw["type"] = field.type + if field.default is not dataclasses.MISSING: + kw["default"] = field.default + elif field.default_factory is not dataclasses.MISSING: + kw["default"] = field.default_factory() + else: + kw["required"] = True + + click.Option(**kw).add_to_parser(parser=parser, ctx=None) # type: ignore diff --git a/client/starwhale/api/argument.py b/client/starwhale/api/argument.py new file mode 100644 index 0000000000..b329a240a7 --- /dev/null +++ b/client/starwhale/api/argument.py @@ -0,0 +1,3 @@ +from ._impl.argument import argument + +__all__ = ["argument"] diff --git a/example/helloworld/evaluation.py b/example/helloworld/evaluation.py index c4e7d89ec7..3caf346c72 100644 --- a/example/helloworld/evaluation.py +++ b/example/helloworld/evaluation.py @@ -1,13 +1,19 @@ from pathlib import Path +from dataclasses import field, dataclass import numpy as np import onnxruntime as rt -from starwhale import Image, evaluation, multi_classification +from starwhale import Image, argument, evaluation, multi_classification _g_model = None +@dataclass +class EvaluationArguments: + reshape: int = field(default=64, metadata={"help": "reshape image size"}) + + def _load_model(): global _g_model @@ -19,16 +25,17 @@ def _load_model(): return _g_model +@argument(EvaluationArguments) @evaluation.predict( resources={"memory": {"request": "500M", "limit": "2G"}}, log_mode="plain", ) -def predict_image(data): +def predict_image(data, args: EvaluationArguments): img: Image = data["img"] model = _load_model() input_name = model.get_inputs()[0].name label_name = model.get_outputs()[0].name - input_array = [img.to_numpy().astype(np.float32).reshape(64)] + input_array = [img.to_numpy().astype(np.float32).reshape(args.reshape)] pred = model.run([label_name], {input_name: input_array})[0] return pred.item()