Skip to content

Commit

Permalink
Merge branch 'dev' into numpy2_reintro
Browse files Browse the repository at this point in the history
Signed-off-by: Eric Kerfoot <[email protected]>
  • Loading branch information
ericspod authored Feb 3, 2025
2 parents 3eabf20 + 8dcb9dc commit 1689c12
Show file tree
Hide file tree
Showing 11 changed files with 299 additions and 95 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ repos:
)$
- repo: https://github.com/hadialqattan/pycln
rev: v2.4.0
rev: v2.5.0
hooks:
- id: pycln
args: [--config=pyproject.toml]
25 changes: 20 additions & 5 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def _get_git_release_url(repo_owner: str, repo_name: str, tag_name: str, filenam


def _get_ngc_bundle_url(model_name: str, version: str) -> str:
return f"{NGC_BASE_URL}/{model_name.lower()}/versions/{version}/zip"
return f"{NGC_BASE_URL}/{model_name.lower()}/versions/{version}/files"


def _get_ngc_private_base_url(repo: str) -> str:
Expand Down Expand Up @@ -218,6 +218,21 @@ def _remove_ngc_prefix(name: str, prefix: str = "monai_") -> str:
return name


def _get_all_download_files(request_url: str, headers: dict | None = None) -> list[str]:
if not has_requests:
raise ValueError("requests package is required, please install it.")
headers = {} if headers is None else headers
response = requests_get(request_url, headers=headers)
response.raise_for_status()
model_info = json.loads(response.text)

if not isinstance(model_info, dict) or "modelFiles" not in model_info:
raise ValueError("The data is not a dictionary or it does not have the key 'modelFiles'.")

model_files = model_info["modelFiles"]
return [f["path"] for f in model_files]


def _download_from_ngc(
download_path: Path,
filename: str,
Expand All @@ -229,12 +244,12 @@ def _download_from_ngc(
# ensure prefix is contained
filename = _add_ngc_prefix(filename, prefix=prefix)
url = _get_ngc_bundle_url(model_name=filename, version=version)
filepath = download_path / f"{filename}_v{version}.zip"
if remove_prefix:
filename = _remove_ngc_prefix(filename, prefix=remove_prefix)
extract_path = download_path / f"{filename}"
download_url(url=url, filepath=filepath, hash_val=None, progress=progress)
extractall(filepath=filepath, output_dir=extract_path, has_base=True)
filepath = download_path / filename
filepath.mkdir(parents=True, exist_ok=True)
for file in _get_all_download_files(url):
download_url(url=f"{url}/{file}", filepath=f"{filepath}/{file}", hash_val=None, progress=progress)


def _download_from_ngc_private(
Expand Down
219 changes: 170 additions & 49 deletions monai/data/image_reader.py

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,3 +607,8 @@ def print_verbose(self) -> None:
print(self)
if self.meta is not None:
print(self.meta.__repr__())


# needed in later versions of Pytorch to indicate the class is safe for serialisation
if hasattr(torch.serialization, "add_safe_globals"):
torch.serialization.add_safe_globals([MetaTensor])
23 changes: 19 additions & 4 deletions monai/inferers/merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from contextlib import nullcontext
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Any

import numpy as np
import torch

from monai.utils import ensure_tuple_size, optional_import, require_pkg
from monai.utils import ensure_tuple_size, get_package_version, optional_import, require_pkg, version_geq

if TYPE_CHECKING:
import zarr
Expand Down Expand Up @@ -233,7 +234,7 @@ def __init__(
store: zarr.storage.Store | str = "merged.zarr",
value_store: zarr.storage.Store | str | None = None,
count_store: zarr.storage.Store | str | None = None,
compressor: str = "default",
compressor: str | None = None,
value_compressor: str | None = None,
count_compressor: str | None = None,
chunks: Sequence[int] | bool = True,
Expand All @@ -246,8 +247,22 @@ def __init__(
self.value_dtype = value_dtype
self.count_dtype = count_dtype
self.store = store
self.value_store = zarr.storage.TempStore() if value_store is None else value_store
self.count_store = zarr.storage.TempStore() if count_store is None else count_store
self.tmpdir: TemporaryDirectory | None
if version_geq(get_package_version("zarr"), "3.0.0"):
if value_store is None:
self.tmpdir = TemporaryDirectory()
self.value_store = zarr.storage.LocalStore(self.tmpdir.name)
else:
self.value_store = value_store
if count_store is None:
self.tmpdir = TemporaryDirectory()
self.count_store = zarr.storage.LocalStore(self.tmpdir.name)
else:
self.count_store = count_store
else:
self.tmpdir = None
self.value_store = zarr.storage.TempStore() if value_store is None else value_store
self.count_store = zarr.storage.TempStore() if count_store is None else count_store
self.chunks = chunks
self.compressor = compressor
self.value_compressor = value_compressor
Expand Down
2 changes: 1 addition & 1 deletion monai/utils/jupyter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def plot_engine_status(


def _get_loss_from_output(
output: list[torch.Tensor | dict[str, torch.Tensor]] | dict[str, torch.Tensor] | torch.Tensor
output: list[torch.Tensor | dict[str, torch.Tensor]] | dict[str, torch.Tensor] | torch.Tensor,
) -> torch.Tensor:
"""Returns a single value from the network output, which is a dict or tensor."""

Expand Down
4 changes: 2 additions & 2 deletions monai/visualize/img2tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,11 @@ def _image3_animated_gif(
img_str = b""
for b_data in PIL.GifImagePlugin.getheader(ims[0])[0]:
img_str += b_data
img_str += b"\x21\xFF\x0B\x4E\x45\x54\x53\x43\x41\x50" b"\x45\x32\x2E\x30\x03\x01\x00\x00\x00"
img_str += b"\x21\xff\x0b\x4e\x45\x54\x53\x43\x41\x50" b"\x45\x32\x2e\x30\x03\x01\x00\x00\x00"
for i in ims:
for b_data in PIL.GifImagePlugin.getdata(i):
img_str += b_data
img_str += b"\x3B"
img_str += b"\x3b"

summary = SummaryX if has_tensorboardx and isinstance(writer, SummaryWriterX) else Summary
summary_image_str = summary.Image(height=10, width=10, colorspace=1, encoded_image_string=img_str)
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pep8-naming
pycodestyle
pyflakes
black>=22.12
isort>=5.1
isort>=5.1, <6.0
ruff
pytype>=2020.6.1; platform_system != "Windows"
types-setuptools
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
torch>=1.9
torch>=1.9,<2.6
numpy>=1.24,<3.0
58 changes: 52 additions & 6 deletions tests/test_load_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,16 @@ def get_data(self, _obj):
# test reader consistency between PydicomReader and ITKReader on dicom data
TEST_CASE_22 = ["tests/testing_data/CT_DICOM"]

# test pydicom gpu reader
TEST_CASE_GPU_5 = [{"reader": "PydicomReader", "to_gpu": True}, "tests/testing_data/CT_DICOM", (16, 16, 4), (16, 16, 4)]

TEST_CASE_GPU_6 = [
{"reader": "PydicomReader", "ensure_channel_first": True, "force": True, "to_gpu": True},
"tests/testing_data/CT_DICOM",
(16, 16, 4),
(1, 16, 16, 4),
]

TESTS_META = []
for track_meta in (False, True):
TESTS_META.append([{}, (128, 128, 128), track_meta])
Expand Down Expand Up @@ -242,16 +252,17 @@ def test_nibabel_reader_gpu(self, input_param, filenames, expected_shape):

@parameterized.expand([TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_8_1, TEST_CASE_9])
def test_itk_reader(self, input_param, filenames, expected_shape):
test_image = np.random.rand(128, 128, 128)
test_image = torch.randint(0, 256, (128, 128, 128), dtype=torch.uint8).numpy()
print("Test image value range:", test_image.min(), test_image.max())
with tempfile.TemporaryDirectory() as tempdir:
for i, name in enumerate(filenames):
filenames[i] = os.path.join(tempdir, name)
itk_np_view = itk.image_view_from_array(test_image)
itk.imwrite(itk_np_view, filenames[i])
nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i])
result = LoadImage(image_only=True, **input_param)(filenames)
self.assertEqual(result.meta["filename_or_obj"], os.path.join(tempdir, "test_image.nii.gz"))
diag = torch.as_tensor(np.diag([-1, -1, 1, 1]))
np.testing.assert_allclose(result.affine, diag)
ext = "".join(Path(name).suffixes)
self.assertEqual(result.meta["filename_or_obj"], os.path.join(tempdir, "test_image" + ext))
self.assertEqual(result.meta["space"], "RAS")
assert_allclose(result.affine, torch.eye(4))
self.assertTupleEqual(result.shape, expected_shape)

@parameterized.expand([TEST_CASE_10, TEST_CASE_11, TEST_CASE_12, TEST_CASE_19, TEST_CASE_20, TEST_CASE_21])
Expand All @@ -271,6 +282,26 @@ def test_itk_dicom_series_reader(self, input_param, filenames, expected_shape, e
)
self.assertTupleEqual(result.shape, expected_np_shape)

@SkipIfNoModule("pydicom")
@SkipIfNoModule("cupy")
@SkipIfNoModule("kvikio")
@parameterized.expand([TEST_CASE_GPU_5, TEST_CASE_GPU_6])
def test_pydicom_gpu_reader(self, input_param, filenames, expected_shape, expected_np_shape):
result = LoadImage(image_only=True, **input_param)(filenames)
self.assertEqual(result.meta["filename_or_obj"], f"{Path(filenames)}")
assert_allclose(
result.affine,
torch.tensor(
[
[-0.488281, 0.0, 0.0, 125.0],
[0.0, -0.488281, 0.0, 128.100006],
[0.0, 0.0, 68.33333333, -99.480003],
[0.0, 0.0, 0.0, 1.0],
]
),
)
self.assertTupleEqual(result.shape, expected_np_shape)

def test_no_files(self):
with self.assertRaisesRegex(RuntimeError, "list index out of range"): # fname_regex excludes everything
LoadImage(image_only=True, reader="PydicomReader", fname_regex=r"^(?!.*).*")("tests/testing_data/CT_DICOM")
Expand Down Expand Up @@ -317,6 +348,21 @@ def test_dicom_reader_consistency(self, filenames):
np.testing.assert_allclose(pydicom_result, itk_result)
np.testing.assert_allclose(pydicom_result.affine, itk_result.affine)

@SkipIfNoModule("pydicom")
@SkipIfNoModule("cupy")
@SkipIfNoModule("kvikio")
@parameterized.expand([TEST_CASE_22])
def test_pydicom_reader_gpu_cpu_consistency(self, filenames):
gpu_param = {"reader": "PydicomReader", "to_gpu": True}
cpu_param = {"reader": "PydicomReader", "to_gpu": False}
for affine_flag in [True, False]:
gpu_param["affine_lps_to_ras"] = affine_flag
cpu_param["affine_lps_to_ras"] = affine_flag
gpu_result = LoadImage(image_only=True, **gpu_param)(filenames)
cpu_result = LoadImage(image_only=True, **cpu_param)(filenames)
np.testing.assert_allclose(gpu_result.cpu(), cpu_result)
np.testing.assert_allclose(gpu_result.affine.cpu(), cpu_result.affine)

def test_dicom_reader_consistency_single(self):
itk_param = {"reader": "ITKReader"}
pydicom_param = {"reader": "PydicomReader"}
Expand Down
52 changes: 27 additions & 25 deletions tests/test_zarr_avg_merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,42 +260,44 @@
TENSOR_4x4,
]

ALL_TESTS = [
TEST_CASE_0_DEFAULT_DTYPE,
TEST_CASE_1_DEFAULT_DTYPE,
TEST_CASE_2_DEFAULT_DTYPE,
TEST_CASE_3_DEFAULT_DTYPE,
TEST_CASE_4_DEFAULT_DTYPE,
TEST_CASE_5_VALUE_DTYPE,
TEST_CASE_6_COUNT_DTYPE,
TEST_CASE_7_COUNT_VALUE_DTYPE,
TEST_CASE_8_DTYPE,
TEST_CASE_9_LARGER_SHAPE,
TEST_CASE_10_DIRECTORY_STORE,
TEST_CASE_11_MEMORY_STORE,
TEST_CASE_12_CHUNKS,
TEST_CASE_16_WITH_LOCK,
TEST_CASE_17_WITHOUT_LOCK,
]

# add compression tests only when using Zarr version before 3.0
if not version_geq(get_package_version("zarr"), "3.0.0"):
ALL_TESTS += [TEST_CASE_13_COMPRESSOR_LZ4, TEST_CASE_14_COMPRESSOR_PICKLE, TEST_CASE_15_COMPRESSOR_LZMA]


@unittest.skipUnless(has_zarr and has_numcodecs, "Requires zarr (and numcodecs) packages.)")
class ZarrAvgMergerTests(unittest.TestCase):

@parameterized.expand(
[
TEST_CASE_0_DEFAULT_DTYPE,
TEST_CASE_1_DEFAULT_DTYPE,
TEST_CASE_2_DEFAULT_DTYPE,
TEST_CASE_3_DEFAULT_DTYPE,
TEST_CASE_4_DEFAULT_DTYPE,
TEST_CASE_5_VALUE_DTYPE,
TEST_CASE_6_COUNT_DTYPE,
TEST_CASE_7_COUNT_VALUE_DTYPE,
TEST_CASE_8_DTYPE,
TEST_CASE_9_LARGER_SHAPE,
TEST_CASE_10_DIRECTORY_STORE,
TEST_CASE_11_MEMORY_STORE,
TEST_CASE_12_CHUNKS,
TEST_CASE_13_COMPRESSOR_LZ4,
TEST_CASE_14_COMPRESSOR_PICKLE,
TEST_CASE_15_COMPRESSOR_LZMA,
TEST_CASE_16_WITH_LOCK,
TEST_CASE_17_WITHOUT_LOCK,
]
)
@parameterized.expand(ALL_TESTS)
def test_zarr_avg_merger_patches(self, arguments, patch_locations, expected):
codec_reg = numcodecs.registry.codec_registry
if "compressor" in arguments:
if arguments["compressor"] != "default":
arguments["compressor"] = zarr.codec_registry[arguments["compressor"].lower()]()
arguments["compressor"] = codec_reg[arguments["compressor"].lower()]()
if "value_compressor" in arguments:
if arguments["value_compressor"] != "default":
arguments["value_compressor"] = zarr.codec_registry[arguments["value_compressor"].lower()]()
arguments["value_compressor"] = codec_reg[arguments["value_compressor"].lower()]()
if "count_compressor" in arguments:
if arguments["count_compressor"] != "default":
arguments["count_compressor"] = zarr.codec_registry[arguments["count_compressor"].lower()]()
arguments["count_compressor"] = codec_reg[arguments["count_compressor"].lower()]()
merger = ZarrAvgMerger(**arguments)
for pl in patch_locations:
merger.aggregate(pl[0], pl[1])
Expand Down

0 comments on commit 1689c12

Please sign in to comment.