Skip to content

Commit

Permalink
feat(controller): add ft eval create job api (#2971)
Browse files Browse the repository at this point in the history
  • Loading branch information
goldenxinxing authored Nov 16, 2023
1 parent 147ff2e commit 0289bd8
Show file tree
Hide file tree
Showing 35 changed files with 982 additions and 1,032 deletions.
5 changes: 4 additions & 1 deletion client/starwhale/api/_impl/wrapper.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import re
import threading
from enum import Enum, unique
Expand Down Expand Up @@ -104,7 +105,9 @@ def __init__(self, eval_id: str, project: Project):
] = (
lambda name: f"eval/{self.eval_id[:VERSION_PREFIX_CNT]}/{self.eval_id}/{name}"
)
self._eval_summary_table_name = "eval/summary"
self._eval_summary_table_name = os.getenv(
"SW_EVALUATION_SUMMARY_TABLE", "eval/summary"
)
self._data_store = data_store.get_data_store(
project.instance.url, project.instance.token
)
Expand Down
134 changes: 45 additions & 89 deletions client/starwhale/base/client/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,10 @@ class RevertModelVersionRequest(SwBaseModel):
version_url: str = Field(..., alias='versionUrl')


class BizType(Enum):
fine_tune = 'FINE_TUNE'


class Type1(Enum):
evaluation = 'EVALUATION'
train = 'TRAIN'
Expand All @@ -216,8 +220,11 @@ class DevWay(Enum):

class JobRequest(SwBaseModel):
model_version_id: Optional[str] = Field(None, alias='modelVersionId')
dataset_version_ids: Optional[List[str]] = Field(None, alias='datasetVersionIds')
runtime_version_id: Optional[str] = Field(None, alias='runtimeVersionId')
dataset_version_ids: Optional[List[str]] = Field(None, alias='datasetVersionIds')
eval_dataset_version_ids: Optional[List[str]] = Field(
None, alias='evalDatasetVersionIds'
)
time_to_live_in_sec: Optional[int] = Field(None, alias='timeToLiveInSec')
model_version_url: Optional[str] = Field(None, alias='modelVersionUrl')
dataset_version_urls: Optional[str] = Field(None, alias='datasetVersionUrls')
Expand All @@ -226,6 +233,8 @@ class JobRequest(SwBaseModel):
resource_pool: str = Field(..., alias='resourcePool')
handler: Optional[str] = None
step_spec_over_writes: Optional[str] = Field(None, alias='stepSpecOverWrites')
biz_type: Optional[BizType] = Field(None, alias='bizType')
biz_id: Optional[str] = Field(None, alias='bizId')
type: Optional[Type1] = None
dev_mode: Optional[bool] = Field(None, alias='devMode')
dev_password: Optional[str] = Field(None, alias='devPassword')
Expand Down Expand Up @@ -362,35 +371,6 @@ class ResponseMessageMapObjectObject(SwBaseModel):
data: Dict[str, Dict[str, Any]]


class Type3(Enum):
evaluation = 'EVALUATION'
train = 'TRAIN'
fine_tune = 'FINE_TUNE'
serving = 'SERVING'
built_in = 'BUILT_IN'


class FineTuneCreateRequest(SwBaseModel):
model_version_id: Optional[str] = Field(None, alias='modelVersionId')
dataset_version_ids: Optional[List[str]] = Field(None, alias='datasetVersionIds')
runtime_version_id: Optional[str] = Field(None, alias='runtimeVersionId')
time_to_live_in_sec: Optional[int] = Field(None, alias='timeToLiveInSec')
model_version_url: Optional[str] = Field(None, alias='modelVersionUrl')
dataset_version_urls: Optional[str] = Field(None, alias='datasetVersionUrls')
runtime_version_url: Optional[str] = Field(None, alias='runtimeVersionUrl')
comment: Optional[str] = None
resource_pool: str = Field(..., alias='resourcePool')
handler: Optional[str] = None
step_spec_over_writes: Optional[str] = Field(None, alias='stepSpecOverWrites')
type: Optional[Type3] = None
dev_mode: Optional[bool] = Field(None, alias='devMode')
dev_password: Optional[str] = Field(None, alias='devPassword')
dev_way: Optional[DevWay] = Field(None, alias='devWay')
eval_dataset_version_ids: Optional[List[int]] = Field(
None, alias='evalDatasetVersionIds'
)


