Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(controller): add ft eval create job api #2971

Merged
merged 15 commits into from
Nov 16, 2023
5 changes: 4 additions & 1 deletion client/starwhale/api/_impl/wrapper.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import re
import threading
from enum import Enum, unique
Expand Down Expand Up @@ -104,7 +105,9 @@ def __init__(self, eval_id: str, project: Project):
] = (
lambda name: f"eval/{self.eval_id[:VERSION_PREFIX_CNT]}/{self.eval_id}/{name}"
)
self._eval_summary_table_name = "eval/summary"
self._eval_summary_table_name = os.getenv(
"SW_EVALUATION_SUMMARY_TABLE", "eval/summary"
)
self._data_store = data_store.get_data_store(
project.instance.url, project.instance.token
)
Expand Down
134 changes: 45 additions & 89 deletions client/starwhale/base/client/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,10 @@ class RevertModelVersionRequest(SwBaseModel):
version_url: str = Field(..., alias='versionUrl')


class BizType(Enum):
fine_tune = 'FINE_TUNE'


class Type1(Enum):
evaluation = 'EVALUATION'
train = 'TRAIN'
Expand All @@ -216,8 +220,11 @@ class DevWay(Enum):

class JobRequest(SwBaseModel):
model_version_id: Optional[str] = Field(None, alias='modelVersionId')
dataset_version_ids: Optional[List[str]] = Field(None, alias='datasetVersionIds')
runtime_version_id: Optional[str] = Field(None, alias='runtimeVersionId')
dataset_version_ids: Optional[List[str]] = Field(None, alias='datasetVersionIds')
eval_dataset_version_ids: Optional[List[str]] = Field(
None, alias='evalDatasetVersionIds'
)
time_to_live_in_sec: Optional[int] = Field(None, alias='timeToLiveInSec')
model_version_url: Optional[str] = Field(None, alias='modelVersionUrl')
dataset_version_urls: Optional[str] = Field(None, alias='datasetVersionUrls')
Expand All @@ -226,6 +233,8 @@ class JobRequest(SwBaseModel):
resource_pool: str = Field(..., alias='resourcePool')
handler: Optional[str] = None
step_spec_over_writes: Optional[str] = Field(None, alias='stepSpecOverWrites')
biz_type: Optional[BizType] = Field(None, alias='bizType')
biz_id: Optional[str] = Field(None, alias='bizId')
type: Optional[Type1] = None
dev_mode: Optional[bool] = Field(None, alias='devMode')
dev_password: Optional[str] = Field(None, alias='devPassword')
Expand Down Expand Up @@ -362,35 +371,6 @@ class ResponseMessageMapObjectObject(SwBaseModel):
data: Dict[str, Dict[str, Any]]


class Type3(Enum):
evaluation = 'EVALUATION'
train = 'TRAIN'
fine_tune = 'FINE_TUNE'
serving = 'SERVING'
built_in = 'BUILT_IN'


class FineTuneCreateRequest(SwBaseModel):
model_version_id: Optional[str] = Field(None, alias='modelVersionId')
dataset_version_ids: Optional[List[str]] = Field(None, alias='datasetVersionIds')
runtime_version_id: Optional[str] = Field(None, alias='runtimeVersionId')
time_to_live_in_sec: Optional[int] = Field(None, alias='timeToLiveInSec')
model_version_url: Optional[str] = Field(None, alias='modelVersionUrl')
dataset_version_urls: Optional[str] = Field(None, alias='datasetVersionUrls')
runtime_version_url: Optional[str] = Field(None, alias='runtimeVersionUrl')
comment: Optional[str] = None
resource_pool: str = Field(..., alias='resourcePool')
handler: Optional[str] = None
step_spec_over_writes: Optional[str] = Field(None, alias='stepSpecOverWrites')
type: Optional[Type3] = None
dev_mode: Optional[bool] = Field(None, alias='devMode')
dev_password: Optional[str] = Field(None, alias='devPassword')
dev_way: Optional[DevWay] = Field(None, alias='devWay')
eval_dataset_version_ids: Optional[List[int]] = Field(
None, alias='evalDatasetVersionIds'
)


class RecordValueDesc(SwBaseModel):
key: str
value: Optional[Dict[str, Any]] = None
Expand Down Expand Up @@ -553,7 +533,7 @@ class Flag(Enum):
unchanged = 'unchanged'


