From b62096ee99d41efeb8890b8cb044cfba1fb26b5e Mon Sep 17 00:00:00 2001 From: tianwei Date: Fri, 22 Dec 2023 19:45:00 +0800 Subject: [PATCH 1/3] support argument for model run --- client/starwhale/__init__.py | 2 + client/starwhale/api/_impl/argument.py | 192 ++++++++++++++++++ .../api/_impl/evaluation/pipeline.py | 25 ++- client/starwhale/api/_impl/experiment.py | 5 +- client/starwhale/api/argument.py | 3 + client/starwhale/core/model/cli.py | 18 +- example/helloworld/evaluation.py | 16 +- 7 files changed, 247 insertions(+), 14 deletions(-) create mode 100644 client/starwhale/api/_impl/argument.py create mode 100644 client/starwhale/api/argument.py diff --git a/client/starwhale/__init__.py b/client/starwhale/__init__.py index 548840fa96..b20f6e07dd 100644 --- a/client/starwhale/__init__.py +++ b/client/starwhale/__init__.py @@ -14,6 +14,7 @@ from starwhale.api.metric import multi_classification from starwhale.api.dataset import Dataset from starwhale.utils.debug import init_logger +from starwhale.api.argument import argument from starwhale.api.instance import login, logout from starwhale.base.context import Context, pass_context from starwhale.api.evaluation import Evaluation, PipelineHandler @@ -47,6 +48,7 @@ __all__ = [ "__version__", + "argument", "model", "Job", "job", diff --git a/client/starwhale/api/_impl/argument.py b/client/starwhale/api/_impl/argument.py new file mode 100644 index 0000000000..0973548477 --- /dev/null +++ b/client/starwhale/api/_impl/argument.py @@ -0,0 +1,192 @@ +from __future__ import annotations + +import typing as t +import inspect +import threading +import dataclasses +from enum import Enum +from functools import wraps + +import click + +from starwhale.utils import console + + +# TODO: use a more elegant way to pass extra cli args +class ExtraCliArgsRegistry: + _args = None + _lock = threading.Lock() + + @classmethod + def set(cls, args: t.List[str]) -> None: + with cls._lock: + cls._args = args + + @classmethod + def get(cls) -> t.List[str]: + with cls._lock: + return cls._args or [] + + +def argument(dataclass_types: t.Any) -> t.Any: + """argument is a decorator function to define arguments for model running(predict, evaluate, serve and finetune). + + The decorated function will receive the instances of the dataclass types as the arguments. + When the decorated function is called, the command line arguments will be parsed to the dataclass instances + and passed to the decorated function as the keyword arguments that name is "argument". + + When use argument decorator, the decorated function must have a keyword argument named "argument" or use "**kw" keyword arguments. + + Argument: + dataclass_types: [required] The dataclass type of the arguments. + A list of dataclass types or a single dataclass type is supported. + + Examples: + ```python + from starwhale import argument, evaluation + + @dataclass + class EvaluationArguments: + reshape: int = field(default=64, metadata={"help": "reshape image size"}) + + @argument(EvaluationArguments) + @evaluation.predict + def predict_image(data, argument: EvaluationArguments): + ... + ``` + """ + is_sequence = True + if dataclasses.is_dataclass(dataclass_types): + dataclass_types = [dataclass_types] + is_sequence = False + + def _register_wrapper(func: t.Callable) -> t.Any: + # TODO: add `--help` for the arguments + # TODO: dump parser to json file when model building + # TODO: `@handler` decorator function supports @argument decorator + parser = get_parser_from_dataclasses(dataclass_types) + + @wraps(func) + def _run_wrapper(*args: t.Any, **kw: t.Any) -> t.Any: + dataclass_values = init_dataclasses_values(parser, dataclass_types) + if "argument" in kw: + raise RuntimeError( + "argument is a reserved keyword for @starwhale.argument decorator in the " + ) + kw["argument"] = dataclass_values if is_sequence else dataclass_values[0] + return func(*args, **kw) + + return _run_wrapper + + return _register_wrapper + + +def init_dataclasses_values( + parser: click.OptionParser, dataclass_types: t.Any +) -> t.Any: + args_map, _, params = parser.parse_args(ExtraCliArgsRegistry.get()) + param_map = {p.name: p for p in params} + + ret = [] + for dtype in dataclass_types: + keys = {f.name for f in dataclasses.fields(dtype) if f.init} + inputs = {k: param_map[k].type(v) for k, v in args_map.items() if k in keys} + for k in inputs: + del args_map[k] + ret.append(dtype(**inputs)) + if args_map: + console.warn(f"Unused args from command line: {args_map}") + return ret + + +def get_parser_from_dataclasses(dataclass_types: t.Any) -> click.OptionParser: + parser = click.OptionParser() + for dtype in dataclass_types: + if not dataclasses.is_dataclass(dtype): + raise ValueError(f"{dtype} is not a dataclass type") + + type_hints: t.Dict[str, type] = t.get_type_hints(dtype) + for field in dataclasses.fields(dtype): + if not field.init: + continue + field.type = type_hints[field.name] + add_field_into_parser(parser, field) + + parser.ignore_unknown_options = True + return parser + + +def add_field_into_parser(parser: click.OptionParser, field: dataclasses.Field) -> None: + # TODO: field.name need format for click option? + kw: t.Dict[str, t.Any] = { + "param_decls": [f"--{field.name}"], + "help": field.metadata.get("help"), + "show_default": True, + "hidden": field.metadata.get("hidden", False), + } + + # reference from huggingface transformers: https://github.com/huggingface/transformers/blob/main/src/transformers/hf_argparser.py + # only support: Union[xxx, None] or Union[EnumType, str] or [List[EnumType], str] type + origin_type = getattr(field.type, "__origin__", field.type) + if origin_type is t.Union: + if (str not in field.type.__args and type(None) not in field.type.__args__) or ( + len(field.type.__args__) != 2 + ): + raise ValueError( + f"{field.type} is not supported." + "Only support Union[xxx, None] or Union[EnumType, str] or [List[EnumType], str] type" + ) + + if type(None) in field.type.__args__: + # ignore None type, use another type as the field type + field.type = ( + field.type.__args__[0] + if field.type.__args__[1] == type(None) + else field.type.__args__[1] + ) + origin_type = getattr(field.type, "__origin__", field.type) + else: + # ignore str and None type, use another type as the field type + field.type = ( + field.type.__args__[0] + if field.type.__args__[1] == str + else field.type.__args__[1] + ) + origin_type = getattr(field.type, "__origin__", field.type) + + if origin_type is t.Literal or ( + isinstance(field.type, type) and issubclass(field.type, Enum) + ): + if origin_type is t.Literal: + kw["type"] = click.Choice(field.type.__args__) + else: + kw["type"] = click.Choice([e.value for e in field.type]) + + kw["show_choices"] = True + if field.default is not dataclasses.MISSING: + kw["default"] = field.default + else: + kw["required"] = True + elif field.type is bool or field.type == t.Optional[bool]: + kw["is_flag"] = True + kw["type"] = bool + kw["default"] = False if field.default is dataclasses.MISSING else field.default + elif inspect.isclass(origin_type) and issubclass(origin_type, list): + kw["type"] = field.type.__args__[0] + kw["multiple"] = True + if field.default is not dataclasses.MISSING: + kw["default"] = field.default + elif field.default_factory is not dataclasses.MISSING: + kw["default"] = field.default_factory() + else: + kw["required"] = True + else: + kw["type"] = field.type + if field.default is not dataclasses.MISSING: + kw["default"] = field.default + elif field.default_factory is not dataclasses.MISSING: + kw["default"] = field.default_factory() + else: + kw["required"] = True + + click.Option(**kw).add_to_parser(parser=parser, ctx=None) # type: ignore diff --git a/client/starwhale/api/_impl/evaluation/pipeline.py b/client/starwhale/api/_impl/evaluation/pipeline.py index b1aec3f024..00b1698f51 100644 --- a/client/starwhale/api/_impl/evaluation/pipeline.py +++ b/client/starwhale/api/_impl/evaluation/pipeline.py @@ -232,12 +232,18 @@ def _do_predict( # case1: only accept data argument # 1. def predict(self, data): ... # 2. def predict(self, data, /): ... + # 3. def predict(self, *args): ... # case2: accept data and external arguments - # 1. def predict(self, *args): ... # 2. def predict(self, **kwargs): ... - # 3. def predict(self, *args, **kwargs): ... - # 4. def predict(self, data, external: t.Dict): ... - # 5. def predict(self, data, **kwargs): ... + # case3: accept two parameters + # 1. def predict(self, data, external=None): ... + # 2. def predict(self, data, **kwargs): ... + # 3. def predict(self, data, argument=None): ... + # 4. def predict(self, *args, **kwargs): ... + # case4: accept more than two parameters + # 1. def predict(self, data, external=None, argument=None): ... + # 2. def predict(self, data, external=None, **kwargs): ... + # 3. def predict(self, data, argument=None, **kwargs): ... kind = inspect._ParameterKind @@ -249,7 +255,7 @@ def _do_predict( if len(parameters) <= 0: raise RuntimeError("predict/ppl function must have at least one argument") elif len(parameters) == 1: - parameter: inspect.Parameter = list(parameters.values())[0] + parameter = list(parameters.values())[0] if parameter.kind == kind.VAR_POSITIONAL: return func(data, external) elif parameter.kind == kind.VAR_KEYWORD: @@ -260,6 +266,15 @@ def _do_predict( raise RuntimeError( f"unsupported parameter kind for predict/ppl function: {parameter.kind}" ) + elif len(parameters) == 2: + # whether to inject `external` argument: + # case1: def predict(self, data, external=None): ... + # case2: def predict(self, data, **kw): ... + parameter = list(parameters.values())[-1] + if parameter.kind == kind.VAR_KEYWORD or parameter.name == "external": + return func(data, external=external) + else: + return func(data) else: return func(data, external=external) diff --git a/client/starwhale/api/_impl/experiment.py b/client/starwhale/api/_impl/experiment.py index ea43604860..8efba4b668 100644 --- a/client/starwhale/api/_impl/experiment.py +++ b/client/starwhale/api/_impl/experiment.py @@ -75,7 +75,7 @@ def _register_wrapper(func: t.Callable) -> t.Any: ) @wraps(func) - def _run_wrapper(*args: t.Any, **kw: t.Any) -> t.Any: + def _run_wrapper(*func_args: t.Any, **func_kw: t.Any) -> t.Any: ctx = Context.get_runtime_context() load_dataset = partial(Dataset.dataset, readonly=True, create="forbid") @@ -97,7 +97,8 @@ def _run_wrapper(*args: t.Any, **kw: t.Any) -> t.Any: # TODO: support arguments from command line add_event(f"Start to finetune model by {func.__qualname__} function") - ret = func(*inject_args) + + ret = func(*(inject_args + list(func_args)), **func_kw) if auto_build_model: console.info(f"building starwhale model package from {workdir}") diff --git a/client/starwhale/api/argument.py b/client/starwhale/api/argument.py new file mode 100644 index 0000000000..e6cd775803 --- /dev/null +++ b/client/starwhale/api/argument.py @@ -0,0 +1,3 @@ +from ._impl.argument import argument, ExtraCliArgsRegistry + +__all__ = ["argument", "ExtraCliArgsRegistry"] diff --git a/client/starwhale/core/model/cli.py b/client/starwhale/core/model/cli.py index 18884226b3..2d70d6834c 100644 --- a/client/starwhale/core/model/cli.py +++ b/client/starwhale/core/model/cli.py @@ -435,7 +435,13 @@ def _recover(model: str, force: bool) -> None: ModelTermView(model).recover(force) -@model_cmd.command("run") +@model_cmd.command( + "run", + context_settings=dict( + ignore_unknown_options=True, + allow_extra_args=True, + ), +) @optgroup.group( "\n ** Model Selectors", cls=RequiredMutuallyExclusiveOptionGroup, @@ -579,8 +585,9 @@ def _recover(model: str, force: bool) -> None: multiple=True, help=f"validation dataset uri for finetune, env is {SWEnv.finetune_validation_dataset_uri}", ) -@click.argument("handler_args", nargs=-1) +@click.pass_context def _run( + ctx: click.Context, workdir: str, uri: str, handler: int | str, @@ -602,7 +609,6 @@ def _run( forbid_packaged_runtime: bool, forbid_snapshot: bool, cleanup_snapshot: bool, - handler_args: t.Tuple[str], ) -> None: """Run Model. Model Package and the model source directory are supported. @@ -639,6 +645,10 @@ def _run( # --> run with finetune validation dataset swcli model run --workdir . -m mnist.finetune --dataset mnist --val-dataset mnist-val """ + from starwhale.api.argument import ExtraCliArgsRegistry + + ExtraCliArgsRegistry.set(ctx.args) + # TODO: support run model in cluster mode run_project_uri = Project(run_project) log_project_uri = Project(log_project) @@ -719,7 +729,7 @@ def _run( "task_num": override_task_num, }, force_generate_jobs_yaml=uri is None, - handler_args=list(handler_args) if handler_args else [], + handler_args=ctx.args, ) diff --git a/example/helloworld/evaluation.py b/example/helloworld/evaluation.py index c4e7d89ec7..cb9b0ddf4b 100644 --- a/example/helloworld/evaluation.py +++ b/example/helloworld/evaluation.py @@ -1,13 +1,19 @@ from pathlib import Path +from dataclasses import field, dataclass import numpy as np import onnxruntime as rt -from starwhale import Image, evaluation, multi_classification +from starwhale import Image, argument, evaluation, multi_classification _g_model = None +@dataclass +class EvaluationArguments: + reshape: int = field(default=64, metadata={"help": "reshape image size"}) + + def _load_model(): global _g_model @@ -19,16 +25,20 @@ def _load_model(): return _g_model +@argument(EvaluationArguments) @evaluation.predict( resources={"memory": {"request": "500M", "limit": "2G"}}, log_mode="plain", ) -def predict_image(data): +def predict_image(data: dict, argument: EvaluationArguments): + # def predict_image(data: dict, argument: EvaluationArguments): + # def predict_image(data: dict, external=None): + # def predict_image(data: dict, external=None, argument: EvaluationArguments=None): img: Image = data["img"] model = _load_model() input_name = model.get_inputs()[0].name label_name = model.get_outputs()[0].name - input_array = [img.to_numpy().astype(np.float32).reshape(64)] + input_array = [img.to_numpy().astype(np.float32).reshape(argument.reshape)] pred = model.run([label_name], {input_name: input_array})[0] return pred.item() From a8b3f3e127f0d04861093dad5e70bcccab4565a2 Mon Sep 17 00:00:00 2001 From: tianwei Date: Wed, 27 Dec 2023 17:08:23 +0800 Subject: [PATCH 2/3] add argument decorator test --- client/starwhale/api/_impl/argument.py | 52 +++++-- client/tests/sdk/test_argument.py | 196 +++++++++++++++++++++++++ client/tests/sdk/test_job_handler.py | 10 +- scripts/example/src/evaluator.py | 18 ++- 4 files changed, 259 insertions(+), 17 deletions(-) create mode 100644 client/tests/sdk/test_argument.py diff --git a/client/starwhale/api/_impl/argument.py b/client/starwhale/api/_impl/argument.py index 0973548477..447c750001 100644 --- a/client/starwhale/api/_impl/argument.py +++ b/client/starwhale/api/_impl/argument.py @@ -90,7 +90,19 @@ def init_dataclasses_values( ret = [] for dtype in dataclass_types: keys = {f.name for f in dataclasses.fields(dtype) if f.init} - inputs = {k: param_map[k].type(v) for k, v in args_map.items() if k in keys} + inputs = {} + for k, v in args_map.items(): + if k not in keys: + continue + + # TODO: support dict type convert + # handle multiple args for list type + if isinstance(v, list): + v = [param_map[k].type(i) for i in v] + else: + v = param_map[k].type(v) + inputs[k] = v + for k in inputs: del args_map[k] ret.append(dtype(**inputs)) @@ -118,8 +130,11 @@ def get_parser_from_dataclasses(dataclass_types: t.Any) -> click.OptionParser: def add_field_into_parser(parser: click.OptionParser, field: dataclasses.Field) -> None: # TODO: field.name need format for click option? + decls = [f"--{field.name}"] + if "_" in field.name: + decls.append(f"--{field.name.replace('_', '-')}") kw: t.Dict[str, t.Any] = { - "param_decls": [f"--{field.name}"], + "param_decls": decls, "help": field.metadata.get("help"), "show_default": True, "hidden": field.metadata.get("hidden", False), @@ -129,9 +144,9 @@ def add_field_into_parser(parser: click.OptionParser, field: dataclasses.Field) # only support: Union[xxx, None] or Union[EnumType, str] or [List[EnumType], str] type origin_type = getattr(field.type, "__origin__", field.type) if origin_type is t.Union: - if (str not in field.type.__args and type(None) not in field.type.__args__) or ( - len(field.type.__args__) != 2 - ): + if ( + str not in field.type.__args__ and type(None) not in field.type.__args__ + ) or (len(field.type.__args__) != 2): raise ValueError( f"{field.type} is not supported." "Only support Union[xxx, None] or Union[EnumType, str] or [List[EnumType], str] type" @@ -154,10 +169,16 @@ def add_field_into_parser(parser: click.OptionParser, field: dataclasses.Field) ) origin_type = getattr(field.type, "__origin__", field.type) - if origin_type is t.Literal or ( + try: + # typing.Literal is only supported in python3.8+ + literal_type = t.Literal # type: ignore[attr-defined] + except AttributeError: + literal_type = None + + if (literal_type and origin_type is literal_type) or ( isinstance(field.type, type) and issubclass(field.type, Enum) ): - if origin_type is t.Literal: + if literal_type and origin_type is literal_type: kw["type"] = click.Choice(field.type.__args__) else: kw["type"] = click.Choice([e.value for e in field.type]) @@ -171,14 +192,17 @@ def add_field_into_parser(parser: click.OptionParser, field: dataclasses.Field) kw["is_flag"] = True kw["type"] = bool kw["default"] = False if field.default is dataclasses.MISSING else field.default - elif inspect.isclass(origin_type) and issubclass(origin_type, list): - kw["type"] = field.type.__args__[0] - kw["multiple"] = True - if field.default is not dataclasses.MISSING: - kw["default"] = field.default - elif field.default_factory is not dataclasses.MISSING: + elif inspect.isclass(origin_type) and issubclass(origin_type, (list, dict)): + if issubclass(origin_type, list): + kw["type"] = field.type.__args__[0] + kw["multiple"] = True + elif issubclass(origin_type, dict): + kw["type"] = dict + + # list and dict types both need default_factory + if field.default_factory is not dataclasses.MISSING: kw["default"] = field.default_factory() - else: + elif field.default is dataclasses.MISSING: kw["required"] = True else: kw["type"] = field.type diff --git a/client/tests/sdk/test_argument.py b/client/tests/sdk/test_argument.py new file mode 100644 index 0000000000..7bf5590647 --- /dev/null +++ b/client/tests/sdk/test_argument.py @@ -0,0 +1,196 @@ +from __future__ import annotations + +import typing as t +import dataclasses +from enum import Enum + +import click +from pyfakefs.fake_filesystem_unittest import TestCase + +from starwhale.api._impl.argument import argument as argument_decorator +from starwhale.api._impl.argument import ( + ExtraCliArgsRegistry, + get_parser_from_dataclasses, +) + + +class IntervalStrategy(Enum): + NO = "no" + STEPS = "steps" + EPOCH = "epoch" + + +class DebugOption(Enum): + UNDERFLOW_OVERFLOW = "underflow_overflow" + TPU_METRICS_DEBUG = "tpu_metrics_debug" + + +@dataclasses.dataclass +class ScalarArguments: + no_field = 1 + batch: int = dataclasses.field(default=64, metadata={"help": "batch size"}) + overwrite: bool = dataclasses.field(default=False, metadata={"help": "overwrite"}) + learning_rate: float = dataclasses.field( + default=0.01, metadata={"help": "learning rate"} + ) + half_precision_backend: str = dataclasses.field( + default="auto", metadata={"help": "half precision backend"} + ) + epoch: int = dataclasses.field(default_factory=lambda: 1) + + +@dataclasses.dataclass +class ComposeArguments: + # simply huggingface transformers TrainingArguments for test + debug: t.Union[str, t.List[DebugOption]] = dataclasses.field( + default="", metadata={"help": "debug mode"} + ) + + lr_scheduler_kwargs: t.Optional[t.Dict] = dataclasses.field( + default_factory=dict, metadata={"help": "lr scheduler kwargs"} + ) + evaluation_strategy: t.Union[IntervalStrategy, str] = dataclasses.field( + default="no", metadata={"help": "evaluation strategy"} + ) + per_gpu_train_batch_size: t.Optional[int] = dataclasses.field(default=None) + eval_delay: t.Optional[float] = dataclasses.field( + default=0, metadata={"help": "evaluation delay"} + ) + label_names: t.Optional[t.List[str]] = dataclasses.field( + default=None, metadata={"help": "label names"} + ) + + +class ArgumentTestCase(TestCase): + def setUp(self) -> None: + self.setUpPyfakefs() + + def tearDown(self) -> None: + ExtraCliArgsRegistry._args = None + + def test_argument_exceptions(self) -> None: + @argument_decorator(ScalarArguments) + def no_argument_func(): + ... + + @argument_decorator(ScalarArguments) + def argument_keyword_func(argument): + ... + + with self.assertRaisesRegex(TypeError, "got an unexpected keyword argument"): + no_argument_func() + + with self.assertRaisesRegex(RuntimeError, "argument is a reserved keyword"): + argument_keyword_func(argument=1) + + def test_argument_decorator(self) -> None: + @argument_decorator((ScalarArguments, ComposeArguments)) + def assert_func(argument: t.Tuple) -> None: + scalar_argument, compose_argument = argument + assert isinstance(scalar_argument, ScalarArguments) + assert isinstance(compose_argument, ComposeArguments) + + assert scalar_argument.batch == 128 + assert scalar_argument.overwrite is True + assert scalar_argument.learning_rate == 0.02 + assert scalar_argument.half_precision_backend == "auto" + assert scalar_argument.epoch == 1 + + assert compose_argument.label_names == ["a", "b", "c"] + assert compose_argument.eval_delay == 0 + assert compose_argument.per_gpu_train_batch_size == 8 + assert compose_argument.evaluation_strategy == "steps" + assert compose_argument.debug == [DebugOption.UNDERFLOW_OVERFLOW] + + ExtraCliArgsRegistry.set( + [ + "--batch", + "128", + "--overwrite", + "--learning-rate=0.02", + "--debug", + "underflow_overflow", + "--evaluation_strategy", + "steps", + "--per_gpu_train_batch_size", + "8", + "--label_names", + "a", + "--label_names", + "b", + "--label_names", + "c", + "--no-defined-arg=1", + ] + ) + assert_func() + + def test_parser_exceptions(self) -> None: + with self.assertRaisesRegex(ValueError, "is not a dataclass type"): + get_parser_from_dataclasses([None]) + + def test_scalar_parser(self) -> None: + scalar_parser = get_parser_from_dataclasses([ScalarArguments]) + assert scalar_parser.ignore_unknown_options + + assert "--no_field" not in scalar_parser._long_opt + + batch = scalar_parser._long_opt["--batch"].obj + assert batch.type == click.INT + assert not batch.required + assert batch.help == "batch size" + assert not batch.is_flag + assert batch.default == 64 + overwrite = scalar_parser._long_opt["--overwrite"].obj + assert overwrite.type == click.BOOL + assert overwrite.is_flag + assert overwrite.default is False + assert scalar_parser._long_opt["--learning-rate"].obj.type == click.FLOAT + assert ( + scalar_parser._long_opt["--half_precision_backend"].obj.type == click.STRING + ) + assert scalar_parser._long_opt["--epoch"].obj.type == click.INT + assert scalar_parser._long_opt["--epoch"].obj.default == 1 + + def test_compose_parser(self) -> None: + compose_parser = get_parser_from_dataclasses([ComposeArguments]) + + dict_obj = compose_parser._long_opt["--lr-scheduler-kwargs"].obj + assert not dict_obj.required + assert dict_obj.default == {} + assert not dict_obj.multiple + assert isinstance(dict_obj.type, click.types.FuncParamType) + assert dict_obj.type.func == dict + + union_enum_obj = compose_parser._long_opt["--evaluation_strategy"].obj + assert not union_enum_obj.required + assert union_enum_obj.default == "no" + assert isinstance(union_enum_obj.type, click.Choice) + assert union_enum_obj.type.choices == ["no", "steps", "epoch"] + assert union_enum_obj.show_choices + assert not union_enum_obj.multiple + + union_list_obj = compose_parser._long_opt["--debug"].obj + assert isinstance(union_list_obj.type, click.types.FuncParamType) + assert union_list_obj.type.func == DebugOption + assert not union_list_obj.required + assert union_list_obj.default is None + assert union_list_obj.multiple + + optional_int_obj = compose_parser._long_opt["--per_gpu_train_batch_size"].obj + assert optional_int_obj.type == click.INT + assert not optional_int_obj.required + assert optional_int_obj.default is None + assert not optional_int_obj.multiple + + optional_float_obj = compose_parser._long_opt["--eval_delay"].obj + assert optional_float_obj.type == click.FLOAT + assert not optional_float_obj.required + assert optional_float_obj.default == 0 + assert not optional_float_obj.multiple + + optional_list_obj = compose_parser._long_opt["--label_names"].obj + assert optional_list_obj.type == click.STRING + assert not optional_list_obj.required + assert optional_list_obj.multiple + assert optional_list_obj.default is None diff --git a/client/tests/sdk/test_job_handler.py b/client/tests/sdk/test_job_handler.py index 89e9f02378..5e26f4439b 100644 --- a/client/tests/sdk/test_job_handler.py +++ b/client/tests/sdk/test_job_handler.py @@ -193,7 +193,7 @@ def _ensure_py_script(self, content: str) -> None: def test_multi_predict_decorators(self) -> None: content = """ -from starwhale import evaluation +from starwhale import evaluation, argument @evaluation.predict def img_predict_handler(*args, **kwargs): ... @@ -206,6 +206,12 @@ def video_predict_handler(*args, **kwargs): ... @evaluation.evaluate(needs=[video_predict_handler]) def video_evaluate_handler(*args, **kwargs): ... + +@evaluation.predict +def mock_predict_handler1(data, external): ... + +@evaluation.predict +def mock_predict_handler2(data, argument=None): ... """ self._ensure_py_script(content) yaml_path = self.workdir / "job.yaml" @@ -219,6 +225,8 @@ def video_evaluate_handler(*args, **kwargs): ... "mock_user_module:img_predict_handler", "mock_user_module:video_evaluate_handler", "mock_user_module:video_predict_handler", + "mock_user_module:mock_predict_handler1", + "mock_user_module:mock_predict_handler2", } == set(jobs_info.keys()) assert jobs_info["mock_user_module:img_evaluate_handler"] == [ StepSpecClient( diff --git a/scripts/example/src/evaluator.py b/scripts/example/src/evaluator.py index 853534802f..fe1aa5e95a 100644 --- a/scripts/example/src/evaluator.py +++ b/scripts/example/src/evaluator.py @@ -2,6 +2,7 @@ import random import typing as t import os.path as osp +import dataclasses from functools import wraps import numpy @@ -12,6 +13,7 @@ Context, Dataset, handler, + argument, IntInput, ListInput, evaluation, @@ -28,6 +30,11 @@ from util import random_image +@dataclasses.dataclass +class TestArguments: + epoch: int = dataclasses.field(default=10, metadata={"help": "epoch"}) + + def timing(func: t.Callable) -> t.Any: @wraps(func) def wrapper(*args: t.Any, **kwargs: t.Any) -> t.Any: @@ -45,13 +52,16 @@ def wrapper(*args: t.Any, **kwargs: t.Any) -> t.Any: log_mode="plain", log_dataset_features=["txt", "img", "label"], ) -def predict(data: t.Dict, external: t.Dict) -> t.Any: +@argument(TestArguments) +def predict(data: t.Dict, external: t.Dict, argument) -> t.Any: # Test relative path case file_name = osp.join("templates", "data.json") assert osp.exists(file_name) assert isinstance(external["context"], Context) assert external["dataset_uri"].name assert external["dataset_uri"].version + assert isinstance(argument, TestArguments) + assert argument.epoch == 10 if in_container(): assert osp.exists("/tmp/runtime-command-run.flag") @@ -74,7 +84,11 @@ def predict(data: t.Dict, external: t.Dict) -> t.Any: show_roc_auc=True, all_labels=[f"label-{i}" for i in range(0, 5)], ) -def evaluate(ppl_result: t.Iterator): +@argument(TestArguments) +def evaluate(ppl_result: t.Iterator, argument: TestArguments) -> t.Any: + assert isinstance(argument, TestArguments) + assert argument.epoch == 10 + result, label, pr = [], [], [] for _data in ppl_result: assert _data["_mode"] == "plain" From 29175f32fbb583232038797290a827bed98abb1f Mon Sep 17 00:00:00 2001 From: tianwei Date: Thu, 28 Dec 2023 16:24:29 +0800 Subject: [PATCH 3/3] support inject_name option for argument decorator --- client/starwhale/api/_impl/argument.py | 15 +++++++++++---- client/tests/sdk/test_argument.py | 13 +++++++++---- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/client/starwhale/api/_impl/argument.py b/client/starwhale/api/_impl/argument.py index 447c750001..77ee36b625 100644 --- a/client/starwhale/api/_impl/argument.py +++ b/client/starwhale/api/_impl/argument.py @@ -28,7 +28,7 @@ def get(cls) -> t.List[str]: return cls._args or [] -def argument(dataclass_types: t.Any) -> t.Any: +def argument(dataclass_types: t.Any, inject_name: str = "argument") -> t.Any: """argument is a decorator function to define arguments for model running(predict, evaluate, serve and finetune). The decorated function will receive the instances of the dataclass types as the arguments. @@ -40,6 +40,8 @@ def argument(dataclass_types: t.Any) -> t.Any: Argument: dataclass_types: [required] The dataclass type of the arguments. A list of dataclass types or a single dataclass type is supported. + inject_name: [optional] The name of the keyword argument that will be passed to the decorated function. + Default is "argument". Examples: ```python @@ -53,6 +55,11 @@ class EvaluationArguments: @evaluation.predict def predict_image(data, argument: EvaluationArguments): ... + + @argument(EvaluationArguments, inject_name="starwhale_arguments") + @evaluation.evaluate(needs=[]) + def evaluate_summary(predict_result_iter, starwhale_arguments: EvaluationArguments): + ... ``` """ is_sequence = True @@ -69,11 +76,11 @@ def _register_wrapper(func: t.Callable) -> t.Any: @wraps(func) def _run_wrapper(*args: t.Any, **kw: t.Any) -> t.Any: dataclass_values = init_dataclasses_values(parser, dataclass_types) - if "argument" in kw: + if inject_name in kw: raise RuntimeError( - "argument is a reserved keyword for @starwhale.argument decorator in the " + f"{inject_name} has been used as a keyword argument in the decorated function, please use another name by the `inject_name` option." ) - kw["argument"] = dataclass_values if is_sequence else dataclass_values[0] + kw[inject_name] = dataclass_values if is_sequence else dataclass_values[0] return func(*args, **kw) return _run_wrapper diff --git a/client/tests/sdk/test_argument.py b/client/tests/sdk/test_argument.py index 7bf5590647..9dfe80fa32 100644 --- a/client/tests/sdk/test_argument.py +++ b/client/tests/sdk/test_argument.py @@ -80,13 +80,18 @@ def argument_keyword_func(argument): with self.assertRaisesRegex(TypeError, "got an unexpected keyword argument"): no_argument_func() - with self.assertRaisesRegex(RuntimeError, "argument is a reserved keyword"): + with self.assertRaisesRegex( + RuntimeError, + "has been used as a keyword argument in the decorated function", + ): argument_keyword_func(argument=1) def test_argument_decorator(self) -> None: - @argument_decorator((ScalarArguments, ComposeArguments)) - def assert_func(argument: t.Tuple) -> None: - scalar_argument, compose_argument = argument + @argument_decorator( + (ScalarArguments, ComposeArguments), inject_name="starwhale_argument" + ) + def assert_func(starwhale_argument: t.Tuple) -> None: + scalar_argument, compose_argument = starwhale_argument assert isinstance(scalar_argument, ScalarArguments) assert isinstance(compose_argument, ComposeArguments)