Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(SDK): support argument for model run #3095

Merged
merged 3 commits into from
Dec 29, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions client/starwhale/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from starwhale.api.metric import multi_classification
from starwhale.api.dataset import Dataset
from starwhale.utils.debug import init_logger
from starwhale.api.argument import argument
from starwhale.api.instance import login, logout
from starwhale.base.context import Context, pass_context
from starwhale.api.evaluation import Evaluation, PipelineHandler
Expand Down Expand Up @@ -47,6 +48,7 @@

__all__ = [
"__version__",
"argument",
"model",
"Job",
"job",
Expand Down
216 changes: 216 additions & 0 deletions client/starwhale/api/_impl/argument.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
from __future__ import annotations

import typing as t
import inspect
import threading
import dataclasses
from enum import Enum
from functools import wraps

import click

from starwhale.utils import console


# TODO: use a more elegant way to pass extra cli args
class ExtraCliArgsRegistry:
_args = None
_lock = threading.Lock()

@classmethod
def set(cls, args: t.List[str]) -> None:
with cls._lock:
anda-ren marked this conversation as resolved.
Show resolved Hide resolved
cls._args = args

@classmethod
def get(cls) -> t.List[str]:
with cls._lock:
return cls._args or []


def argument(dataclass_types: t.Any) -> t.Any:
"""argument is a decorator function to define arguments for model running(predict, evaluate, serve and finetune).

The decorated function will receive the instances of the dataclass types as the arguments.
When the decorated function is called, the command line arguments will be parsed to the dataclass instances
and passed to the decorated function as the keyword arguments that name is "argument".

When use argument decorator, the decorated function must have a keyword argument named "argument" or use "**kw" keyword arguments.

Argument:
dataclass_types: [required] The dataclass type of the arguments.
A list of dataclass types or a single dataclass type is supported.

Examples:
```python
from starwhale import argument, evaluation

@dataclass
class EvaluationArguments:
reshape: int = field(default=64, metadata={"help": "reshape image size"})

@argument(EvaluationArguments)
@evaluation.predict
def predict_image(data, argument: EvaluationArguments):
...
```
"""
is_sequence = True
if dataclasses.is_dataclass(dataclass_types):
dataclass_types = [dataclass_types]
is_sequence = False

def _register_wrapper(func: t.Callable) -> t.Any:
# TODO: add `--help` for the arguments
# TODO: dump parser to json file when model building
# TODO: `@handler` decorator function supports @argument decorator
parser = get_parser_from_dataclasses(dataclass_types)

@wraps(func)
def _run_wrapper(*args: t.Any, **kw: t.Any) -> t.Any:
dataclass_values = init_dataclasses_values(parser, dataclass_types)
if "argument" in kw:
anda-ren marked this conversation as resolved.
Show resolved Hide resolved
raise RuntimeError(
"argument is a reserved keyword for @starwhale.argument decorator in the "
)
kw["argument"] = dataclass_values if is_sequence else dataclass_values[0]
return func(*args, **kw)

return _run_wrapper

return _register_wrapper


def init_dataclasses_values(
parser: click.OptionParser, dataclass_types: t.Any
) -> t.Any:
args_map, _, params = parser.parse_args(ExtraCliArgsRegistry.get())
param_map = {p.name: p for p in params}

ret = []
for dtype in dataclass_types:
keys = {f.name for f in dataclasses.fields(dtype) if f.init}
inputs = {}
for k, v in args_map.items():
if k not in keys:
continue

# TODO: support dict type convert
# handle multiple args for list type
if isinstance(v, list):
v = [param_map[k].type(i) for i in v]
else:
v = param_map[k].type(v)
inputs[k] = v

for k in inputs:
del args_map[k]
ret.append(dtype(**inputs))
if args_map:
console.warn(f"Unused args from command line: {args_map}")
return ret


def get_parser_from_dataclasses(dataclass_types: t.Any) -> click.OptionParser:
parser = click.OptionParser()
for dtype in dataclass_types:
if not dataclasses.is_dataclass(dtype):
raise ValueError(f"{dtype} is not a dataclass type")

type_hints: t.Dict[str, type] = t.get_type_hints(dtype)
for field in dataclasses.fields(dtype):
if not field.init:
continue
field.type = type_hints[field.name]
add_field_into_parser(parser, field)

parser.ignore_unknown_options = True
return parser


def add_field_into_parser(parser: click.OptionParser, field: dataclasses.Field) -> None:
# TODO: field.name need format for click option?
decls = [f"--{field.name}"]
if "_" in field.name:
decls.append(f"--{field.name.replace('_', '-')}")
kw: t.Dict[str, t.Any] = {
"param_decls": decls,
"help": field.metadata.get("help"),
"show_default": True,
"hidden": field.metadata.get("hidden", False),
}

# reference from huggingface transformers: https://github.com/huggingface/transformers/blob/main/src/transformers/hf_argparser.py
# only support: Union[xxx, None] or Union[EnumType, str] or [List[EnumType], str] type
origin_type = getattr(field.type, "__origin__", field.type)
if origin_type is t.Union:
if (
str not in field.type.__args__ and type(None) not in field.type.__args__
) or (len(field.type.__args__) != 2):
raise ValueError(
f"{field.type} is not supported."
"Only support Union[xxx, None] or Union[EnumType, str] or [List[EnumType], str] type"
)

if type(None) in field.type.__args__:
# ignore None type, use another type as the field type
field.type = (
field.type.__args__[0]
if field.type.__args__[1] == type(None)
else field.type.__args__[1]
)
origin_type = getattr(field.type, "__origin__", field.type)
else:
# ignore str and None type, use another type as the field type
field.type = (
field.type.__args__[0]
if field.type.__args__[1] == str
else field.type.__args__[1]
)
origin_type = getattr(field.type, "__origin__", field.type)

try:
# typing.Literal is only supported in python3.8+
literal_type = t.Literal # type: ignore[attr-defined]
except AttributeError:
literal_type = None

if (literal_type and origin_type is literal_type) or (
isinstance(field.type, type) and issubclass(field.type, Enum)
):
if literal_type and origin_type is literal_type:
kw["type"] = click.Choice(field.type.__args__)
else:
kw["type"] = click.Choice([e.value for e in field.type])

kw["show_choices"] = True
if field.default is not dataclasses.MISSING:
kw["default"] = field.default
else:
kw["required"] = True
elif field.type is bool or field.type == t.Optional[bool]:
kw["is_flag"] = True
kw["type"] = bool
kw["default"] = False if field.default is dataclasses.MISSING else field.default
elif inspect.isclass(origin_type) and issubclass(origin_type, (list, dict)):
if issubclass(origin_type, list):
kw["type"] = field.type.__args__[0]
kw["multiple"] = True
elif issubclass(origin_type, dict):
kw["type"] = dict

# list and dict types both need default_factory
if field.default_factory is not dataclasses.MISSING:
kw["default"] = field.default_factory()
elif field.default is dataclasses.MISSING:
kw["required"] = True
else:
kw["type"] = field.type
if field.default is not dataclasses.MISSING:
kw["default"] = field.default
elif field.default_factory is not dataclasses.MISSING:
kw["default"] = field.default_factory()
else:
kw["required"] = True

click.Option(**kw).add_to_parser(parser=parser, ctx=None) # type: ignore
25 changes: 20 additions & 5 deletions client/starwhale/api/_impl/evaluation/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,12 +232,18 @@ def _do_predict(
# case1: only accept data argument
# 1. def predict(self, data): ...
# 2. def predict(self, data, /): ...
# 3. def predict(self, *args): ...
# case2: accept data and external arguments
# 1. def predict(self, *args): ...
# 2. def predict(self, **kwargs): ...
# 3. def predict(self, *args, **kwargs): ...
# 4. def predict(self, data, external: t.Dict): ...
# 5. def predict(self, data, **kwargs): ...
# case3: accept two parameters
# 1. def predict(self, data, external=None): ...
# 2. def predict(self, data, **kwargs): ...
# 3. def predict(self, data, argument=None): ...
# 4. def predict(self, *args, **kwargs): ...
# case4: accept more than two parameters
# 1. def predict(self, data, external=None, argument=None): ...
# 2. def predict(self, data, external=None, **kwargs): ...
# 3. def predict(self, data, argument=None, **kwargs): ...

kind = inspect._ParameterKind

Expand All @@ -249,7 +255,7 @@ def _do_predict(
if len(parameters) <= 0:
raise RuntimeError("predict/ppl function must have at least one argument")
elif len(parameters) == 1:
parameter: inspect.Parameter = list(parameters.values())[0]
parameter = list(parameters.values())[0]
if parameter.kind == kind.VAR_POSITIONAL:
return func(data, external)
elif parameter.kind == kind.VAR_KEYWORD:
Expand All @@ -260,6 +266,15 @@ def _do_predict(
raise RuntimeError(
f"unsupported parameter kind for predict/ppl function: {parameter.kind}"
)
elif len(parameters) == 2:
# whether to inject `external` argument:
# case1: def predict(self, data, external=None): ...
# case2: def predict(self, data, **kw): ...
parameter = list(parameters.values())[-1]
if parameter.kind == kind.VAR_KEYWORD or parameter.name == "external":
return func(data, external=external)
else:
return func(data)
else:
return func(data, external=external)

Expand Down
5 changes: 3 additions & 2 deletions client/starwhale/api/_impl/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _register_wrapper(func: t.Callable) -> t.Any:
)

@wraps(func)
def _run_wrapper(*args: t.Any, **kw: t.Any) -> t.Any:
def _run_wrapper(*func_args: t.Any, **func_kw: t.Any) -> t.Any:
ctx = Context.get_runtime_context()
load_dataset = partial(Dataset.dataset, readonly=True, create="forbid")

Expand All @@ -97,7 +97,8 @@ def _run_wrapper(*args: t.Any, **kw: t.Any) -> t.Any:

# TODO: support arguments from command line
add_event(f"Start to finetune model by {func.__qualname__} function")
ret = func(*inject_args)

ret = func(*(inject_args + list(func_args)), **func_kw)

if auto_build_model:
console.info(f"building starwhale model package from {workdir}")
Expand Down
3 changes: 3 additions & 0 deletions client/starwhale/api/argument.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from ._impl.argument import argument, ExtraCliArgsRegistry

__all__ = ["argument", "ExtraCliArgsRegistry"]
18 changes: 14 additions & 4 deletions client/starwhale/core/model/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,13 @@ def _recover(model: str, force: bool) -> None:
ModelTermView(model).recover(force)


@model_cmd.command("run")
@model_cmd.command(
"run",
context_settings=dict(
ignore_unknown_options=True,
allow_extra_args=True,
),
)
@optgroup.group(
"\n ** Model Selectors",
cls=RequiredMutuallyExclusiveOptionGroup,
Expand Down Expand Up @@ -579,8 +585,9 @@ def _recover(model: str, force: bool) -> None:
multiple=True,
help=f"validation dataset uri for finetune, env is {SWEnv.finetune_validation_dataset_uri}",
)
@click.argument("handler_args", nargs=-1)
@click.pass_context
def _run(
ctx: click.Context,
workdir: str,
uri: str,
handler: int | str,
Expand All @@ -602,7 +609,6 @@ def _run(
forbid_packaged_runtime: bool,
forbid_snapshot: bool,
cleanup_snapshot: bool,
handler_args: t.Tuple[str],
) -> None:
"""Run Model.
Model Package and the model source directory are supported.
Expand Down Expand Up @@ -639,6 +645,10 @@ def _run(
# --> run with finetune validation dataset
swcli model run --workdir . -m mnist.finetune --dataset mnist --val-dataset mnist-val
"""
from starwhale.api.argument import ExtraCliArgsRegistry

ExtraCliArgsRegistry.set(ctx.args)

# TODO: support run model in cluster mode
run_project_uri = Project(run_project)
log_project_uri = Project(log_project)
Expand Down Expand Up @@ -719,7 +729,7 @@ def _run(
"task_num": override_task_num,
},
force_generate_jobs_yaml=uri is None,
handler_args=list(handler_args) if handler_args else [],
handler_args=ctx.args,
)


Expand Down
Loading
Loading