class Type4(Enum):
class Type3(Enum):
directory = 'directory'
file = 'file'

Expand All @@ -563,7 +543,7 @@ class FileNode(SwBaseModel):
signature: Optional[str] = None
flag: Optional[Flag] = None
mime: Optional[str] = None
type: Optional[Type4] = None
type: Optional[Type3] = None
desc: Optional[str] = None
size: Optional[str] = None

Expand Down Expand Up @@ -922,13 +902,13 @@ class DatasetVo(SwBaseModel):
version: DatasetVersionVo


class Type5(Enum):
class Type4(Enum):
dev_mode = 'DEV_MODE'
web_handler = 'WEB_HANDLER'


class ExposedLinkVo(SwBaseModel):
type: Type5
type: Type4
name: str
link: str

Expand Down Expand Up @@ -1193,7 +1173,7 @@ class Status2(Enum):
unknown = 'UNKNOWN'


class Type6(Enum):
class Type5(Enum):
image = 'IMAGE'
video = 'VIDEO'
audio = 'AUDIO'
Expand All @@ -1208,7 +1188,7 @@ class BuildRecordVo(SwBaseModel):
task_id: str = Field(..., alias='taskId')
dataset_name: str = Field(..., alias='datasetName')
status: Status2
type: Type6
type: Type5
create_time: int = Field(..., alias='createTime')


Expand Down Expand Up @@ -1285,24 +1265,39 @@ class ResponseMessagePageInfoFineTuneSpaceVo(SwBaseModel):
data: PageInfoFineTuneSpaceVo


class DsInfo(SwBaseModel):
pass
class FineTuneVo(SwBaseModel):
id: int
job: JobVo
train_datasets: List[DatasetVo] = Field(..., alias='trainDatasets')
eval_datasets: Optional[List[DatasetVo]] = Field(None, alias='evalDatasets')
target_model: ModelVo = Field(..., alias='targetModel')


class Status3(Enum):
created = 'CREATED'
ready = 'READY'
paused = 'PAUSED'
running = 'RUNNING'
cancelling = 'CANCELLING'
canceled = 'CANCELED'
success = 'SUCCESS'
fail = 'FAIL'
unknown = 'UNKNOWN'
class PageInfoFineTuneVo(SwBaseModel):
total: Optional[int] = None
list: Optional[List[FineTuneVo]] = None
page_num: Optional[int] = Field(None, alias='pageNum')
page_size: Optional[int] = Field(None, alias='pageSize')
size: Optional[int] = None
start_row: Optional[int] = Field(None, alias='startRow')
end_row: Optional[int] = Field(None, alias='endRow')
pages: Optional[int] = None
pre_page: Optional[int] = Field(None, alias='prePage')
next_page: Optional[int] = Field(None, alias='nextPage')
is_first_page: Optional[bool] = Field(None, alias='isFirstPage')
is_last_page: Optional[bool] = Field(None, alias='isLastPage')
has_previous_page: Optional[bool] = Field(None, alias='hasPreviousPage')
has_next_page: Optional[bool] = Field(None, alias='hasNextPage')
navigate_pages: Optional[int] = Field(None, alias='navigatePages')
navigatepage_nums: Optional[List[int]] = Field(None, alias='navigatepageNums')
navigate_first_page: Optional[int] = Field(None, alias='navigateFirstPage')
navigate_last_page: Optional[int] = Field(None, alias='navigateLastPage')


class ModelInfo(SwBaseModel):
pass
class ResponseMessagePageInfoFineTuneVo(SwBaseModel):
code: str
message: str
data: PageInfoFineTuneVo


class PanelPluginVo(SwBaseModel):
Expand Down Expand Up @@ -1619,45 +1614,6 @@ class ResponseMessageGraph(SwBaseModel):
data: Graph


class FineTuneVo(SwBaseModel):
id: int
job_id: int = Field(..., alias='jobId')
status: Status3
start_time: int = Field(..., alias='startTime')
end_time: Optional[int] = Field(None, alias='endTime')
train_datasets: Optional[List[DsInfo]] = Field(None, alias='trainDatasets')
eval_datasets: Optional[List[DsInfo]] = Field(None, alias='evalDatasets')
base_model: ModelInfo = Field(..., alias='baseModel')
target_model: Optional[ModelInfo] = Field(None, alias='targetModel')


