Skip to content

Commit

Permalink
support argument for model run
Browse files Browse the repository at this point in the history
  • Loading branch information
tianweidut committed Dec 22, 2023
1 parent 4949eea commit 74b4fc0
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 3 deletions.
2 changes: 2 additions & 0 deletions client/starwhale/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from starwhale.api.metric import multi_classification
from starwhale.api.dataset import Dataset
from starwhale.utils.debug import init_logger
from starwhale.api.argument import argument
from starwhale.api.instance import login, logout
from starwhale.base.context import Context, pass_context
from starwhale.api.evaluation import Evaluation, PipelineHandler
Expand Down Expand Up @@ -47,6 +48,7 @@

__all__ = [
"__version__",
"argument",
"model",
"Job",
"job",
Expand Down
141 changes: 141 additions & 0 deletions client/starwhale/api/_impl/argument.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
from __future__ import annotations

import typing as t
import inspect
import dataclasses
from functools import wraps

import click


def argument(dataclass_types: t.Any) -> t.Any:
"""argument is a decorator function to define arguments for model running(predict, evaluate, serve and finetune).
The decorated function will receive the instances of the dataclass types as the arguments.
Argument:
dataclass_types: [required] The dataclass type of the arguments.
A list of dataclass types or a single dataclass type is supported.
Examples:
```python
from starwhale import argument, evaluation
@dataclass
class EvaluationArguments:
reshape: int = field(default=64, metadata={"help": "reshape image size"})
@argument(EvaluationArguments)
@evaluation.predict
def predict_image(data, eval_args: EvaluationArguments):
...
```
"""

def _register_wrapper(func: t.Callable) -> t.Any:
parser = get_parser_from_dataclasses(dataclass_types)

@wraps(func)
def _run_wrapper(*args: t.Any, **kw: t.Any) -> t.Any:
# parser parse from click ctx
# inject args to func
return func(*args, **kw)

return _run_wrapper

return _register_wrapper


def get_parser_from_dataclasses(dataclass_types: t.Any) -> click.OptionParser:
if dataclasses.is_dataclass(dataclass_types):
dataclass_types = [dataclass_types]

parser = click.OptionParser()
for dtype in dataclass_types:
if not dataclasses.is_dataclass(dtype):
raise ValueError(f"{dtype} is not a dataclass type")

type_hints: t.Dict[str, type] = t.get_type_hints(dtype)
for field in dataclasses.fields(dtype):
if not field.init:
continue
field.type = type_hints[field.name]
add_field_into_parser(parser, field)

return parser


def add_field_into_parser(parser: click.OptionParser, field: dataclasses.Field) -> None:
# TODO: field.name need format for click option?
kw: t.Dict[str, t.Any] = {
"param_decls": [f"--{field.name}"],
"help": field.metadata.get("help"),
"show_default": True,
"hidden": field.metadata.get("hidden", False),
}

# reference from huggingface transformers: https://github.com/huggingface/transformers/blob/main/src/transformers/hf_argparser.py
# only support: Union[xxx, None] or Union[EnumType, str] or [List[EnumType], str] type
origin_type = getattr(field.type, "__origin__", field.type)
if origin_type is t.Union:
if (str not in field.type.__args and type(None) not in field.type.__args__) or (
len(field.type.__args__) != 2
):
raise ValueError(
f"{field.type} is not supported."
"Only support Union[xxx, None] or Union[EnumType, str] or [List[EnumType], str] type"
)

if type(None) in field.type.__args__:
# ignore None type, use another type as the field type
field.type = (
field.type.__args__[0]
if field.type.__args__[1] == type(None)
else field.type.__args__[1]
)
origin_type = getattr(field.type, "__origin__", field.type)
else:
# ignore str and None type, use another type as the field type
field.type = (
field.type.__args__[0]
if field.type.__args__[1] == str
else field.type.__args__[1]
)
origin_type = getattr(field.type, "__origin__", field.type)

if origin_type is t.Literal or (
isinstance(field.type, type) and issubclass(field.type, Enum)
):
if origin_type is t.Literal:
kw["type"] = click.Choice(field.type.__args__)
else:
kw["type"] = click.Choice([e.value for e in field.type])

kw["show_choices"] = True
if field.default is not dataclasses.MISSING:
kw["default"] = field.default
else:
kw["required"] = True
elif field.type is bool or field.type == t.Optional[bool]:
kw["is_flag"] = True
kw["type"] = bool
kw["default"] = False if field.default is dataclasses.MISSING else field.default
elif inspect.isclass(origin_type) and issubclass(origin_type, list):
kw["type"] = field.type.__args__[0]
kw["multiple"] = True
if field.default is not dataclasses.MISSING:
kw["default"] = field.default
elif field.default_factory is not dataclasses.MISSING:
kw["default"] = field.default_factory()
else:
kw["required"] = True
else:
kw["type"] = field.type
if field.default is not dataclasses.MISSING:
kw["default"] = field.default
elif field.default_factory is not dataclasses.MISSING:
kw["default"] = field.default_factory()
else:
kw["required"] = True

click.Option(**kw).add_to_parser(parser=parser, ctx=None) # type: ignore
3 changes: 3 additions & 0 deletions client/starwhale/api/argument.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from ._impl.argument import argument

__all__ = ["argument"]
13 changes: 10 additions & 3 deletions example/helloworld/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
from pathlib import Path
from dataclasses import field, dataclass

import numpy as np
import onnxruntime as rt

from starwhale import Image, evaluation, multi_classification
from starwhale import Image, argument, evaluation, multi_classification

_g_model = None


@dataclass
class EvaluationArguments:
reshape: int = field(default=64, metadata={"help": "reshape image size"})


def _load_model():
global _g_model

Expand All @@ -19,16 +25,17 @@ def _load_model():
return _g_model


@argument(EvaluationArguments)
@evaluation.predict(
resources={"memory": {"request": "500M", "limit": "2G"}},
log_mode="plain",
)
def predict_image(data):
def predict_image(data, args: EvaluationArguments):
img: Image = data["img"]
model = _load_model()
input_name = model.get_inputs()[0].name
label_name = model.get_outputs()[0].name
input_array = [img.to_numpy().astype(np.float32).reshape(64)]
input_array = [img.to_numpy().astype(np.float32).reshape(args.reshape)]
pred = model.run([label_name], {input_name: input_array})[0]
return pred.item()

Expand Down

0 comments on commit 74b4fc0

Please sign in to comment.