Skip to content

Commit

Permalink
feat(controller): Sft/release (#2976)
Browse files Browse the repository at this point in the history
  • Loading branch information
anda-ren authored Nov 15, 2023
1 parent cc1326b commit 4a93aca
Show file tree
Hide file tree
Showing 9 changed files with 172 additions and 25 deletions.
2 changes: 2 additions & 0 deletions client/starwhale/base/client/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,8 @@ class StepSpec(SwBaseModel):
job_name: Optional[str] = None
show_name: str
require_dataset: Optional[bool] = None
require_train_datasets: Optional[bool] = None
require_validation_datasets: Optional[bool] = None
container_spec: Optional[ContainerSpec] = None
ext_cmd_args: Optional[str] = None
parameters_sig: Optional[List[ParameterSignature]] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,4 +147,16 @@ public ResponseEntity<ResponseMessage<PageInfo<FineTuneVo>>> listFineTune(
PageInfo<FineTuneVo> pageInfo = fineTuneAppService.list(spaceId, pageNum, pageSize);
return ResponseEntity.ok(Code.success.asResponse(pageInfo));
}

@Operation(summary = "release fine-tune")
@PutMapping(value = "/project/{projectId}/ftspace/{spaceId}/ft/release", produces =
MediaType.APPLICATION_JSON_VALUE)
@PreAuthorize("hasAnyRole('OWNER', 'MAINTAINER')")
public ResponseEntity<ResponseMessage<String>> releaseFt(
@RequestParam Long ftId,
@RequestParam(required = false) String modelName
) {
fineTuneAppService.releaseFt(ftId, modelName, userService.currentUserDetail());
return ResponseEntity.ok(Code.success.asResponse(""));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,31 @@
import ai.starwhale.mlops.common.Constants;
import ai.starwhale.mlops.common.IdConverter;
import ai.starwhale.mlops.configuration.FeaturesProperties;
import ai.starwhale.mlops.domain.bundle.base.BundleEntity;
import ai.starwhale.mlops.domain.dataset.DatasetDao;
import ai.starwhale.mlops.domain.dataset.bo.DatasetVersion;
import ai.starwhale.mlops.domain.ft.mapper.FineTuneMapper;
import ai.starwhale.mlops.domain.ft.mapper.FineTuneSpaceMapper;
import ai.starwhale.mlops.domain.ft.po.FineTuneEntity;
import ai.starwhale.mlops.domain.ft.po.FineTuneSpaceEntity;
import ai.starwhale.mlops.domain.ft.vo.FineTuneVo;
import ai.starwhale.mlops.domain.job.JobCreator;
import ai.starwhale.mlops.domain.job.JobType;
import ai.starwhale.mlops.domain.job.bo.Job;
import ai.starwhale.mlops.domain.job.bo.UserJobCreateRequest;
import ai.starwhale.mlops.domain.job.converter.UserJobConverter;
import ai.starwhale.mlops.domain.job.mapper.JobMapper;
import ai.starwhale.mlops.domain.job.po.JobEntity;
import ai.starwhale.mlops.domain.job.spec.Env;
import ai.starwhale.mlops.domain.job.spec.JobSpecParser;
import ai.starwhale.mlops.domain.job.spec.StepSpec;
import ai.starwhale.mlops.domain.model.ModelDao;
import ai.starwhale.mlops.domain.model.bo.ModelVersion;
import ai.starwhale.mlops.domain.model.po.ModelEntity;
import ai.starwhale.mlops.domain.model.po.ModelVersionEntity;
import ai.starwhale.mlops.domain.project.bo.Project;
import ai.starwhale.mlops.domain.user.bo.User;
import ai.starwhale.mlops.exception.SwNotFoundException;
import ai.starwhale.mlops.exception.SwNotFoundException.ResourceType;
import ai.starwhale.mlops.exception.SwValidationException;
import ai.starwhale.mlops.exception.SwValidationException.ValidSubject;
import ai.starwhale.mlops.exception.api.StarwhaleApiException;
Expand All @@ -53,6 +59,7 @@
import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.HttpStatus;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;

Expand All @@ -77,6 +84,10 @@ public class FineTuneAppService {

final String instanceUri;

final FineTuneSpaceMapper fineTuneSpaceMapper;

final UserJobConverter userJobConverter;

public FineTuneAppService(
FeaturesProperties featuresProperties,
JobCreator jobCreator,
Expand All @@ -86,7 +97,9 @@ public FineTuneAppService(
IdConverter idConverter,
ModelDao modelDao,
@Value("${sw.instance-uri}") String instanceUri,
DatasetDao datasetDao
DatasetDao datasetDao,
FineTuneSpaceMapper fineTuneSpaceMapper,
UserJobConverter userJobConverter
) {
this.featuresProperties = featuresProperties;
this.jobCreator = jobCreator;
Expand All @@ -97,9 +110,12 @@ public FineTuneAppService(
this.modelDao = modelDao;
this.datasetDao = datasetDao;
this.instanceUri = instanceUri;
this.fineTuneSpaceMapper = fineTuneSpaceMapper;
this.userJobConverter = userJobConverter;
}


@Transactional
public void createFineTune(
Long spaceId,
Project project,
Expand All @@ -116,23 +132,8 @@ public void createFineTune(
.build();
fineTuneMapper.add(ft);
request = addEnvToRequest(ft.getId(), request);
Job job = jobCreator.createJob(
UserJobCreateRequest.builder()
.modelVersionId(idConverter.revert(request.getModelVersionId()))
.runtimeVersionId(idConverter.revert(request.getRuntimeVersionId()))
.datasetVersionIds(idConverter.revertList(request.getDatasetVersionIds()))
.devMode(request.isDevMode())
.devPassword(request.getDevPassword())
.ttlInSec(request.getTimeToLiveInSec())
.project(project)
.user(creator)
.comment(request.getComment())
.resourcePool(request.getResourcePool())
.handler(request.getHandler())
.stepSpecOverWrites(request.getStepSpecOverWrites())
.jobType(JobType.FINE_TUNE)
.build()
);
request.setType(JobType.FINE_TUNE);
Job job = jobCreator.createJob(userJobConverter.convert(project.getId().toString(), request));
fineTuneMapper.updateJobId(ft.getId(), job.getId());
}

Expand Down Expand Up @@ -205,7 +206,7 @@ public PageInfo<FineTuneVo> list(Long spaceId, Integer pageNum, Integer pageSize
.jobId(jobId)
.status(job.getJobStatus())
.startTime(job.getCreatedTime().getTime())
.endTime(job.getFinishedTime().getTime())
.endTime(null != job.getFinishedTime() ? job.getFinishedTime().getTime() : null)
.evalDatasets(List.of())//TODO
.trainDatasets(List.of())//TODO
.baseModel(null)//TODO
Expand All @@ -219,7 +220,46 @@ public void evalFt(List<Long> evalDatasetIds, Long runtimeId, String handerSpec,

}

public void releaseFt(Long ftId) {
@Transactional
public void releaseFt(Long ftId, String modelName, User user) {
FineTuneEntity ft = fineTuneMapper.findById(ftId);
if (null == ft) {
throw new SwNotFoundException(ResourceType.FINE_TUNE, "fine tune not found");
}
Long targetModelVersionId = ft.getTargetModelVersionId();
if (null == targetModelVersionId) {
throw new SwNotFoundException(ResourceType.FINE_TUNE, "target model has not been generated yet");
}
ModelVersionEntity modelVersion = modelDao.getModelVersion(targetModelVersionId.toString());
if (!modelVersion.getDraft()) {
throw new SwValidationException(
ValidSubject.MODEL,
"model has been released to modelId: " + modelVersion.getModelId()
);
}
Long modelId;
if (!StringUtils.hasText(modelName) || modelVersion.getModelName().equals(modelName)) {
modelId = modelVersion.getModelId();
} else {
//release to a new model
FineTuneSpaceEntity ftSpace = fineTuneSpaceMapper.findById(ft.getSpaceId());
Long projectId = ftSpace.getProjectId();
BundleEntity modelEntity = this.modelDao.findByNameForUpdate(modelName, projectId);
if (null == modelEntity) {
//create model
ModelEntity model = ModelEntity.builder()
.ownerId(user.getId())
.projectId(projectId)
.modelName(modelName)
.build();
modelDao.add(model);
modelId = model.getId();
} else {
modelId = modelEntity.getId();
}
}
// update model version model id to new model and set draft to false
modelDao.releaseModelVersion(targetModelVersionId, modelId);

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ public interface FineTuneMapper {
@Select("select " + COLUMNS + " from fine_tune where job_id = #{jobId}")
FineTuneEntity findByJob(Long jobId);

@Results({
@Result(property = "evalDatasets", column = "eval_datasets", typeHandler = ListStringTypeHandler.class),
@Result(property = "trainDatasets", column = "train_datasets", typeHandler = ListStringTypeHandler.class)
})
@Select("select " + COLUMNS + " from fine_tune where id = #{id}")
FineTuneEntity findById(Long id);

@Update("update fine_tune set target_model_version_id = #{targetModelVersionId} where id = #{id}")
int updateTargetModel(Long id, Long targetModelVersionId);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ public interface FineTuneSpaceMapper {
@Select("select " + COLUMNS + " from fine_tune_space where project_id = #{projectId} order by id desc")
List<FineTuneSpaceEntity> list(Long projectId);

@Select("select " + COLUMNS + " from fine_tune_space where id = #{id}")
FineTuneSpaceEntity findById(Long id);


@UpdateProvider(value = UpdateSqlProvider.class, method = "update")
int update(Long spaceId, String name, String description);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,14 @@ public ModelVersionEntity getModelVersion(String versionUrl) {
return entity;
}

public void add(ModelEntity model) {
modelMapper.insert(model);
}

public void releaseModelVersion(Long versionId, Long modelId) {
versionMapper.updateModelRef(versionId, modelId);
}

@Override
public BundleEntity findById(Long id) {
return modelMapper.find(id);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ List<ModelVersionEntity> list(
@Update("update model_version set shared = #{shared} where id = #{id}")
int updateShared(@Param("id") Long id, @Param("shared") Boolean shared);

@Update("update model_version set model_id = #{modelId} and draft = 0 where id = #{id}")
int updateModelRef(@Param("id") Long id, @Param("modelId") Long modelId);

@Select("select " + VERSION_VIEW_COLUMNS
+ " from model_info as m, model_version as v, project_info as p, user_info as u"
+ " where v.model_id = m.id"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ public enum ResourceType {
USER("001", "User"),
PROJECT("002", "Project"),
BUNDLE("003", "Starwhale Bundle"),
BUNDLE_VERSION("004", "Starwhale Bundle Version");
BUNDLE_VERSION("004", "Starwhale Bundle Version"),
FINE_TUNE("005", "Starwhale Finetune");
final String code;
final String tipSubject;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,24 @@
import ai.starwhale.mlops.domain.dataset.DatasetDao;
import ai.starwhale.mlops.domain.dataset.bo.DatasetVersion;
import ai.starwhale.mlops.domain.ft.mapper.FineTuneMapper;
import ai.starwhale.mlops.domain.ft.mapper.FineTuneSpaceMapper;
import ai.starwhale.mlops.domain.ft.po.FineTuneEntity;
import ai.starwhale.mlops.domain.ft.po.FineTuneSpaceEntity;
import ai.starwhale.mlops.domain.job.JobCreator;
import ai.starwhale.mlops.domain.job.bo.Job;
import ai.starwhale.mlops.domain.job.bo.UserJobCreateRequest;
import ai.starwhale.mlops.domain.job.converter.UserJobConverter;
import ai.starwhale.mlops.domain.job.mapper.JobMapper;
import ai.starwhale.mlops.domain.job.po.JobEntity;
import ai.starwhale.mlops.domain.job.spec.JobSpecParser;
import ai.starwhale.mlops.domain.job.spec.StepSpec;
import ai.starwhale.mlops.domain.model.ModelDao;
import ai.starwhale.mlops.domain.model.po.ModelEntity;
import ai.starwhale.mlops.domain.model.po.ModelVersionEntity;
import ai.starwhale.mlops.domain.project.bo.Project;
import ai.starwhale.mlops.domain.user.bo.User;
import ai.starwhale.mlops.exception.SwNotFoundException;
import ai.starwhale.mlops.exception.SwValidationException;
import ai.starwhale.mlops.exception.api.StarwhaleApiException;
import com.fasterxml.jackson.core.JsonProcessingException;
import java.util.List;
Expand All @@ -53,6 +61,7 @@ class FineTuneAppServiceTest {
JobCreator jobCreator;

FineTuneMapper fineTuneMapper;
FineTuneSpaceMapper fineTuneSpaceMapper;

JobMapper jobMapper;

Expand All @@ -74,6 +83,9 @@ public void setup() {
jobSpecParser = mock(JobSpecParser.class);
modelDao = mock(ModelDao.class);
datasetDao = mock(DatasetDao.class);
UserJobConverter jobConverter = mock(UserJobConverter.class);
when(jobConverter.convert(any(), any())).thenReturn(UserJobCreateRequest.builder().build());
fineTuneSpaceMapper = mock(FineTuneSpaceMapper.class);
featuresProperties = mock(FeaturesProperties.class);
when(featuresProperties.isFineTuneEnabled()).thenReturn(true);
fineTuneAppService = new FineTuneAppService(
Expand All @@ -85,7 +97,8 @@ public void setup() {
new IdConverter(),
modelDao,
"instanceuri",
datasetDao
datasetDao,
fineTuneSpaceMapper, jobConverter//todo
);
}

Expand All @@ -106,7 +119,7 @@ public Object answer(InvocationOnMock invocation) {
when(datasetDao.getDatasetVersion(anyLong())).thenReturn(DatasetVersion.builder().projectId(22L).datasetName(
"dsn").versionName("dsv").build());
when(jobSpecParser.parseAndFlattenStepFromYaml(any())).thenReturn(List.of(StepSpec.builder().build()));
fineTuneAppService.createFineTune(1L, Project.builder().build(), request, User.builder().build());
fineTuneAppService.createFineTune(1L, Project.builder().id(1L).build(), request, User.builder().build());

verify(fineTuneMapper).updateJobId(123L, 22L);

Expand All @@ -125,6 +138,64 @@ void evalFt() {

@Test
void releaseFt() {
User creator = User.builder().build();
when(fineTuneMapper.findById(1L)).thenReturn(null);
Assertions.assertThrows(SwNotFoundException.class, () -> {
fineTuneAppService.releaseFt(1L, "", null);
});

when(fineTuneMapper.findById(2L)).thenReturn(
FineTuneEntity.builder()
.targetModelVersionId(null)
.build()
);
Assertions.assertThrows(SwNotFoundException.class, () -> {
fineTuneAppService.releaseFt(2L, "", null);
});

when(fineTuneMapper.findById(3L)).thenReturn(
FineTuneEntity.builder()
.targetModelVersionId(4L)
.build()
);
when(modelDao.getModelVersion("4")).thenReturn(ModelVersionEntity
.builder()
.draft(false)
.build());
Assertions.assertThrows(SwValidationException.class, () -> {
fineTuneAppService.releaseFt(3L, "", null);
});

when(fineTuneMapper.findById(5L)).thenReturn(
FineTuneEntity.builder()
.targetModelVersionId(6L)
.spaceId(1L)
.build()
);
when(modelDao.getModelVersion("6")).thenReturn(ModelVersionEntity
.builder()
.modelId(10L)
.modelName("aac")
.draft(true)
.build());
fineTuneAppService.releaseFt(5L, null, creator);
verify(modelDao).releaseModelVersion(6L, 10L);


when(fineTuneSpaceMapper.findById(anyLong())).thenReturn(FineTuneSpaceEntity.builder().projectId(1L).build());
doAnswer(new Answer() {
public Object answer(InvocationOnMock invocation) {
Object[] args = invocation.getArguments();
((ModelEntity) args[0]).setId(123L);
return null; // void method, so return null
}
}).when(modelDao).add(any());
fineTuneAppService.releaseFt(5L, "aab", creator);
verify(modelDao).releaseModelVersion(6L, 123L);

when(modelDao.findByNameForUpdate(any(), anyLong())).thenReturn(ModelEntity.builder().id(124L).build());
fineTuneAppService.releaseFt(5L, "aabc", creator);
verify(modelDao).releaseModelVersion(6L, 124L);
}

@Test
Expand All @@ -141,4 +212,4 @@ void testFeatureDisabled() {

Assertions.assertThrows(StarwhaleApiException.class, () -> fineTuneAppService.list(1L, 1, 1));
}
}
}

0 comments on commit 4a93aca

Please sign in to comment.