class PageInfoFineTuneVo(SwBaseModel):
total: Optional[int] = None
list: Optional[List[FineTuneVo]] = None
page_num: Optional[int] = Field(None, alias='pageNum')
page_size: Optional[int] = Field(None, alias='pageSize')
size: Optional[int] = None
start_row: Optional[int] = Field(None, alias='startRow')
end_row: Optional[int] = Field(None, alias='endRow')
pages: Optional[int] = None
pre_page: Optional[int] = Field(None, alias='prePage')
next_page: Optional[int] = Field(None, alias='nextPage')
is_first_page: Optional[bool] = Field(None, alias='isFirstPage')
is_last_page: Optional[bool] = Field(None, alias='isLastPage')
has_previous_page: Optional[bool] = Field(None, alias='hasPreviousPage')
has_next_page: Optional[bool] = Field(None, alias='hasNextPage')
navigate_pages: Optional[int] = Field(None, alias='navigatePages')
navigatepage_nums: Optional[List[int]] = Field(None, alias='navigatepageNums')
navigate_first_page: Optional[int] = Field(None, alias='navigateFirstPage')
navigate_last_page: Optional[int] = Field(None, alias='navigateLastPage')


class ResponseMessagePageInfoFineTuneVo(SwBaseModel):
code: str
message: str
data: PageInfoFineTuneVo


class PageInfoPanelPluginVo(SwBaseModel):
total: Optional[int] = None
list: Optional[List[PanelPluginVo]] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,19 @@

