From 8ad5964727736c8fb9ce569af62fe3f0161c9465 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Mon, 20 Jan 2025 21:03:24 +0800 Subject: [PATCH 1/2] fix #8306 Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/bundle/scripts.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 131c78008b..fdd5fdc7da 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -174,7 +174,7 @@ def _get_git_release_url(repo_owner: str, repo_name: str, tag_name: str, filenam def _get_ngc_bundle_url(model_name: str, version: str) -> str: - return f"{NGC_BASE_URL}/{model_name.lower()}/versions/{version}/zip" + return f"{NGC_BASE_URL}/{model_name.lower()}/versions/{version}/files" def _get_ngc_private_base_url(repo: str) -> str: @@ -218,6 +218,21 @@ def _remove_ngc_prefix(name: str, prefix: str = "monai_") -> str: return name +def _get_all_download_files(request_url, headers=None) -> list[str]: + if not has_requests: + raise ValueError("requests package is required, please install it.") + headers = {} if headers is None else headers + response = requests_get(request_url, headers=headers) + response.raise_for_status() + model_info = json.loads(response.text) + + if not isinstance(model_info, dict) or "modelFiles" not in model_info: + raise ValueError("The data is not a dictionary or it does not have the key 'modelFiles'.") + + model_files = model_info["modelFiles"] + return [f["path"] for f in model_files] + + def _download_from_ngc( download_path: Path, filename: str, @@ -229,12 +244,12 @@ def _download_from_ngc( # ensure prefix is contained filename = _add_ngc_prefix(filename, prefix=prefix) url = _get_ngc_bundle_url(model_name=filename, version=version) - filepath = download_path / f"{filename}_v{version}.zip" if remove_prefix: filename = _remove_ngc_prefix(filename, prefix=remove_prefix) - extract_path = download_path / f"{filename}" - download_url(url=url, filepath=filepath, hash_val=None, progress=progress) - extractall(filepath=filepath, output_dir=extract_path, has_base=True) + filepath = download_path / filename + filepath.mkdir(parents=True, exist_ok=True) + for file in _get_all_download_files(url): + download_url(url=f"{url}/{file}", filepath=f"{filepath}/{file}", hash_val=None, progress=progress) def _download_from_ngc_private( From 942b62b0c9c1723674b53e1ba1529abf2ce57b2f Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Mon, 20 Jan 2025 22:15:02 +0800 Subject: [PATCH 2/2] fix mypy Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/bundle/scripts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index fdd5fdc7da..5089f0c045 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -218,7 +218,7 @@ def _remove_ngc_prefix(name: str, prefix: str = "monai_") -> str: return name -def _get_all_download_files(request_url, headers=None) -> list[str]: +def _get_all_download_files(request_url: str, headers: dict | None = None) -> list[str]: if not has_requests: raise ValueError("requests package is required, please install it.") headers = {} if headers is None else headers