class RecordValueDesc(SwBaseModel):
key: str
value: Optional[Dict[str, Any]] = None
Expand Down Expand Up @@ -553,7 +533,7 @@ class Flag(Enum):
unchanged = 'unchanged'


class Type4(Enum):
class Type3(Enum):
directory = 'directory'
file = 'file'

Expand All @@ -563,7 +543,7 @@ class FileNode(SwBaseModel):
signature: Optional[str] = None
flag: Optional[Flag] = None
mime: Optional[str] = None
type: Optional[Type4] = None
type: Optional[Type3] = None
desc: Optional[str] = None
size: Optional[str] = None

Expand Down Expand Up @@ -922,13 +902,13 @@ class DatasetVo(SwBaseModel):
version: DatasetVersionVo


class Type5(Enum):
class Type4(Enum):
dev_mode = 'DEV_MODE'
web_handler = 'WEB_HANDLER'


class ExposedLinkVo(SwBaseModel):
type: Type5
type: Type4
name: str
link: str

Expand Down Expand Up @@ -1193,7 +1173,7 @@ class Status2(Enum):
unknown = 'UNKNOWN'


class Type6(Enum):
class Type5(Enum):
image = 'IMAGE'
video = 'VIDEO'
audio = 'AUDIO'
Expand All @@ -1208,7 +1188,7 @@ class BuildRecordVo(SwBaseModel):
task_id: str = Field(..., alias='taskId')
dataset_name: str = Field(..., alias='datasetName')
status: Status2
type: Type6
type: Type5
create_time: int = Field(..., alias='createTime')


Expand Down Expand Up @@ -1285,24 +1265,39 @@ class ResponseMessagePageInfoFineTuneSpaceVo(SwBaseModel):
data: PageInfoFineTuneSpaceVo


class DsInfo(SwBaseModel):
pass
class FineTuneVo(SwBaseModel):
id: int
job: JobVo
train_datasets: List[DatasetVo] = Field(..., alias='trainDatasets')
eval_datasets: Optional[List[DatasetVo]] = Field(None, alias='evalDatasets')
target_model: ModelVo = Field(..., alias='targetModel')


class Status3(Enum):
created = 'CREATED'
ready = 'READY'
paused = 'PAUSED'
running = 'RUNNING'
cancelling = 'CANCELLING'
canceled = 'CANCELED'
success = 'SUCCESS'
fail = 'FAIL'
unknown = 'UNKNOWN'
class PageInfoFineTuneVo(SwBaseModel):
total: Optional[int] = None
list: Optional[List[FineTuneVo]] = None
page_num: Optional[int] = Field(None, alias='pageNum')
page_size: Optional[int] = Field(None, alias='pageSize')
size: Optional[int] = None
start_row: Optional[int] = Field(None, alias='startRow')
end_row: Optional[int] = Field(None, alias='endRow')
pages: Optional[int] = None
pre_page: Optional[int] = Field(None, alias='prePage')
next_page: Optional[int] = Field(None, alias='nextPage')
is_first_page: Optional[bool] = Field(None, alias='isFirstPage')
is_last_page: Optional[bool] = Field(None, alias='isLastPage')
has_previous_page: Optional[bool] = Field(None, alias='hasPreviousPage')
has_next_page: Optional[bool] = Field(None, alias='hasNextPage')
navigate_pages: Optional[int] = Field(None, alias='navigatePages')
navigatepage_nums: Optional[List[int]] = Field(None, alias='navigatepageNums')
navigate_first_page: Optional[int] = Field(None, alias='navigateFirstPage')
navigate_last_page: Optional[int] = Field(None, alias='navigateLastPage')


class ModelInfo(SwBaseModel):
pass
class ResponseMessagePageInfoFineTuneVo(SwBaseModel):
code: str
message: str
data: PageInfoFineTuneVo


class PanelPluginVo(SwBaseModel):
Expand Down Expand Up @@ -1619,45 +1614,6 @@ class ResponseMessageGraph(SwBaseModel):
data: Graph


class FineTuneVo(SwBaseModel):
id: int
job_id: int = Field(..., alias='jobId')
status: Status3
start_time: int = Field(..., alias='startTime')
end_time: Optional[int] = Field(None, alias='endTime')
train_datasets: Optional[List[DsInfo]] = Field(None, alias='trainDatasets')
eval_datasets: Optional[List[DsInfo]] = Field(None, alias='evalDatasets')
base_model: ModelInfo = Field(..., alias='baseModel')
target_model: Optional[ModelInfo] = Field(None, alias='targetModel')