import ai.starwhale.mlops.api.protocol.Code;
import ai.starwhale.mlops.api.protocol.ResponseMessage;
import ai.starwhale.mlops.api.protocol.ft.FineTuneCreateRequest;
import ai.starwhale.mlops.api.protocol.ft.FineTuneSpaceCreateRequest;
import ai.starwhale.mlops.api.protocol.ft.FineTuneSpaceVo;
import ai.starwhale.mlops.common.IdConverter;
import ai.starwhale.mlops.configuration.FeaturesProperties;
import ai.starwhale.mlops.domain.ft.FineTuneAppService;
import ai.starwhale.mlops.domain.ft.FineTuneSpaceService;
import ai.starwhale.mlops.domain.ft.vo.FineTuneVo;
import ai.starwhale.mlops.domain.job.converter.UserJobConverter;
import ai.starwhale.mlops.domain.project.ProjectService;
import ai.starwhale.mlops.domain.user.UserService;
import com.github.pagehelper.PageInfo;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag;
import javax.validation.Valid;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
Expand Down Expand Up @@ -58,16 +58,23 @@ public class FineTuneController {

final FineTuneAppService fineTuneAppService;

final FeaturesProperties featuresProperties;

final UserJobConverter userJobConverter;

public FineTuneController(
ProjectService projectService,
UserService userService,
FineTuneSpaceService fineTuneSpaceService,
FineTuneAppService fineTuneAppService
) {
FineTuneAppService fineTuneAppService,
FeaturesProperties featuresProperties,
UserJobConverter userJobConverter) {
this.projectService = projectService;
this.userService = userService;
this.fineTuneSpaceService = fineTuneSpaceService;
this.fineTuneAppService = fineTuneAppService;
this.featuresProperties = featuresProperties;
this.userJobConverter = userJobConverter;
}

@Operation(summary = "Get the list of fine-tune spaces")
Expand Down Expand Up @@ -115,25 +122,6 @@ public ResponseEntity<ResponseMessage<String>> updateSpace(
return ResponseEntity.ok(Code.success.asResponse(""));
}

@Operation(summary = "Create fine-tune")
@PostMapping(value = "/project/{projectId}/ftspace/{spaceId}/ft", produces = MediaType.APPLICATION_JSON_VALUE)
@PreAuthorize("hasAnyRole('OWNER', 'MAINTAINER')")
public ResponseEntity<ResponseMessage<String>> createFineTune(
@PathVariable("projectId") Long projectId,
@PathVariable("spaceId") Long spaceId,
@Valid @RequestBody FineTuneCreateRequest request
) {

fineTuneAppService.createFineTune(
spaceId,
projectService.findProject(projectId),
request,
userService.currentUserDetail()

);
return ResponseEntity.ok(Code.success.asResponse(""));
}

@Operation(summary = "List fine-tune")
@GetMapping(value = "/project/{projectId}/ftspace/{spaceId}/ft", produces = MediaType.APPLICATION_JSON_VALUE)
@PreAuthorize("hasAnyRole('OWNER', 'MAINTAINER', 'GUEST')")
Expand All @@ -159,4 +147,5 @@ public ResponseEntity<ResponseMessage<String>> releaseFt(
fineTuneAppService.releaseFt(ftId, modelName, userService.currentUserDetail());
return ResponseEntity.ok(Code.success.asResponse(""));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@
import ai.starwhale.mlops.domain.dag.DagQuerier;
import ai.starwhale.mlops.domain.dag.bo.Graph;
import ai.starwhale.mlops.domain.event.EventService;
import ai.starwhale.mlops.domain.ft.FineTuneAppService;
import ai.starwhale.mlops.domain.job.BizType;
import ai.starwhale.mlops.domain.job.JobServiceForWeb;
import ai.starwhale.mlops.domain.job.JobType;
import ai.starwhale.mlops.domain.job.ModelServingService;
import ai.starwhale.mlops.domain.job.RuntimeSuggestionService;
import ai.starwhale.mlops.domain.job.converter.UserJobConverter;
Expand Down Expand Up @@ -80,6 +83,7 @@
public class JobController {

private final JobServiceForWeb jobServiceForWeb;
private final FineTuneAppService fineTuneAppService;
private final TaskService taskService;
private final ModelServingService modelServingService;
private final RuntimeSuggestionService runtimeSuggestionService;
Expand All @@ -94,7 +98,7 @@ public class JobController {

public JobController(
JobServiceForWeb jobServiceForWeb,
TaskService taskService,
FineTuneAppService fineTuneAppService, TaskService taskService,
ModelServingService modelServingService,
RuntimeSuggestionService runtimeSuggestionService,
IdConverter idConvertor,
Expand All @@ -105,6 +109,7 @@ public JobController(
UserJobConverter userJobConverter
) {
this.jobServiceForWeb = jobServiceForWeb;
this.fineTuneAppService = fineTuneAppService;
this.taskService = taskService;
this.modelServingService = modelServingService;
this.runtimeSuggestionService = runtimeSuggestionService;
Expand Down Expand Up @@ -204,13 +209,22 @@ public ResponseEntity<ResponseMessage<String>> createJob(
@Valid @RequestBody JobRequest jobRequest
) {
if (jobRequest.isDevMode() && !featuresProperties.isJobDevEnabled()) {
throw new StarwhaleApiException(new SwValidationException(ValidSubject.JOB, "dev mode is not enabled"),
HttpStatus.BAD_REQUEST);
throw new StarwhaleApiException(
new SwValidationException(ValidSubject.JOB, "dev mode is not enabled"),
HttpStatus.BAD_REQUEST
);
}
Long jobId = null;
if (jobRequest.getBizType() == BizType.FINE_TUNE) {
Long spaceId = idConvertor.revert(jobRequest.getBizId());
if (jobRequest.getType() == JobType.FINE_TUNE) {
jobId = fineTuneAppService.createFineTune(projectUrl, spaceId, jobRequest);
} else if (jobRequest.getType() == JobType.EVALUATION) {
jobId = fineTuneAppService.createEvaluationJob(projectUrl, spaceId, jobRequest);
}
} else {
jobId = jobServiceForWeb.createJob(userJobConverter.convert(projectUrl, jobRequest));
}

var req = userJobConverter.convert(projectUrl, jobRequest);
Long jobId = jobServiceForWeb.createJob(req);

return ResponseEntity.ok(Code.success.asResponse(idConvertor.convert(jobId)));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@

import ai.starwhale.mlops.api.protocol.job.JobRequest;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.List;
import lombok.Data;
import lombok.EqualsAndHashCode;

@Data
public class FineTuneCreateRequest extends JobRequest {
@EqualsAndHashCode(callSuper = true)
public class FineTuneEvalCreateRequest extends JobRequest {

@JsonProperty("evalDatasetVersionIds")
private List<Long> evalDatasetVersionIds;
@JsonProperty("spaceId")
private Long spaceId;
}
Loading
Loading