Skip to content

Commit

Permalink
feat(client): support finetune model source field for model copy (#2967)
Browse files Browse the repository at this point in the history
  • Loading branch information
tianweidut authored Nov 13, 2023
1 parent 0208100 commit 0ed5b27
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 13 deletions.
26 changes: 18 additions & 8 deletions client/starwhale/core/model/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,16 +482,26 @@ async def upload_model(
console.print("uploading metadata...")
blob_id = await _upload_meta_blobs(meta_blobs)
runtime_version = await runtime_version_recv.receive()

_request_json: t.Dict = {
"metaBlobId": blob_id,
"builtInRuntime": runtime_version,
"force": force,
}
_finetune_id = os.environ.get("SW_SERVER_TRIGGERED_FINETUNE_ID")
if _finetune_id is not None:
_request_json["modelSource"] = {
"type": "FINE_TUNE",
"id": int(_finetune_id),
}
await _http_request(
"POST",
path=f"/project/{dest_uri.project.id}"
+ f"/model/{dest_uri.name}"
+ f"/version/{dest_uri.version}/completeUpload",
json={
"metaBlobId": blob_id,
"builtInRuntime": runtime_version,
"force": force,
},
path=(
f"/project/{dest_uri.project.id}"
f"/model/{dest_uri.name}"
f"/version/{dest_uri.version}/completeUpload"
),
json=_request_json,
replace=False,
)
console.print("metadata uploaded")
Expand Down
12 changes: 7 additions & 5 deletions client/tests/base/test_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,11 +822,13 @@ def random_compressible_bytes(n: int) -> bytes:
swmp_path / "src" / "empty_dir", swmp_path / "src" / "empty_dir_symlink"
)
os.symlink(swmp_path / "src" / "empty", swmp_path / "src" / "empty_symlink")
BundleCopy(
src_uri="mnist/v1",
dest_uri="cloud://pre-bare/project/mnist/model/mnist-alias",
typ=ResourceType.model,
).do()

with patch.dict(os.environ, {"SW_SERVER_TRIGGERED_FINETUNE_ID": "1111"}):
BundleCopy(
src_uri="mnist/v1",
dest_uri="cloud://pre-bare/project/mnist/model/mnist-alias",
typ=ResourceType.model,
).do()

rm.request(
HTTPMethod.GET,
Expand Down

0 comments on commit 0ed5b27

Please sign in to comment.