class PageInfoFineTuneVo(SwBaseModel):
total: Optional[int] = None
list: Optional[List[FineTuneVo]] = None
page_num: Optional[int] = Field(None, alias='pageNum')
page_size: Optional[int] = Field(None, alias='pageSize')
size: Optional[int] = None
start_row: Optional[int] = Field(None, alias='startRow')
end_row: Optional[int] = Field(None, alias='endRow')
pages: Optional[int] = None
pre_page: Optional[int] = Field(None, alias='prePage')
next_page: Optional[int] = Field(None, alias='nextPage')
is_first_page: Optional[bool] = Field(None, alias='isFirstPage')
is_last_page: Optional[bool] = Field(None, alias='isLastPage')
has_previous_page: Optional[bool] = Field(None, alias='hasPreviousPage')
has_next_page: Optional[bool] = Field(None, alias='hasNextPage')
navigate_pages: Optional[int] = Field(None, alias='navigatePages')
navigatepage_nums: Optional[List[int]] = Field(None, alias='navigatepageNums')
navigate_first_page: Optional[int] = Field(None, alias='navigateFirstPage')
navigate_last_page: Optional[int] = Field(None, alias='navigateLastPage')


class ResponseMessagePageInfoFineTuneVo(SwBaseModel):
code: str
message: str
data: PageInfoFineTuneVo


class PageInfoPanelPluginVo(SwBaseModel):
total: Optional[int] = None
list: Optional[List[PanelPluginVo]] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,19 @@

