diff --git a/client/starwhale/api/_impl/argument.py b/client/starwhale/api/_impl/argument.py index 77ee36b625..2c0944e555 100644 --- a/client/starwhale/api/_impl/argument.py +++ b/client/starwhale/api/_impl/argument.py @@ -6,6 +6,7 @@ import dataclasses from enum import Enum from functools import wraps +from collections import defaultdict import click @@ -28,6 +29,46 @@ def get(cls) -> t.List[str]: return cls._args or [] +class ArgumentContext: + _instance = None + _lock = threading.Lock() + + def __init__(self) -> None: + self._click_ctx = click.Context(click.Command("Starwhale Argument Decorator")) + self._options: t.Dict[str, list] = defaultdict(list) + + @classmethod + def get_current_context(cls) -> ArgumentContext: + with cls._lock: + if cls._instance is None: + cls._instance = ArgumentContext() + return cls._instance + + def add_option(self, option: click.Option, group: str) -> None: + with self._lock: + self._options[group].append(option) + + def echo_help(self) -> None: + if not self._options: + click.echo("No options") + return + + formatter = self._click_ctx.make_formatter() + formatter.write_heading("\nOptions from Starwhale Argument Decorator") + + for group, options in self._options.items(): + help_records = [] + for option in options: + record = option.get_help_record(self._click_ctx) + if record: + help_records.append(record) + + with formatter.section(f"** {group}"): + formatter.write_dl(help_records) + + click.echo(formatter.getvalue().rstrip("\n")) + + def argument(dataclass_types: t.Any, inject_name: str = "argument") -> t.Any: """argument is a decorator function to define arguments for model running(predict, evaluate, serve and finetune). @@ -68,9 +109,7 @@ def evaluate_summary(predict_result_iter, starwhale_arguments: EvaluationArgumen is_sequence = False def _register_wrapper(func: t.Callable) -> t.Any: - # TODO: add `--help` for the arguments # TODO: dump parser to json file when model building - # TODO: `@handler` decorator function supports @argument decorator parser = get_parser_from_dataclasses(dataclass_types) @wraps(func) @@ -113,12 +152,14 @@ def init_dataclasses_values( for k in inputs: del args_map[k] ret.append(dtype(**inputs)) + if args_map: console.warn(f"Unused args from command line: {args_map}") return ret def get_parser_from_dataclasses(dataclass_types: t.Any) -> click.OptionParser: + argument_ctx = ArgumentContext.get_current_context() parser = click.OptionParser() for dtype in dataclass_types: if not dataclasses.is_dataclass(dtype): @@ -129,13 +170,17 @@ def get_parser_from_dataclasses(dataclass_types: t.Any) -> click.OptionParser: if not field.init: continue field.type = type_hints[field.name] - add_field_into_parser(parser, field) + option = convert_field_to_option(field) + option.add_to_parser(parser=parser, ctx=parser.ctx) # type: ignore + argument_ctx.add_option( + option=option, group=f"{dtype.__module__}.{dtype.__qualname__}" + ) parser.ignore_unknown_options = True return parser -def add_field_into_parser(parser: click.OptionParser, field: dataclasses.Field) -> None: +def convert_field_to_option(field: dataclasses.Field) -> click.Option: # TODO: field.name need format for click option? decls = [f"--{field.name}"] if "_" in field.name: @@ -220,4 +265,4 @@ def add_field_into_parser(parser: click.OptionParser, field: dataclasses.Field) else: kw["required"] = True - click.Option(**kw).add_to_parser(parser=parser, ctx=None) # type: ignore + return click.Option(**kw) diff --git a/client/starwhale/core/model/cli.py b/client/starwhale/core/model/cli.py index 2d70d6834c..dc6cd59ded 100644 --- a/client/starwhale/core/model/cli.py +++ b/client/starwhale/core/model/cli.py @@ -1,6 +1,7 @@ from __future__ import annotations import os +import sys import typing as t from pathlib import Path @@ -472,6 +473,13 @@ def _recover(model: str, force: bool) -> None: multiple=True, help="module name, the format is python module import path, handlers will be searched in the module. The option supports set multiple times.", ) +@optgroup.option( # type: ignore[no-untyped-call] + "-sa", + "--show-argument", + is_flag=True, + default=False, + help="[ONLY STANDALONE]Show the argument help info by the @starwhale.argument decorator registered arguments. The help info only analysis the imported modules.", +) @optgroup.option( # type: ignore[no-untyped-call] "-f", "--model-yaml", @@ -609,6 +617,7 @@ def _run( forbid_packaged_runtime: bool, forbid_snapshot: bool, cleanup_snapshot: bool, + show_argument: bool, ) -> None: """Run Model. Model Package and the model source directory are supported. @@ -644,9 +653,15 @@ def _run( \b # --> run with finetune validation dataset swcli model run --workdir . -m mnist.finetune --dataset mnist --val-dataset mnist-val + + \b + # --> echo the argument help info by the @starwhale argument decorator + swcli model run --workdir . -m mnist.finetune --show-argument + swcli model run --uri mnist --show-argument """ from starwhale.api.argument import ExtraCliArgsRegistry + # TODO: currently, ExtraCliArgsRegistry must be set before the model run. We will find a better way to set it, such as ctx hooking. ExtraCliArgsRegistry.set(ctx.args) # TODO: support run model in cluster mode @@ -698,6 +713,21 @@ def _run( forbid_packaged_runtime=forbid_packaged_runtime, ) + if show_argument: + search_modules = model_config.run.modules + if not search_modules: + click.echo( + "no modules specified, please use `--module` option to set search modules" + ) + sys.exit(1) + + ModelTermView.show_argument( + model_src_dir=model_src_dir, + search_modules=search_modules, + runtime_uri=runtime_uri, + ) + return + if in_container: ModelTermView.run_in_container( model_src_dir=model_src_dir, diff --git a/client/starwhale/core/model/view.py b/client/starwhale/core/model/view.py index 3424921dbb..46520e911e 100644 --- a/client/starwhale/core/model/view.py +++ b/client/starwhale/core/model/view.py @@ -245,6 +245,26 @@ def run_in_server( return ok, version_or_reason + @classmethod + @BaseTermView._only_standalone + def show_argument( + cls, + model_src_dir: Path | str, + search_modules: t.List[str], + runtime_uri: t.Optional[Resource] = None, + ) -> None: + if runtime_uri: + RuntimeProcess(uri=runtime_uri).run() + else: + from starwhale.api._impl.argument import ArgumentContext + from starwhale.api._impl.job.handler import Handler + + Handler._preload_registering_handlers( + search_modules=search_modules, package_dir=Path(model_src_dir) + ) + ctx = ArgumentContext.get_current_context() + ctx.echo_help() + @classmethod @BaseTermView._only_standalone def run_in_host( diff --git a/client/tests/sdk/test_argument.py b/client/tests/sdk/test_argument.py index 9dfe80fa32..aae43faa51 100644 --- a/client/tests/sdk/test_argument.py +++ b/client/tests/sdk/test_argument.py @@ -3,12 +3,14 @@ import typing as t import dataclasses from enum import Enum +from unittest.mock import patch, MagicMock import click from pyfakefs.fake_filesystem_unittest import TestCase from starwhale.api._impl.argument import argument as argument_decorator from starwhale.api._impl.argument import ( + ArgumentContext, ExtraCliArgsRegistry, get_parser_from_dataclasses, ) @@ -67,6 +69,7 @@ def setUp(self) -> None: def tearDown(self) -> None: ExtraCliArgsRegistry._args = None + ArgumentContext._instance = None def test_argument_exceptions(self) -> None: @argument_decorator(ScalarArguments) @@ -157,6 +160,14 @@ def test_scalar_parser(self) -> None: assert scalar_parser._long_opt["--epoch"].obj.type == click.INT assert scalar_parser._long_opt["--epoch"].obj.default == 1 + argument_ctx = ArgumentContext.get_current_context() + assert len(argument_ctx._options) == 1 + options = argument_ctx._options["tests.sdk.test_argument.ScalarArguments"] + assert len(options) == 5 + assert options[0].name == "batch" + assert options[-1].name == "epoch" + argument_ctx.echo_help() + def test_compose_parser(self) -> None: compose_parser = get_parser_from_dataclasses([ComposeArguments]) @@ -199,3 +210,36 @@ def test_compose_parser(self) -> None: assert not optional_list_obj.required assert optional_list_obj.multiple assert optional_list_obj.default is None + + argument_ctx = ArgumentContext.get_current_context() + assert len(argument_ctx._options) == 1 + options = argument_ctx._options["tests.sdk.test_argument.ComposeArguments"] + assert len(options) == 6 + assert options[0].name == "debug" + argument_ctx.echo_help() + + @patch("click.echo") + def test_argument_help_output(self, mock_echo: MagicMock): + @argument_decorator((ScalarArguments, ComposeArguments)) + def mock_func(starwhale_argument: t.Tuple) -> None: + ... + + ArgumentContext.get_current_context().echo_help() + help_output = mock_echo.call_args[0][0] + cases = [ + "tests.sdk.test_argument.ScalarArguments:", + "--batch INTEGER", + "--overwrite", + "--learning_rate, --learning-rate FLOAT", + "--half_precision_backend, --half-precision-backend TEXT", + "--epoch INTEGER", + "tests.sdk.test_argument.ComposeArguments:", + "--debug DEBUGOPTION", + "--lr_scheduler_kwargs, --lr-scheduler-kwargs DICT", + "--evaluation_strategy, --evaluation-strategy [no|steps|epoch]", + "--per_gpu_train_batch_size, --per-gpu-train-batch-size INTEGER", + "--eval_delay, --eval-delay FLOAT", + "--label_names, --label-names TEXT", + ] + for case in cases: + assert case in help_output