Skip to content

Commit

Permalink
Fix bundle download error from ngc source (#8307)
Browse files Browse the repository at this point in the history
Fixes #8306

This previous api has been deprecated, update based on:

https://docs.ngc.nvidia.com/api/?urls.primaryName=Private%20Artifacts%20(Models)%20API#/artifact-file-controller/downloadAllArtifactFiles

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: YunLiu <[email protected]>
  • Loading branch information
KumoLiu authored Jan 21, 2025
1 parent e39bad9 commit df1ba5d
Showing 1 changed file with 20 additions and 5 deletions.
25 changes: 20 additions & 5 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def _get_git_release_url(repo_owner: str, repo_name: str, tag_name: str, filenam


def _get_ngc_bundle_url(model_name: str, version: str) -> str:
return f"{NGC_BASE_URL}/{model_name.lower()}/versions/{version}/zip"
return f"{NGC_BASE_URL}/{model_name.lower()}/versions/{version}/files"


def _get_ngc_private_base_url(repo: str) -> str:
Expand Down Expand Up @@ -218,6 +218,21 @@ def _remove_ngc_prefix(name: str, prefix: str = "monai_") -> str:
return name


def _get_all_download_files(request_url: str, headers: dict | None = None) -> list[str]:
if not has_requests:
raise ValueError("requests package is required, please install it.")
headers = {} if headers is None else headers
response = requests_get(request_url, headers=headers)
response.raise_for_status()
model_info = json.loads(response.text)

if not isinstance(model_info, dict) or "modelFiles" not in model_info:
raise ValueError("The data is not a dictionary or it does not have the key 'modelFiles'.")

model_files = model_info["modelFiles"]
return [f["path"] for f in model_files]


def _download_from_ngc(
download_path: Path,
filename: str,
Expand All @@ -229,12 +244,12 @@ def _download_from_ngc(
# ensure prefix is contained
filename = _add_ngc_prefix(filename, prefix=prefix)
url = _get_ngc_bundle_url(model_name=filename, version=version)
filepath = download_path / f"{filename}_v{version}.zip"
if remove_prefix:
filename = _remove_ngc_prefix(filename, prefix=remove_prefix)
extract_path = download_path / f"{filename}"
download_url(url=url, filepath=filepath, hash_val=None, progress=progress)
extractall(filepath=filepath, output_dir=extract_path, has_base=True)
filepath = download_path / filename
filepath.mkdir(parents=True, exist_ok=True)
for file in _get_all_download_files(url):
download_url(url=f"{url}/{file}", filepath=f"{filepath}/{file}", hash_val=None, progress=progress)


def _download_from_ngc_private(
Expand Down

0 comments on commit df1ba5d

Please sign in to comment.