import ai.starwhale.mlops.api.protocol.Code;
import ai.starwhale.mlops.api.protocol.ResponseMessage;
import ai.starwhale.mlops.api.protocol.ft.FineTuneCreateRequest;
import ai.starwhale.mlops.api.protocol.ft.FineTuneSpaceCreateRequest;
import ai.starwhale.mlops.api.protocol.ft.FineTuneSpaceVo;
import ai.starwhale.mlops.common.IdConverter;
import ai.starwhale.mlops.configuration.FeaturesProperties;
import ai.starwhale.mlops.domain.ft.FineTuneAppService;
import ai.starwhale.mlops.domain.ft.FineTuneSpaceService;
import ai.starwhale.mlops.domain.ft.vo.FineTuneVo;
import ai.starwhale.mlops.domain.job.converter.UserJobConverter;
import ai.starwhale.mlops.domain.project.ProjectService;
import ai.starwhale.mlops.domain.user.UserService;
import com.github.pagehelper.PageInfo;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag;
import javax.validation.Valid;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
Expand Down Expand Up @@ -58,16 +58,23 @@ public class FineTuneController {

final FineTuneAppService fineTuneAppService;

final FeaturesProperties featuresProperties;

final UserJobConverter userJobConverter;

public FineTuneController(
ProjectService projectService,
UserService userService,
FineTuneSpaceService fineTuneSpaceService,
FineTuneAppService fineTuneAppService
) {
FineTuneAppService fineTuneAppService,
FeaturesProperties featuresProperties,
UserJobConverter userJobConverter) {
this.projectService = projectService;
this.userService = userService;
this.fineTuneSpaceService = fineTuneSpaceService;
this.fineTuneAppService = fineTuneAppService;
this.featuresProperties = featuresProperties;
this.userJobConverter = userJobConverter;
}

@Operation(summary = "Get the list of fine-tune spaces")
Expand Down Expand Up @@ -115,25 +122,6 @@ public ResponseEntity<ResponseMessage<String>> updateSpace(
return ResponseEntity.ok(Code.success.asResponse(""));
}

@Operation(summary = "Create fine-tune")
@PostMapping(value = "/project/{projectId}/ftspace/{spaceId}/ft", produces = MediaType.APPLICATION_JSON_VALUE)
@PreAuthorize("hasAnyRole('OWNER', 'MAINTAINER')")
public ResponseEntity<ResponseMessage<String>> createFineTune(
@PathVariable("projectId") Long projectId,
@PathVariable("spaceId") Long spaceId,
@Valid @RequestBody FineTuneCreateRequest request
) {

fineTuneAppService.createFineTune(
spaceId,
projectService.findProject(projectId),
request,
userService.currentUserDetail()

);
return ResponseEntity.ok(Code.success.asResponse(""));
}

@Operation(summary = "List fine-tune")
@GetMapping(value = "/project/{projectId}/ftspace/{spaceId}/ft", produces = MediaType.APPLICATION_JSON_VALUE)
@PreAuthorize("hasAnyRole('OWNER', 'MAINTAINER', 'GUEST')")
Expand All @@ -159,4 +147,5 @@ public ResponseEntity<ResponseMessage<String>> releaseFt(
fineTuneAppService.releaseFt(ftId, modelName, userService.currentUserDetail());
return ResponseEntity.ok(Code.success.asResponse(""));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@
import ai.starwhale.mlops.domain.dag.DagQuerier;
import ai.starwhale.mlops.domain.dag.bo.Graph;
import ai.starwhale.mlops.domain.event.EventService;
import ai.starwhale.mlops.domain.ft.FineTuneAppService;
import ai.starwhale.mlops.domain.job.BizType;
import ai.starwhale.mlops.domain.job.JobServiceForWeb;
import ai.starwhale.mlops.domain.job.JobType;
import ai.starwhale.mlops.domain.job.ModelServingService;
import ai.starwhale.mlops.domain.job.RuntimeSuggestionService;
import ai.starwhale.mlops.domain.job.converter.UserJobConverter;
Expand Down Expand Up @@ -80,6 +83,7 @@
public class JobController {

private final JobServiceForWeb jobServiceForWeb;
private final FineTuneAppService fineTuneAppService;
private final TaskService taskService;
private final ModelServingService modelServingService;
private final RuntimeSuggestionService runtimeSuggestionService;
Expand All @@ -94,7 +98,7 @@ public class JobController {

public JobController(
JobServiceForWeb jobServiceForWeb,
TaskService taskService,
FineTuneAppService fineTuneAppService, TaskService taskService,
ModelServingService modelServingService,
RuntimeSuggestionService runtimeSuggestionService,
IdConverter idConvertor,
Expand All @@ -105,6 +109,7 @@ public JobController(
UserJobConverter userJobConverter
) {
this.jobServiceForWeb = jobServiceForWeb;
this.fineTuneAppService = fineTuneAppService;
this.taskService = taskService;
this.modelServingService = modelServingService;
this.runtimeSuggestionService = runtimeSuggestionService;
Expand Down Expand Up @@ -204,13 +209,22 @@ public ResponseEntity<ResponseMessage<String>> createJob(
@Valid @RequestBody JobRequest jobRequest
) {
if (jobRequest.isDevMode() && !featuresProperties.isJobDevEnabled()) {
throw new StarwhaleApiException(new SwValidationException(ValidSubject.JOB, "dev mode is not enabled"),
HttpStatus.BAD_REQUEST);
throw new StarwhaleApiException(
new SwValidationException(ValidSubject.JOB, "dev mode is not enabled"),
HttpStatus.BAD_REQUEST
);
}
Long jobId = null;
if (jobRequest.getBizType() == BizType.FINE_TUNE) {
Long spaceId = idConvertor.revert(jobRequest.getBizId());
if (jobRequest.getType() == JobType.FINE_TUNE) {
jobId = fineTuneAppService.createFineTune(projectUrl, spaceId, jobRequest);
} else if (jobRequest.getType() == JobType.EVALUATION) {
jobId = fineTuneAppService.createEvaluationJob(projectUrl, spaceId, jobRequest);
}
} else {
jobId = jobServiceForWeb.createJob(userJobConverter.convert(projectUrl, jobRequest));
}

var req = userJobConverter.convert(projectUrl, jobRequest);
Long jobId = jobServiceForWeb.createJob(req);

return ResponseEntity.ok(Code.success.asResponse(idConvertor.convert(jobId)));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@

import ai.starwhale.mlops.api.protocol.job.JobRequest;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.List;
import lombok.Data;
import lombok.EqualsAndHashCode;

@Data
public class FineTuneCreateRequest extends JobRequest {
@EqualsAndHashCode(callSuper = true)
public class FineTuneEvalCreateRequest extends JobRequest {

@JsonProperty("evalDatasetVersionIds")
private List<Long> evalDatasetVersionIds;
@JsonProperty("spaceId")
private Long spaceId;
}
Loading

0 comments on commit 0289bd8

Please sign in to comment.