Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactored roi_ranges out of the model #2454

Merged
merged 5 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 11 additions & 18 deletions mantidimaging/gui/windows/spectrum_viewer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,13 @@ class SpectrumViewerWindowModel:
_normalise_stack: ImageStack | None = None
tof_range: tuple[int, int] = (0, 0)
tof_plot_range: tuple[float, float] | tuple[int, int] = (0, 0)
_roi_ranges: dict[str, SensibleROI]
tof_mode: ToFUnitMode = ToFUnitMode.WAVELENGTH
tof_data: np.ndarray | None = None
tof_range_full: tuple[int, int] = (0, 0)

def __init__(self, presenter: SpectrumViewerWindowPresenter):
self.presenter = presenter
self._roi_id_counter = 0
self._roi_ranges = {}
self.special_roi_list = [ROI_ALL]

self.units = UnitConversion()
Expand Down Expand Up @@ -128,8 +126,6 @@ def set_stack(self, stack: ImageStack | None) -> None:
self.tof_range = (0, stack.data.shape[0] - 1)
self.tof_range_full = self.tof_range
self.tof_data = self.get_stack_time_of_flight()
height, width = self.get_image_shape()
self._roi_ranges[ROI_ALL] = SensibleROI.from_list([0, 0, width, height])

def set_normalise_stack(self, normalise_stack: ImageStack | None) -> None:
self._normalise_stack = normalise_stack
Expand Down Expand Up @@ -305,7 +301,6 @@ def save_csv(self,
Iterates over all ROIs and saves the spectrum for each one to a CSV file.
@param path: The path to save the CSV file to.
@param normalized: Whether to save the normalized spectrum.

"""
if self._stack is None:
raise ValueError("No stack selected")
Expand Down Expand Up @@ -333,18 +328,19 @@ def save_csv(self,
csv_output.add_column(f"{roi_name}_norm",
self.get_spectrum(roi, SpecType.SAMPLE_NORMED, normalise_with_shuttercount),
"Counts")

with path.open("w") as outfile:
csv_output.write(outfile)
self.save_roi_coords(self.get_roi_coords_filename(path))
self.save_roi_coords(self.get_roi_coords_filename(path), rois)

def save_single_rits_spectrum(self, path: Path, error_mode: ErrorMode) -> None:
def save_single_rits_spectrum(self, path: Path, error_mode: ErrorMode, roi: SensibleROI) -> None:
"""
Saves the spectrum for the RITS ROI to a RITS file.

@param path: The path to save the CSV file to.
@param error_mode: Which version (standard deviation or propagated) of the error to use in the RITS export.
"""
self.save_rits_roi(path, error_mode, self._roi_ranges[ROI_RITS])
self.save_rits_roi(path, error_mode, roi)

def save_rits_roi(self, path: Path, error_mode: ErrorMode, roi: SensibleROI, normalise: bool = False) -> None:
"""
Expand Down Expand Up @@ -401,6 +397,7 @@ def save_rits_images(self,
error_mode: ErrorMode,
bin_size: int,
step: int,
roi: SensibleROI,
normalise: bool = False,
progress: Progress | None = None) -> None:
"""
Expand All @@ -420,11 +417,13 @@ def save_rits_images(self,
error_mode (ErrorMode): The error mode to use when saving the images.
bin_size (int): The size of the sub-regions.
step (int): The step size to use when sliding the window across the ROI.
roi (SensibleROI): The parent ROI to be subdivided.
normalise (bool): If True, the images will be normalised.
progress (Progress | None): Optional progress reporter.

Returns:
None
"""
roi = self._roi_ranges[ROI_RITS]
left, top, right, bottom = roi
x_iterations = min(ceil((right - left) / step), ceil((right - left - bin_size) / step) + 1)
y_iterations = min(ceil((bottom - top) / step), ceil((bottom - top - bin_size) / step) + 1)
Expand Down Expand Up @@ -463,22 +462,17 @@ def get_roi_coords_filename(self, path: Path) -> Path:
"""
return path.with_stem(f"{path.stem}_roi_coords")

def save_roi_coords(self, path: Path) -> None:
"""
Save the coordinates of the ROIs to a csv file (ROI name, x_min, x_max, y_min, y_max)
following Pascal VOC format.
@param path: The path to save the CSV file to.
"""
def save_roi_coords(self, path: Path, rois: dict[str, SensibleROI]) -> None:
with open(path, encoding='utf-8', mode='w') as f:
csv_writer = csv.DictWriter(f, fieldnames=["ROI", "X Min", "X Max", "Y Min", "Y Max"])
csv_writer.writeheader()
for roi_name, coords in self._roi_ranges.items():
for roi_name, coords in rois.items():
csv_writer.writerow({
"ROI": roi_name,
"X Min": coords.left,
"X Max": coords.right,
"Y Min": coords.top,
"Y Max": coords.bottom
"Y Max": coords.bottom,
})

def export_spectrum_to_rits(self, path: Path, tof: np.ndarray, transmission: np.ndarray,
Expand All @@ -494,7 +488,6 @@ def remove_all_roi(self) -> None:
Remove all ROIs from the model
"""
self._roi_id_counter = 0
self._roi_ranges = {}

def set_relevant_tof_units(self) -> None:
if self._stack is not None:
Expand Down
24 changes: 14 additions & 10 deletions mantidimaging/gui/windows/spectrum_viewer/presenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,23 +275,24 @@ def handle_rits_export(self) -> None:
if path is None:
LOG.debug("No path selected, aborting export")
return
run_function = partial(self.model.save_rits_images,
path,
error_mode,
self.view.bin_size,
self.view.bin_step,
normalise=self.view.shuttercount_norm_enabled())

run_function = partial(
self.model.save_rits_images,
path,
error_mode,
self.view.bin_size,
self.view.bin_step,
normalise=self.view.shuttercount_norm_enabled(),
)
start_async_task_view(self.view, run_function, self._async_save_done)

else:
path = self.view.get_rits_export_filename()
if path is None:
LOG.debug("No path selected, aborting export")
return
if path and path.suffix != ".dat":
path = path.with_suffix(".dat")
self.model.save_single_rits_spectrum(path, error_mode)
roi = self.view.spectrum_widget.get_roi(ROI_RITS)
self.model.save_single_rits_spectrum(path, error_mode, roi)

def _async_save_done(self, task: TaskWorkerThread) -> None:
if task.error is not None:
Expand Down Expand Up @@ -349,7 +350,10 @@ def change_roi_colour(self, roi_name: str, new_colour: tuple[int, int, int]) ->
self.view.on_visibility_change()

def add_rits_roi(self) -> None:
roi = self.model._roi_ranges.setdefault(ROI_RITS, SensibleROI.from_list([0, 0, *self.model.get_image_shape()]))
"""
Add the RITS ROI to the spectrum widget and initialize it with default dimensions.
"""
roi = SensibleROI.from_list([0, 0, *self.model.get_image_shape()])
self.view.spectrum_widget.add_roi(roi, ROI_RITS)
self.view.set_spectrum(ROI_RITS,
self.model.get_spectrum(roi, self.spectrum_mode, self.view.shuttercount_norm_enabled()))
Expand Down
56 changes: 21 additions & 35 deletions mantidimaging/gui/windows/spectrum_viewer/test/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@ def _set_sample_stack(self, with_tof=False, with_shuttercount=False):
mock_shuttercounts.get_column.return_value = np.arange(5, 15)
stack._shutter_count_file = mock_shuttercounts
self.model.set_stack(stack)
height, width = stack.data.shape[1], stack.data.shape[2]
self.model._roi_ranges["roi"] = SensibleROI.from_list([0, 0, width, height])
return stack, spectrum

def _set_normalise_stack(self, with_shuttercount=False):
Expand Down Expand Up @@ -151,11 +149,8 @@ def test_normalise_issue(self):
self.assertEqual("", self.model.normalise_issue())

def test_set_stack_sets_roi(self):
self._set_sample_stack()
roi_all = self.model._roi_ranges["all"]
roi = self.model._roi_ranges["roi"]

self.assertEqual(roi_all, roi)
stack, _ = self._set_sample_stack()
roi_all = SensibleROI.from_list([0, 0, stack.data.shape[2], stack.data.shape[1]])
self.assertEqual(roi_all.top, 0)
self.assertEqual(roi_all.left, 0)
self.assertEqual(roi_all.right, 12)
Expand Down Expand Up @@ -228,14 +223,12 @@ def test_save_rits_roi_dat(self):
stack, _ = self._set_sample_stack(with_tof=True)
norm = ImageStack(np.full([10, 11, 12], 2))
stack.data[:, :, :5] *= 2

self.model._roi_ranges["rits_roi"] = SensibleROI.from_list([0, 0, 10, 11])
self.model.set_normalise_stack(norm)
self.model._roi_ranges["ROI_RITS"] = SensibleROI.from_list([0, 0, 10, 11])
roi = SensibleROI.from_list([0, 0, 10, 11])
mock_stream, mock_path = self._make_mock_path_stream()

with mock.patch.object(self.model, "save_roi_coords"):
self.model.save_rits_roi(mock_path, ErrorMode.STANDARD_DEVIATION, self.model._roi_ranges["ROI_RITS"])
self.model.save_rits_roi(mock_path, ErrorMode.STANDARD_DEVIATION, roi)
mock_path.open.assert_called_once_with("w")
self.assertIn("0.0\t0.0\t0.0", mock_stream.captured[0])
self.assertIn("100000.0\t0.75\t0.25", mock_stream.captured[1])
Expand All @@ -253,11 +246,12 @@ def test_save_rits_data_errors(self, _, error_mode, expected_error):
stack.data[:, :, :5] *= 2
self.model.set_normalise_stack(norm)

self.model._roi_ranges["ROI_RITS"] = SensibleROI.from_list([0, 0, 10, 11])
roi = SensibleROI.from_list([0, 0, 10, 11])
mock_stream, mock_path = self._make_mock_path_stream()

with mock.patch.object(self.model, "save_roi_coords"):
with mock.patch.object(self.model, "export_spectrum_to_rits") as mock_export:
self.model.save_rits_roi(mock_path, error_mode, self.model._roi_ranges["ROI_RITS"])
self.model.save_rits_roi(mock_path, error_mode, roi)

calculated_errors = mock_export.call_args[0][3]
np.testing.assert_allclose(expected_error, calculated_errors, atol=1e-4)
Expand All @@ -279,7 +273,6 @@ def test_save_rits_no_norm_err(self):
mock_inst_log = mock.create_autospec(InstrumentLog, source_file="", instance=True)
stack.log_file = mock_inst_log
roi = SensibleROI.from_list([0, 0, 12, 11])
self.model._roi_ranges["ROI_RITS"] = roi

mock_stream, mock_path = self._make_mock_path_stream()
with mock.patch.object(self.model, "save_roi_coords"):
Expand All @@ -299,8 +292,9 @@ def test_save_rits_no_tof_err(self):

def test_WHEN_save_csv_called_THEN_save_roi_coords_called_WITH_correct_args(self):
path = Path("test_file.csv")
rois = {"roi1": SensibleROI.from_list([0, 0, 10, 10]), "roi2": SensibleROI.from_list([10, 10, 20, 20])}
with mock.patch('builtins.open', mock.mock_open()) as mock_open:
self.model.save_roi_coords(path)
self.model.save_roi_coords(path, rois)
mock_open.assert_called_once_with(path, encoding='utf-8', mode='w')

def test_WHEN_get_roi_coords_filename_called_THEN_correct_filename_returned(self):
Expand Down Expand Up @@ -388,12 +382,8 @@ def test_WHEN_stack_value_set_THEN_can_export_returns_(self, _, image_stack, exp

def test_WHEN_remove_all_rois_called_THEN_all_but_default_rois_removed(self):
self.model.set_stack(generate_images())
rois = ["new_roi", "new_roi_2"]
for roi in rois:
self.model._roi_ranges[roi] = SensibleROI.from_list([0, 0, 10, 10])
self.assertListEqual(list(self.model._roi_ranges.keys()), ["all"] + rois)
self.model.remove_all_roi()
self.assertListEqual(list(self.model._roi_ranges.keys()), [])
self.assertEqual(self.model._roi_id_counter, 0)

def test_WHEN_no_stack_tof_THEN_time_of_flight_none(self):
# No Stack
Expand Down Expand Up @@ -449,43 +439,40 @@ def test_save_rits_images_write_correct_number_of_files(self, _, roi_size, bin_s
stack, _ = self._set_sample_stack(with_tof=True)
norm = ImageStack(np.full([10, 11, 12], 2))
stack.data[:, :, :5] *= 2
roi_name = "rits_roi"
roi = SensibleROI.from_list([0, 0, roi_size, roi_size])
self.model._roi_ranges[roi_name] = roi
self.model.set_normalise_stack(norm)

Mx, My = roi.width, roi.height
x_iterations = min(math.ceil(Mx / step), math.ceil((Mx - bin_size) / step) + 1)
y_iterations = min(math.ceil(My / step), math.ceil((My - bin_size) / step) + 1)
expected_number_of_calls = x_iterations * y_iterations
_, mock_path = self._make_mock_path_stream()
with mock.patch.object(self.model, "save_roi_coords"):
self.model.save_rits_images(mock_path, ErrorMode.STANDARD_DEVIATION, bin_size, step)
self.model.save_rits_images(mock_path, ErrorMode.STANDARD_DEVIATION, bin_size, step, roi)

self.assertEqual(mock_save_rits_roi.call_count, expected_number_of_calls)

@mock.patch.object(SpectrumViewerWindowModel, "save_rits_roi")
def test_save_single_rits_spectrum(self, mock_save_rits_roi):
stack, _ = self._set_sample_stack(with_tof=True)
norm = ImageStack(np.full([10, 11, 12], 2))
stack.data[:, :, :5] *= 2
self.model._roi_ranges["rits_roi"] = SensibleROI.from_list([0, 0, 5, 5])
self.model.set_normalise_stack(norm)
roi = SensibleROI.from_list([0, 0, 5, 5])
_, mock_path = self._make_mock_path_stream()
with mock.patch.object(self.model, "save_roi_coords"):
self.model.save_single_rits_spectrum(mock_path, ErrorMode.STANDARD_DEVIATION)
mock_save_rits_roi.assert_called_once_with(mock_path, mock.ANY, SensibleROI.from_list([0, 0, 5, 5]))
self.model.save_single_rits_spectrum(mock_path, ErrorMode.STANDARD_DEVIATION, roi)
mock_save_rits_roi.assert_called_once_with(mock_path, mock.ANY, roi)

@mock.patch.object(SpectrumViewerWindowModel, "export_spectrum_to_rits")
def test_save_rits_correct_transmission(self, mock_save_rits_roi):
stack, spectrum = self._set_sample_stack(with_tof=True)
norm = ImageStack(np.full([10, 11, 12], 2))
for i in range(10):
stack.data[:, :, i] *= i
self.model._roi_ranges["rits_roi"] = SensibleROI.from_list([1, 0, 6, 4])
self.model.set_normalise_stack(norm)
mock_path = mock.create_autospec(Path, instance=True)

self.model.save_rits_images(mock_path, ErrorMode.STANDARD_DEVIATION, 3, 1)
roi = SensibleROI.from_list([1, 0, 6, 4])
self.model.save_rits_images(mock_path, ErrorMode.STANDARD_DEVIATION, 3, 1, roi)

self.assertEqual(6, len(mock_save_rits_roi.call_args_list))
expected_means = [1, 1.5, 2, 1, 1.5, 2] # running average of [1, 2, 3, 4, 5], divided by 2 for normalisation
Expand Down Expand Up @@ -540,20 +527,19 @@ def test_get_transmission_error_standard_dev(self):
open_shutter_counts = normalise_stack.shutter_count_file.get_column(ShutterCountColumn.SHUTTER_COUNT)
average_shutter_counts = sample_shutter_counts[0] / open_shutter_counts[0]
roi = SensibleROI.from_list([0, 0, 5, 5])
self.model._roi_ranges["roi"] = roi

left, top, right, bottom = roi
sample = stack.data[:, top:bottom, left:right]
open = normalise_stack.data[:, top:bottom, left:right]
expected = np.divide(sample, open, out=np.zeros_like(sample), where=open != 0) / average_shutter_counts
expected = np.std(expected, axis=(1, 2))

with (mock.patch.object(self.model,
"get_shuttercount_normalised_correction_parameter",
return_value=average_shutter_counts) as
mock_get_shuttercount_normalised_correction_parameter):
with mock.patch.object(
self.model, "get_shuttercount_normalised_correction_parameter",
return_value=average_shutter_counts) as mock_get_shuttercount_normalised_correction_parameter:
result = self.model.get_transmission_error_standard_dev(roi, normalise_with_shuttercount=True)
mock_get_shuttercount_normalised_correction_parameter.assert_called_once()

self.assertEqual(len(expected), len(result))
np.testing.assert_allclose(expected, result)

Expand Down
13 changes: 5 additions & 8 deletions mantidimaging/gui/windows/spectrum_viewer/test/presenter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,18 +218,19 @@ def test_handle_export_csv(self, path_name: str, mock_save_csv: mock.Mock, mock_
normalise_with_shuttercount=False)

@parameterized.expand(["/fake/path", "/fake/path.dat"])
@mock.patch("mantidimaging.gui.windows.spectrum_viewer.model.SpectrumViewerWindowModel.save_rits_roi")
def test_handle_rits_export(self, path_name: str, mock_save_rits_roi: mock.Mock):
@mock.patch("mantidimaging.gui.windows.spectrum_viewer.model.SpectrumViewerWindowModel.save_single_rits_spectrum")
def test_handle_rits_export(self, path_name: str, mock_save_single_rits_spectrum: mock.Mock):
self.view.get_rits_export_filename = mock.Mock(return_value=Path(path_name))
self.view.transmission_error_mode = "Standard Deviation"

mock_roi = SensibleROI.from_list([0, 0, 5, 5])
self.presenter.model._roi_ranges[ROI_RITS] = mock_roi
self.view.spectrum_widget.get_roi = mock.Mock(return_value=mock_roi)
self.presenter.model.set_stack(generate_images())
self.presenter.handle_rits_export()

self.view.get_rits_export_filename.assert_called_once()
mock_save_rits_roi.assert_called_once_with(Path("/fake/path.dat"), ErrorMode.STANDARD_DEVIATION, mock_roi)
mock_save_single_rits_spectrum.assert_called_once_with(Path("/fake/path.dat"), ErrorMode.STANDARD_DEVIATION,
mock_roi)

def test_WHEN_do_add_roi_called_THEN_new_roi_added(self):
self.view.spectrum_widget.roi_dict = {"all": mock.Mock()}
Expand Down Expand Up @@ -292,19 +293,15 @@ def test_WHEN_ROI_renamed_THEN_roi_renamed(self):
def test_WHEN_invalid_ROI_renamed_THEN_error_raised(self):
rois = ["all", "roi", "roi_1"]
self.view.spectrum_widget.roi_dict = {roi: mock.Mock() for roi in rois}
self.presenter.model._roi_ranges = {roi: mock.Mock() for roi in rois}
self.view.spectrum_widget.rename_roi = mock.Mock(side_effect=KeyError("Invalid ROI"))
self.view.spectrum_widget.rois = {roi: mock.Mock() for roi in rois}
with self.assertRaises(KeyError):
self.presenter.rename_roi("invalid_roi", "new_name")

def test_WHEN_do_remove_roi_called_with_no_arguments_THEN_all_rois_removed(self):
rois = ["all", "roi", "roi_1", "roi_2"]
self.view.spectrum_widget.roi_dict = {roi: mock.Mock() for roi in rois}
self.presenter.model._roi_ranges = {roi: mock.Mock() for roi in rois}
self.presenter.do_remove_roi()
self.assertEqual(self.view.spectrum_widget.roi_dict, {})
self.assertEqual(self.presenter.model._roi_ranges, {})

@parameterized.expand([("Image Index", ToFUnitMode.IMAGE_NUMBER), ("Wavelength", ToFUnitMode.WAVELENGTH),
("Energy", ToFUnitMode.ENERGY), ("Time of Flight (\u03BCs)", ToFUnitMode.TOF_US)])
Expand Down
Loading