Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP enhancement(controller): add type for step spec #3076

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions client/starwhale/api/_impl/evaluation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from starwhale.consts import DecoratorInjectAttr
from starwhale.base.type import PredictLogMode
from starwhale.base.client.models.models import StepType

from .log import Evaluation
from .pipeline import PipelineHandler
Expand Down Expand Up @@ -109,6 +110,7 @@ def _register_predict(
dataset_uris=datasets,
),
built_in=True,
typ=StepType.evaluation,
)(func)


Expand Down Expand Up @@ -173,6 +175,7 @@ def _register_evaluate(
predict_auto_log=use_predict_auto_log,
),
built_in=True,
typ=StepType.evaluation,
)(func)


Expand Down
3 changes: 2 additions & 1 deletion client/starwhale/api/_impl/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from starwhale.base.context import Context
from starwhale.api._impl.model import build as build_starwhale_model
from starwhale.api._impl.dataset import Dataset
from starwhale.base.client.models.models import FineTune
from starwhale.base.client.models.models import FineTune, StepType


# TODO: support arguments
Expand Down Expand Up @@ -140,5 +140,6 @@ def _register_ft(
require_train_datasets=require_train_datasets,
require_validation_datasets=require_validation_datasets,
),
typ=StepType.fine_tune,
)(func)
setattr(func, DecoratorInjectAttr.FineTune, True)
7 changes: 7 additions & 0 deletions client/starwhale/api/_impl/job/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from starwhale.api._impl.evaluation import PipelineHandler
from starwhale.base.client.models.models import (
FineTune,
StepType,
RuntimeResource,
ParameterSignature,
)
Expand Down Expand Up @@ -105,6 +106,7 @@ def register(
require_dataset: bool = False,
built_in: bool = False,
fine_tune: FineTune | None = None,
typ: StepType | None = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mabe typo for type: StepType

) -> t.Callable:
"""Register a function as a handler. Enable the function execute by needs handler, run with gpu/cpu/mem resources in server side,
and control replicas of handler run.
Expand All @@ -125,6 +127,8 @@ def register(
built_in: [bool, optional] A special flag to distinguish user defined args in handler function from the StarWhale ones.
This should always be False unless you know what it does.
fine_tune: [FineTune, optional The fine tune config for the handler. Default is None.
typ:
- [str, optional] The type of the handler. Default is None.

Example:
```python
Expand Down Expand Up @@ -202,6 +206,7 @@ def decorator(func: t.Callable) -> t.Callable:
parameters_sig=parameters_sig,
ext_cmd_args=ext_cmd_args,
fine_tune=fine_tune,
step_type=typ,
)

cls._register(_handler, func)
Expand Down Expand Up @@ -347,6 +352,7 @@ def _preload_registering_handlers(
name="predict",
require_dataset=True,
built_in=True,
typ=StepType.evaluation,
)

evaluate_register = partial(
Expand All @@ -355,6 +361,7 @@ def _preload_registering_handlers(
name="evaluate",
built_in=True,
replicas=1,
typ=StepType.evaluation,
)

run_info = getattr(_cls, "_registered_run_info", None)
Expand Down
7 changes: 7 additions & 0 deletions client/starwhale/base/client/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,12 @@ class RuntimeResource(SwBaseModel):
limit: Optional[float] = None


class StepType(Enum):
evaluation = 'EVALUATION'
fine_tune = 'FINE_TUNE'
serving = 'SERVING'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we add plain or customized type for @handler annotation



class DatasetVersionViewVo(SwBaseModel):
id: str
version_name: str = Field(..., alias='versionName')
Expand Down Expand Up @@ -1492,6 +1498,7 @@ class StepSpec(SwBaseModel):
virtual: Optional[bool] = None
job_name: Optional[str] = None
show_name: str
step_type: Optional[StepType] = None
require_dataset: Optional[bool] = None
fine_tune: Optional[FineTune] = None
container_spec: Optional[ContainerSpec] = None
Expand Down
3 changes: 2 additions & 1 deletion client/starwhale/core/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@
from starwhale.core.runtime.model import StandaloneRuntime
from starwhale.base.client.api.job import JobApi
from starwhale.base.client.api.model import ModelApi
from starwhale.base.client.models.models import JobRequest, ModelInfoVo
from starwhale.base.client.models.models import StepType, JobRequest, ModelInfoVo


@unique
Expand Down Expand Up @@ -271,6 +271,7 @@ def _gen_model_serving(self, search_modules: t.List[str], workdir: Path) -> None
cls_name, _, func_name = func.__qualname__.rpartition(".")
h = StepSpecClient(
name="serving",
step_type=StepType.serving,
show_name="virtual handler for model serving",
func_name=func.__qualname__,
module_name=func.__module__,
Expand Down
3 changes: 3 additions & 0 deletions client/starwhale/utils/json.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import dataclasses
from enum import Enum
from typing import Any

from pydantic import BaseModel
Expand All @@ -11,6 +12,8 @@ def default(self, o: Any) -> Any:
return o.decode("utf-8")
if isinstance(o, BaseModel):
return json.loads(o.json(exclude_unset=True))
if isinstance(o, Enum):
return o.value
if dataclasses.is_dataclass(o):
return dataclasses.asdict(o)
return super().default(o)
6 changes: 6 additions & 0 deletions client/tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from starwhale.base.client.models.models import (
UserVo,
ModelVo,
StepType,
JobRequest,
ModelVersionVo,
ResponseMessageString,
Expand Down Expand Up @@ -120,6 +121,7 @@ def test_build_workflow(
show_name="t1",
func_name="test_func1",
module_name="mock",
step_type="EVALUATION",
)
Handler._registered_handlers["depend"] = StepSpecClient(
name="depend",
Expand Down Expand Up @@ -191,6 +193,7 @@ def test_build_workflow(
func_name="test_func1",
module_name="mock",
replicas=1,
step_type=StepType.evaluation,
)
],
"depend": [
Expand All @@ -200,6 +203,7 @@ def test_build_workflow(
func_name="test_func1",
module_name="mock",
replicas=1,
step_type=StepType.evaluation,
),
StepSpecClient(
name="depend",
Expand All @@ -208,11 +212,13 @@ def test_build_workflow(
module_name="mock",
replicas=1,
needs=["base"],
step_type=None,
),
],
"serving": [
StepSpecClient(
name="serving",
step_type=StepType.serving,
show_name="virtual handler for model serving",
func_name="StandaloneModel._serve_handler",
module_name="starwhale.core.model.model",
Expand Down
17 changes: 17 additions & 0 deletions client/tests/sdk/test_job_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from starwhale.base.models.model import JobHandlers, StepSpecClient
from starwhale.base.client.models.models import (
FineTune,
StepType,
RuntimeResource,
ParameterSignature,
)
Expand Down Expand Up @@ -242,6 +243,7 @@ def video_evaluate_handler(*args, **kwargs): ...
require_dataset=True,
parameters_sig=[],
ext_cmd_args="",
step_type=StepType.evaluation,
),
StepSpecClient(
cls_name="",
Expand All @@ -257,6 +259,7 @@ def video_evaluate_handler(*args, **kwargs): ...
require_dataset=False,
parameters_sig=[],
ext_cmd_args="",
step_type=StepType.evaluation,
),
]

Expand All @@ -282,6 +285,7 @@ def video_evaluate_handler(*args, **kwargs): ...
require_dataset=True,
parameters_sig=[],
ext_cmd_args="",
step_type=StepType.evaluation,
),
StepSpecClient(
cls_name="",
Expand All @@ -297,6 +301,7 @@ def video_evaluate_handler(*args, **kwargs): ...
require_dataset=False,
parameters_sig=[],
ext_cmd_args="",
step_type=StepType.evaluation,
),
]

Expand Down Expand Up @@ -346,6 +351,7 @@ def evaluate_handler(*args, **kwargs): ...
require_dataset=True,
parameters_sig=[],
ext_cmd_args="",
step_type=StepType.evaluation,
),
StepSpecClient(
cls_name="",
Expand All @@ -361,6 +367,7 @@ def evaluate_handler(*args, **kwargs): ...
require_dataset=False,
parameters_sig=[],
ext_cmd_args="",
step_type=StepType.evaluation,
),
],
"mock_user_module:predict_handler": [
Expand All @@ -385,6 +392,7 @@ def evaluate_handler(*args, **kwargs): ...
require_dataset=True,
parameters_sig=[],
ext_cmd_args="",
step_type=StepType.evaluation,
)
],
}
Expand Down Expand Up @@ -511,6 +519,7 @@ def evaluate(self, *args, **kwargs): ...
require_dataset=True,
parameters_sig=[],
ext_cmd_args="",
step_type=StepType.evaluation,
),
StepSpecClient(
cls_name="MockHandler",
Expand All @@ -528,6 +537,7 @@ def evaluate(self, *args, **kwargs): ...
require_dataset=False,
parameters_sig=[],
ext_cmd_args="",
step_type=StepType.evaluation,
),
]
assert jobs_info["mock_user_module:MockHandler.predict"] == [
Expand All @@ -546,6 +556,7 @@ def evaluate(self, *args, **kwargs): ...
require_dataset=True,
parameters_sig=[],
ext_cmd_args="",
step_type=StepType.evaluation,
)
]
_, steps = Step.get_steps_from_yaml(
Expand Down Expand Up @@ -614,6 +625,7 @@ def evaluate_handler(self, *args, **kwargs): ...
require_dataset=True,
parameters_sig=[],
ext_cmd_args="",
step_type=StepType.evaluation,
)
]

Expand All @@ -639,6 +651,7 @@ def evaluate_handler(self, *args, **kwargs): ...
require_dataset=True,
parameters_sig=[],
ext_cmd_args="",
step_type=StepType.evaluation,
),
StepSpecClient(
cls_name="MockHandler",
Expand All @@ -654,6 +667,7 @@ def evaluate_handler(self, *args, **kwargs): ...
require_dataset=False,
parameters_sig=[],
ext_cmd_args="",
step_type=StepType.evaluation,
),
]

Expand Down Expand Up @@ -793,6 +807,7 @@ def ft2(): ...
require_train_datasets=True,
require_validation_datasets=True,
),
step_type=StepType.fine_tune,
),
],
"mock_user_module:ft2": [
Expand All @@ -816,6 +831,7 @@ def ft2(): ...
require_train_datasets=True,
require_validation_datasets=True,
),
step_type=StepType.fine_tune,
),
StepSpecClient(
name="mock_user_module:ft2",
Expand Down Expand Up @@ -843,6 +859,7 @@ def ft2(): ...
require_train_datasets=False,
require_validation_datasets=False,
),
step_type=StepType.fine_tune,
),
],
}
Expand Down
1 change: 1 addition & 0 deletions console/src/api/server/data-contracts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1305,6 +1305,7 @@ export interface IStepSpec {
virtual?: boolean
job_name?: string
show_name: string
step_type?: 'EVALUATION' | 'FINE_TUNE' | 'SERVING'
require_dataset?: boolean
fine_tune?: IFineTune
container_spec?: IContainerSpec
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ public ResponseEntity<ResponseMessage<String>> createJob(
} else {
jobId = jobServiceForWeb.createJob(userJobConverter.convert(projectUrl, jobRequest));
}
if (jobId == null) {
if (jobId == null) {
throw new StarwhaleApiException(
new SwValidationException(ValidSubject.JOB, "job request is invalid"),
HttpStatus.BAD_REQUEST
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package ai.starwhale.mlops.api.protocol.model;

import ai.starwhale.mlops.domain.job.spec.StepSpec;
import ai.starwhale.mlops.domain.job.spec.step.StepSpec;
import io.swagger.v3.oas.annotations.media.Schema;
import java.util.List;
import javax.validation.constraints.NotNull;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package ai.starwhale.mlops.api.protocol.model;

import ai.starwhale.mlops.api.protocol.user.UserVo;
import ai.starwhale.mlops.domain.job.spec.StepSpec;
import ai.starwhale.mlops.domain.job.spec.step.StepSpec;
import com.fasterxml.jackson.annotation.JsonProperty;
import io.swagger.v3.oas.annotations.media.Schema;
import java.io.Serializable;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@
import ai.starwhale.mlops.domain.job.JobType;
import ai.starwhale.mlops.domain.job.bo.Job;
import ai.starwhale.mlops.domain.job.bo.JobCreateRequest;
import ai.starwhale.mlops.domain.job.spec.Env;
import ai.starwhale.mlops.domain.job.spec.JobSpecParser;
import ai.starwhale.mlops.domain.job.spec.StepSpec;
import ai.starwhale.mlops.domain.job.spec.step.Env;
import ai.starwhale.mlops.domain.job.spec.step.StepSpec;
import ai.starwhale.mlops.domain.job.step.VirtualJobLoader;
import ai.starwhale.mlops.domain.project.ProjectService;
import ai.starwhale.mlops.domain.project.bo.Project;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@
import ai.starwhale.mlops.domain.job.converter.UserJobConverter;
import ai.starwhale.mlops.domain.job.mapper.JobMapper;
import ai.starwhale.mlops.domain.job.po.JobEntity;
import ai.starwhale.mlops.domain.job.spec.Env;
import ai.starwhale.mlops.domain.job.spec.JobSpecParser;
import ai.starwhale.mlops.domain.job.spec.StepSpec;
import ai.starwhale.mlops.domain.job.spec.step.Env;
import ai.starwhale.mlops.domain.job.spec.step.StepSpec;
import ai.starwhale.mlops.domain.job.status.JobStatusMachine;
import ai.starwhale.mlops.domain.model.ModelDao;
import ai.starwhale.mlops.domain.model.ModelService;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@
import ai.starwhale.mlops.domain.job.bo.UserJobCreateRequest;
import ai.starwhale.mlops.domain.job.mapper.ModelServingMapper;
import ai.starwhale.mlops.domain.job.po.ModelServingEntity;
import ai.starwhale.mlops.domain.job.spec.Env;
import ai.starwhale.mlops.domain.job.spec.JobSpecParser;
import ai.starwhale.mlops.domain.job.spec.ModelServingSpec;
import ai.starwhale.mlops.domain.job.spec.StepSpec;
import ai.starwhale.mlops.domain.job.spec.step.Env;
import ai.starwhale.mlops.domain.job.spec.step.StepSpec;
import ai.starwhale.mlops.domain.job.status.JobStatus;
import ai.starwhale.mlops.domain.job.step.VirtualJobLoader;
import ai.starwhale.mlops.domain.model.ModelDao;
Expand Down
Loading
Loading