Skip to content

Commit

Permalink
enhance(client): make the list(model/runtime/dataset) output of stand…
Browse files Browse the repository at this point in the history
…alone instance consitent with cloud/server style (#3126)
  • Loading branch information
tianweidut authored Jan 17, 2024
1 parent c5ce480 commit 68e7c4c
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 5 deletions.
27 changes: 27 additions & 0 deletions client/starwhale/base/bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,14 @@
from starwhale.utils.config import SWCliConfigMixed
from starwhale.base.models.base import ListFilter
from starwhale.base.uri.project import Project
from starwhale.base.models.model import LocalModelInfoBase
from starwhale.base.uri.resource import Resource
from starwhale.base.models.dataset import LocalDatasetInfoBase
from starwhale.base.models.runtime import LocalRuntimeVersion

_LOCAL_INFO_TYPE = t.Union[
LocalModelInfoBase, LocalRuntimeVersion, LocalDatasetInfoBase
]


class BaseBundle(metaclass=ABCMeta):
Expand Down Expand Up @@ -278,3 +285,23 @@ def _do_remove(self, force: bool = False) -> t.Tuple[bool, str]:
False,
)
return _ok and _ok2, _reason + _reason2

@classmethod
def group_and_filter_local_info(
cls,
rows: t.List[_LOCAL_INFO_TYPE],
) -> t.List[_LOCAL_INFO_TYPE]:
rs: t.Dict[str, _LOCAL_INFO_TYPE] = {}
for row in rows:
if not isinstance(
row, (LocalModelInfoBase, LocalRuntimeVersion, LocalDatasetInfoBase)
):
raise TypeError(f"invalid type {type(row)}")

if row.name not in rs:
rs[row.name] = row
else:
if row.created_at > rs[row.name].created_at:
rs[row.name] = row

return list(rs.values())
2 changes: 1 addition & 1 deletion client/starwhale/base/models/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import typing as t
from typing import Union

from starwhale import Resource
from starwhale.base.models.base import SwBaseModel
from starwhale.base.uri.resource import Resource
from starwhale.base.client.models.models import RuntimeVo


Expand Down
2 changes: 1 addition & 1 deletion client/starwhale/core/dataset/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def list(
)
)

return rs, {}
return cls.group_and_filter_local_info(rs), {} # type: ignore

def build_from_csv_files(self, paths: t.List[PathLike], **kwargs: t.Any) -> None:
from starwhale.api._impl.dataset.model import Dataset as SDKDataset
Expand Down
2 changes: 1 addition & 1 deletion client/starwhale/core/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,7 @@ def list(
created_at=_info[CREATED_AT_KEY],
)
)
return rs, {}
return cls.group_and_filter_local_info(rs), {} # type: ignore

def buildImpl(self, workdir: Path, **kw: t.Any) -> None: # type: ignore[override]
model_config: ModelConfig = kw["model_config"]
Expand Down
2 changes: 1 addition & 1 deletion client/starwhale/core/runtime/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1509,7 +1509,7 @@ def list(
)
)

return ret, {}
return cls.group_and_filter_local_info(ret), {} # type: ignore

@classmethod
def quickstart_from_uri(
Expand Down
2 changes: 1 addition & 1 deletion client/tests/core/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -1144,7 +1144,7 @@ def test_build_from_runtime_yaml_in_venv_mode(
)

rts, _ = StandaloneRuntime.list(Project(""))
assert len(rts) == 2
assert len(rts) == 1

rtv = runtime_term_view(f"{name}/version/{build_version[:8]}")
ok, _ = rtv.remove()
Expand Down

0 comments on commit 68e7c4c

Please sign in to comment.