diff --git a/bunkerhill/examples/hippocampus/model.py b/bunkerhill/examples/hippocampus/model.py index 47af166..4ec4522 100644 --- a/bunkerhill/examples/hippocampus/model.py +++ b/bunkerhill/examples/hippocampus/model.py @@ -2,10 +2,11 @@ import subprocess -from typing import Dict +from typing import Dict, Tuple import numpy as np +from bunkerhill import image_utils from bunkerhill import nnunet_wrapper from bunkerhill.base_model import BaseModel from bunkerhill.bunkerhill_types import Outputs, SeriesInstanceUID @@ -48,17 +49,35 @@ def __init__(self): install_pretrained_model_cmd = [self._LOAD_WEIGHTS_COMMAND, self._PRETRAINED_MODEL_FILENAME] subprocess.check_call(install_pretrained_model_cmd, timeout=300) - def inference(self, pixel_array: Dict[SeriesInstanceUID, np.ndarray]) -> Outputs: + def inference( + self, + image_position_patient: Dict[SeriesInstanceUID, Dict[int, Tuple[float, float, float]]], + pixel_array: Dict[SeriesInstanceUID, np.ndarray], + pixel_spacing: Dict[SeriesInstanceUID, Tuple[float, float]], + ) -> Outputs: """Runs inference on the pixel array for a DICOM series. Args: + image_position_patient: The x, y, and z coordinates of the upper left hand corner of each + instance. pixel_array: A dict mapping the DICOM series UID to its pixel array. + pixel_spacing: The pair of values specifying physical distance in the patient between the + center of each pixel. Returns: A dictionary containing the output segmentation and softmax ndarrays. """ + pixel_spacing_z = image_utils.compute_z_dim_pixel_spacing( + next(iter(image_position_patient.values()))) + pixel_spacing_x, pixel_spacing_y = next(iter(pixel_spacing.values())) + first_series_pixel_array = next(iter(pixel_array.values())) + # Convert Bunkerhill pipeline's model arguments into format expected by nnUNet. - nnunet_wrapper.dump_pixel_array(next(iter(pixel_array.values())), self._paths.test_data_dirname) + nnunet_wrapper.dump_pixel_array( + [first_series_pixel_array], + [(pixel_spacing_x, pixel_spacing_y, pixel_spacing_z)], + self._paths.test_data_dirname + ) # Run model inference using nnUNet_predict command line tool. Save the softmax tensor in # addition to the segmentation. diff --git a/bunkerhill/examples/hippocampus/test_model.py b/bunkerhill/examples/hippocampus/test_model.py index ba23206..bf6f82f 100644 --- a/bunkerhill/examples/hippocampus/test_model.py +++ b/bunkerhill/examples/hippocampus/test_model.py @@ -1,6 +1,5 @@ """Test for model.py""" -import os import pickle import uuid @@ -23,7 +22,17 @@ def test_run_inference(tmp_path: Path, grpc_server: Server): study_identifier = str(uuid.uuid4()) pixel_array = np.random.randint(2, 165, (39, 47, 36), dtype=np.uint8) series_uid = '1.2.314159.117779' - model_arguments = {'pixel_array': {series_uid: pixel_array}} + + model_arguments = { + 'image_position_patient': { + series_uid: { + i + 1: (0., 0., float(i)) + for i in range(36) + } + }, + 'pixel_array': {series_uid: pixel_array}, + 'pixel_spacing': {series_uid: (1.0, 1.0)}, + } model_arguments_filename = shared_file_utils.get_model_arguments_filename( data_dirname, study_identifier) with open(model_arguments_filename, 'wb') as f: @@ -52,5 +61,5 @@ def test_run_inference(tmp_path: Path, grpc_server: Server): assert segmentation.dtype == np.uint8 softmax = outputs[MSDHippocampusModel._SOFTMAX_OUTPUT_ATTRIBUTE_NAME][series_uid] - assert softmax.shape == (3, 36, 47, 39) + assert softmax.shape == (3, 39, 47, 36) assert softmax.dtype == np.float16 diff --git a/bunkerhill/image_utils.py b/bunkerhill/image_utils.py new file mode 100644 index 0000000..7252449 --- /dev/null +++ b/bunkerhill/image_utils.py @@ -0,0 +1,30 @@ +"""Utility methods for image processing.""" + +import logging + + +from typing import Dict, Tuple + + +logger = logging.getLogger(__name__) + +def compute_z_dim_pixel_spacing( + image_position_patient: Dict[int, Tuple[float, float, float]] +) -> float: + """Computes the z dimension pixel spacing value from the Image Position (Patient) DICOM tag.""" + instance_indices = image_position_patient.keys() + first_instance_index = min(instance_indices) + last_instance_index = max(instance_indices) + first_z_position = image_position_patient[first_instance_index][2] + last_z_position = image_position_patient[last_instance_index][2] + + # Warn if there are missing instances. + num_expected_instances = last_instance_index - first_instance_index + 1 + num_actual_instances = len(instance_indices) + if num_actual_instances != num_expected_instances: + logger.warning( + 'Expected %s instances, but received %s instead.', + num_expected_instances, + num_actual_instances + ) + return abs(last_z_position - first_z_position) / (last_instance_index - first_instance_index) diff --git a/bunkerhill/nnunet_wrapper.py b/bunkerhill/nnunet_wrapper.py index e9ffbdb..39426a7 100644 --- a/bunkerhill/nnunet_wrapper.py +++ b/bunkerhill/nnunet_wrapper.py @@ -3,10 +3,10 @@ import dataclasses import os -from typing import Dict +from typing import Dict, List, Tuple -import nibabel as nib import numpy as np +import SimpleITK as sitk from bunkerhill.bunkerhill_types import SeriesInstanceUID @@ -72,7 +72,11 @@ def setup_paths(data_dirname: str, task: str) -> NNUNetPaths: ) -def dump_pixel_array(pixel_array: np.ndarray, nnunet_input_dirname: str) -> None: +def dump_pixel_array( + pixel_arrays: List[np.ndarray], + pixel_spacings: List[Tuple[float, float, float]], + nnunet_input_dirname: str +) -> None: """Converts pixel_array from NumPy ndarray to 3D NifTi file. nnUNet expects input images to be in 3D NifTi files, while Bunkerhill unpacks DICOM pixel_arrays @@ -80,14 +84,18 @@ def dump_pixel_array(pixel_array: np.ndarray, nnunet_input_dirname: str) -> None https://github.com/MIC-DKFZ/nnUNet/blob/7f1e273fa1021dd2ff00df2ada781ee3133096ef/documentation/dataset_conversion.md Args: - pixel_array: The pixel_array spanning all instances in the same series. + pixel_arrays: The pixel_array spanning all instances in the same series. + pixel_spacings: The distance between pixels along each of the dimensions (x, y, and z). nnunet_input_dirname: The directory path where the 3D NifTi pixel_array will be written. """ - model_argument_filename = os.path.join( - nnunet_input_dirname, f'{_TEST_INSTANCE_ID}_{_MODALITY_SUFFIX}.nii.gz' - ) - nifti_pixel_array = nib.Nifti1Image(pixel_array, affine=np.eye(4)) - nib.save(nifti_pixel_array, model_argument_filename) + for i, (pixel_array, pixel_spacing) in enumerate(zip(pixel_arrays, pixel_spacings)): + modality_suffix = '%04d' % i + model_argument_filename = os.path.join( + nnunet_input_dirname, f'{_TEST_INSTANCE_ID}_{modality_suffix}.nii.gz' + ) + sitk_pixel_array = sitk.GetImageFromArray(pixel_array) + sitk_pixel_array.SetSpacing(pixel_spacing) + sitk.WriteImage(sitk_pixel_array, model_argument_filename) def load_segmentation(outputs_dirname: str, output_attribute_name: str, @@ -110,8 +118,9 @@ def load_segmentation(outputs_dirname: str, output_attribute_name: str, } } """ - segmentation_ndarray = nib.load(os.path.join(outputs_dirname, - f'{_TEST_INSTANCE_ID}.nii.gz')).get_data() + img = sitk.ReadImage(os.path.join(outputs_dirname, + f'{_TEST_INSTANCE_ID}.nii.gz')) + segmentation_ndarray = sitk.GetArrayFromImage(img) return {output_attribute_name: {series_uid: segmentation_ndarray}} diff --git a/bunkerhill/utils/nifti_to_modelrunner_input.py b/bunkerhill/utils/nifti_to_modelrunner_input.py index a8366d3..1cd53ad 100644 --- a/bunkerhill/utils/nifti_to_modelrunner_input.py +++ b/bunkerhill/utils/nifti_to_modelrunner_input.py @@ -11,13 +11,13 @@ import argparse import pickle -import nibabel as nib +import SimpleITK as sitk from bunkerhill import shared_file_utils def main(args: argparse.Namespace) -> None: - input_array = nib.load(args.nifti_filename).get_data() + input_array = sitk.GetArrayFromImage(sitk.ReadImage(args.nifti_filename)) pixel_array = {'pixel_array': {args.series_uuid: input_array}} input_filename = shared_file_utils.get_model_arguments_filename( diff --git a/setup.py b/setup.py index fcd6ffb..311f62c 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name="bunkerhill", - version="0.0.1", + version="0.0.2", author="Bunkerhill Health", description="SDK for integration with Bunkerhill Health", long_description=long_description, @@ -21,8 +21,8 @@ 'grpcio==1.51.1', 'grpcio-testing==1.51.1', 'grpcio-tools==1.51.1', - 'nibabel==5.0.0', 'numpy>=1.24.0', + 'SimpleITK>=2.2.1', ], )