diff --git a/conda/meta.yaml b/conda/meta.yaml index f129f9df21a..3a9668a813e 100644 --- a/conda/meta.yaml +++ b/conda/meta.yaml @@ -41,6 +41,8 @@ requirements: - qt-material=2.14 - darkdetect=0.8.0 - qt-gtk-platformtheme # [linux] + - dask + - dask-image build: diff --git a/docs/release_notes/next/dev-2311-dask-live-viewer b/docs/release_notes/next/dev-2311-dask-live-viewer new file mode 100644 index 00000000000..807bbad4db2 --- /dev/null +++ b/docs/release_notes/next/dev-2311-dask-live-viewer @@ -0,0 +1 @@ +#2311: The Live Viewer now uses Dask to load in images and create a delayed datastack for operations \ No newline at end of file diff --git a/docs/release_notes/next/feature-2327-live-viewer-spectrum-experimental b/docs/release_notes/next/feature-2327-live-viewer-spectrum-experimental new file mode 100644 index 00000000000..7b752aed5da --- /dev/null +++ b/docs/release_notes/next/feature-2327-live-viewer-spectrum-experimental @@ -0,0 +1 @@ +#2327: The mean spectrum of live data in the Live Spectrum can be plotted via a right-click menu, using Dask. This feature is experimental. \ No newline at end of file diff --git a/mantidimaging/eyes_tests/live_viewer_window_test.py b/mantidimaging/eyes_tests/live_viewer_window_test.py index 7d9b49bea06..b43eb70db62 100644 --- a/mantidimaging/eyes_tests/live_viewer_window_test.py +++ b/mantidimaging/eyes_tests/live_viewer_window_test.py @@ -7,7 +7,7 @@ import numpy as np import os from mantidimaging.core.operations.loader import load_filter_packages -from mantidimaging.gui.windows.live_viewer.model import Image_Data +from mantidimaging.gui.windows.live_viewer.model import Image_Data, DaskImageDataStack from mantidimaging.test_helpers.unit_test_helper import FakeFSTestCase from pathlib import Path from mantidimaging.eyes_tests.base_eyes import BaseEyesTest @@ -56,36 +56,39 @@ def test_live_view_opens_without_data(self, _mock_time, _mock_image_watcher): self.imaging.show_live_viewer(self.live_directory) self.check_target(widget=self.imaging.live_viewer) - @mock.patch('mantidimaging.gui.windows.live_viewer.presenter.LiveViewerWindowPresenter.load_image') + @mock.patch('mantidimaging.gui.windows.live_viewer.presenter.LiveViewerWindowPresenter.load_image_from_path') @mock.patch('mantidimaging.gui.windows.live_viewer.model.ImageWatcher') @mock.patch("time.time", return_value=4000.0) def test_live_view_opens_with_data(self, _mock_time, _mock_image_watcher, mock_load_image): file_list = self._make_simple_dir(self.live_directory) image_list = [Image_Data(path) for path in file_list] + dask_image_stack = DaskImageDataStack(image_list, create_delayed_array=False) mock_load_image.return_value = self._generate_image() self.imaging.show_live_viewer(self.live_directory) - self.imaging.live_viewer.presenter.model._handle_image_changed_in_list(image_list) + self.imaging.live_viewer.presenter.model._handle_image_changed_in_list(image_list, dask_image_stack) self.check_target(widget=self.imaging.live_viewer) - @mock.patch('mantidimaging.gui.windows.live_viewer.presenter.LiveViewerWindowPresenter.load_image') + @mock.patch('mantidimaging.gui.windows.live_viewer.presenter.LiveViewerWindowPresenter.load_image_from_path') @mock.patch('mantidimaging.gui.windows.live_viewer.model.ImageWatcher') @mock.patch("time.time", return_value=4000.0) def test_live_view_opens_with_bad_data(self, _mock_time, _mock_image_watcher, mock_load_image): file_list = self._make_simple_dir(self.live_directory) image_list = [Image_Data(path) for path in file_list] + dask_image_stack = DaskImageDataStack(image_list, create_delayed_array=False) mock_load_image.side_effect = ValueError self.imaging.show_live_viewer(self.live_directory) - self.imaging.live_viewer.presenter.model._handle_image_changed_in_list(image_list) + self.imaging.live_viewer.presenter.model._handle_image_changed_in_list(image_list, dask_image_stack) self.check_target(widget=self.imaging.live_viewer) - @mock.patch('mantidimaging.gui.windows.live_viewer.presenter.LiveViewerWindowPresenter.load_image') + @mock.patch('mantidimaging.gui.windows.live_viewer.presenter.LiveViewerWindowPresenter.load_image_from_path') @mock.patch('mantidimaging.gui.windows.live_viewer.model.ImageWatcher') @mock.patch("time.time", return_value=4000.0) def test_rotate_operation_rotates_image(self, _mock_time, _mock_image_watcher, mock_load_image): file_list = self._make_simple_dir(self.live_directory) image_list = [Image_Data(path) for path in file_list] + dask_image_stack = DaskImageDataStack(image_list, create_delayed_array=False) mock_load_image.return_value = self._generate_image() self.imaging.show_live_viewer(self.live_directory) - self.imaging.live_viewer.presenter.model._handle_image_changed_in_list(image_list) + self.imaging.live_viewer.presenter.model._handle_image_changed_in_list(image_list, dask_image_stack) self.imaging.live_viewer.rotate_angles_group.actions()[1].trigger() self.check_target(widget=self.imaging.live_viewer) diff --git a/mantidimaging/gui/windows/live_viewer/live_view_widget.py b/mantidimaging/gui/windows/live_viewer/live_view_widget.py index 3a02baec8ce..c7baafd1eec 100644 --- a/mantidimaging/gui/windows/live_viewer/live_view_widget.py +++ b/mantidimaging/gui/windows/live_viewer/live_view_widget.py @@ -2,10 +2,15 @@ # SPDX - License - Identifier: GPL-3.0-or-later from __future__ import annotations from typing import TYPE_CHECKING -from pyqtgraph import GraphicsLayoutWidget +from PyQt5.QtCore import pyqtSignal +from pyqtgraph import GraphicsLayoutWidget, mkPen + +from mantidimaging.core.utility.close_enough_point import CloseEnoughPoint +from mantidimaging.core.utility.sensible_roi import SensibleROI from mantidimaging.gui.widgets.mi_mini_image_view.view import MIMiniImageView from mantidimaging.gui.widgets.zslider.zslider import ZSlider +from mantidimaging.gui.windows.spectrum_viewer.spectrum_widget import SpectrumROI if TYPE_CHECKING: import numpy as np @@ -18,6 +23,10 @@ class LiveViewWidget(GraphicsLayoutWidget): @param parent: The parent widget """ image: MIMiniImageView + image_shape: tuple = (-1, -1) + roi_changed = pyqtSignal() + roi_object: SpectrumROI | None = None + sensible_roi: SensibleROI def __init__(self) -> None: super().__init__() @@ -48,3 +57,41 @@ def handle_deleted(self) -> None: def show_error(self, message: str | None): self.image.show_message(message) + + def add_roi(self): + if self.image_shape == (-1, -1): + return + height, width = self.image_shape + roi = SensibleROI.from_list([0, 0, width, height]) + self.roi_object = SpectrumROI('roi', roi, rotatable=False, scaleSnap=True, translateSnap=True) + self.roi_object.colour = (255, 194, 10, 255) + self.roi_object.hoverPen = mkPen(self.roi_object.colour, width=3) + self.roi_object.roi.sigRegionChangeFinished.connect(self.roi_changed.emit) + self.image.vb.addItem(self.roi_object.roi) + + def set_image_shape(self, shape: tuple) -> None: + self.image_shape = shape + + def get_roi(self) -> SensibleROI: + if not self.roi_object: + return SensibleROI() + roi = self.roi_object.roi + pos = CloseEnoughPoint(roi.pos()) + size = CloseEnoughPoint(roi.size()) + return SensibleROI.from_points(pos, size) + + def set_roi_alpha(self, alpha: int) -> None: + if not self.roi_object: + return + self.roi_object.colour = self.roi_object.colour[:3] + (alpha, ) + self.roi_object.setPen(self.roi_object.colour) + self.roi_object.hoverPen = mkPen(self.roi_object.colour, width=3) + self.set_roi_visibility_flags(bool(alpha)) + + def set_roi_visibility_flags(self, visible: bool) -> None: + if not self.roi_object: + return + handles = self.roi_object.getHandles() + for handle in handles: + handle.setVisible(visible) + self.roi_object.setVisible(visible) diff --git a/mantidimaging/gui/windows/live_viewer/model.py b/mantidimaging/gui/windows/live_viewer/model.py index dd0c78671f2..2596a1741c5 100644 --- a/mantidimaging/gui/windows/live_viewer/model.py +++ b/mantidimaging/gui/windows/live_viewer/model.py @@ -3,11 +3,21 @@ from __future__ import annotations import time +from functools import lru_cache from typing import TYPE_CHECKING from pathlib import Path from logging import getLogger + +import dask.array +import numpy as np from PyQt5.QtCore import QFileSystemWatcher, QObject, pyqtSignal, QTimer +import dask_image.imread +from astropy.io import fits + +from mantidimaging.core.utility import ExecutionProfiler +from mantidimaging.core.utility.sensible_roi import SensibleROI + if TYPE_CHECKING: from os import stat_result from mantidimaging.gui.windows.live_viewer.view import LiveViewerWindowPresenter @@ -15,6 +25,201 @@ LOG = getLogger(__name__) +class DaskImageDataStack: + """ + A Dask Image Data Stack Class to hold a delayed array of all the images in the Live Viewer Path + """ + delayed_stack: dask.array.Array | None = None + image_list: list[Image_Data] + image_paths: set[str] = set() + create_delayed_array: bool + _selected_index: int + mean: np.ndarray = np.array([]) + roi: SensibleROI | None = None + param_to_calc: list[str] = [] + max_cache_size: int = 100 + buffer_size: int = 10 + + def __init__(self, image_list: list[Image_Data], create_delayed_array: bool = False): + self.image_list = image_list + self.create_delayed_array = create_delayed_array + + if image_list and create_delayed_array: + self.create_and_set_delayed_stack() + + @property + def shape(self): + return self.delayed_stack.shape + + @property + def selected_index(self): + return self._selected_index + + @selected_index.setter + def selected_index(self, index): + self._selected_index = index + + def get_delayed_arrays(self, image_list: list[Image_Data]) -> list[dask.array.Array] | None: + if image_list: + if image_list[0].image_path.suffix.lower() in [".tif", ".tiff"] and self.create_delayed_array: + return [dask_image.imread.imread(image_data.image_path)[0] for image_data in image_list] + elif image_list[0].image_path.suffix.lower() == ".fits" and self.create_delayed_array: + return [dask.delayed(fits.open)(image_data.image_path)[0].data for image_data in image_list] + else: + return None + else: + return None + + def get_delayed_image(self, index: int) -> dask.array.Array | None: + return self.delayed_stack[index] if self.delayed_stack is not None else None + + def get_image_data(self, index: int) -> Image_Data | None: + return self.image_list[index] if self.image_list else None + + def get_fits_sample(self, image_data: Image_Data) -> np.ndarray: + with fits.open(image_data.image_path.__str__()) as fit: + return fit[0].data + + @lru_cache(maxsize=max_cache_size) # noqa: B019 + def get_computed_image(self, index: int): + if index < 0: + return None + try: + image_to_compute = self.get_delayed_image(index) + if image_to_compute is not None: + image_to_compute_opt = dask.optimize(image_to_compute) + computed_image = image_to_compute_opt[0].compute() + except dask_image.imread.pims.api.UnknownFormatError: + self.remove_image_data_by_index(index) + self.get_computed_image(index - 1) + except AttributeError: + return None + return computed_image + + def get_selected_computed_image(self): + try: + return self.get_computed_image(self.selected_index) + except dask_image.imread.pims.api.UnknownFormatError: + pass + + def remove_image_data_by_path(self, image_path: Path) -> None: + image_paths = [image.image_path for image in self.image_list] + index_to_remove = image_paths.index(image_path) + self.remove_image_data_by_index(index_to_remove) + + def remove_image_data_by_index(self, index_to_remove: int) -> None: + self.image_list.pop(index_to_remove) + self.delayed_stack = dask.array.delete(self.delayed_stack, index_to_remove, 0) + if index_to_remove == self.selected_index and self.selected_index > 0: + self.selected_index = self.selected_index - 1 + if not self.image_list: + self.delayed_stack = None + + def create_delayed_stack_from_image_data(self, image_list: list[Image_Data]) -> None | dask.array.Array: + delayed_stack = None + arrays = self.get_delayed_arrays(image_list) + if arrays: + if image_list[0].image_path.suffix.lower() in [".tif", ".tiff"]: + delayed_stack = dask.array.stack(dask.array.array(arrays)) + elif image_list[0].image_path.suffix.lower() in [".fits"]: + sample = self.get_fits_sample(image_list[0]) + lazy_arrays = [dask.array.from_delayed(x, shape=sample.shape, dtype=sample.dtype) for x in arrays] + delayed_stack = dask.array.stack(lazy_arrays) + else: + raise NotImplementedError(f"DaskImageDataStack does not support image with extension " + f"{image_list[0].image_path.suffix.lower()}") + return delayed_stack + + def update_delayed_stack(self, new_image_list) -> None: + if self.delayed_stack is None: + self.delayed_stack = self.create_delayed_stack_from_image_data(new_image_list) + else: + new_images = [image for image in new_image_list if image.image_path not in self.image_paths] + self.delayed_stack = dask.optimize( + dask.array.concatenate([self.delayed_stack, + self.create_delayed_stack_from_image_data(new_images)]))[0] + + def update_image_list(self, new_image_list: list, update_stack: bool = True) -> None: + if update_stack and self.create_delayed_array: + self.update_delayed_stack(new_image_list) + self.image_list = new_image_list + self.update_image_paths(new_image_list) + + def update_param_calculations(self) -> None: + if 'mean' in self.param_to_calc: + if len(self.mean) == len(self.image_list) - 1: + self.add_last_mean() + else: + if self.roi: + self.calc_mean_fully_roi() + else: + self.calc_mean_fully() + + def update_image_paths(self, new_image_list: list): + for image in new_image_list: + self.image_paths.add(image.image_path) + + def add_last_mean(self) -> None: + if self.delayed_stack is not None: + if self.roi: + left, top, right, bottom = self.roi + mean_to_add = dask.optimize(dask.array.mean(self.delayed_stack[-1, top:bottom, + left:right]))[0].compute() + else: + mean_to_add = dask.optimize(dask.array.mean(self.delayed_stack[-1]))[0].compute() + self.mean = np.append(self.mean, mean_to_add) + self.calc_mean_buffer() + + def calc_mean_fully(self) -> None: + if self.delayed_stack is not None: + self.mean = dask.array.mean(self.delayed_stack, axis=(1, 2)).compute() + + def calc_mean_fully_roi(self): + if self.delayed_stack is not None and self.image_list: + left, top, right, bottom = self.roi + current_cache_size = self.get_computed_image.cache_info()[3] + self.mean = np.full(len(self.image_list), np.nan) + np.put(self.mean, range(-current_cache_size, 0), self.calc_mean_cached_images(left, top, right, bottom)) + + def calc_mean_cached_images(self, left, top, right, bottom): + current_cache_size = self.get_computed_image.cache_info()[3] + cache_stack = [ + self.get_computed_image(index) + for index in range(self.selected_index - current_cache_size + 1, self.selected_index + 1, 1) + ] + cache_stack_array = np.stack(cache_stack) + cache_stack_mean = np.mean(cache_stack_array[:, top:bottom, left:right], axis=(1, 2)) + return cache_stack_mean + + def calc_mean_buffer(self): + nanInds = np.argwhere(np.isnan(self.mean)) + left, top, right, bottom = self.roi + if nanInds.size > 0: + print(f"{self.mean=}") + if nanInds.size < self.buffer_size: + buffer_start = 0 + else: + buffer_start = nanInds.size - self.buffer_size + dask_mean = dask.optimize( + dask.array.mean(self.delayed_stack[buffer_start:nanInds.size, top:bottom, left:right], + axis=(1, 2)))[0].compute() + np.put(self.mean, range(buffer_start, nanInds.size), dask_mean) + + def set_roi(self, roi: SensibleROI): + self.roi = roi + + def delete_all_data(self): + self.image_list = [] + self.delayed_stack = None + self.selected_index = 0 + + def create_and_set_delayed_stack(self): + self.delayed_stack = self.create_delayed_stack_from_image_data(self.image_list) + + def add_param_to_calc(self, param_name: str): + self.param_to_calc.append(param_name) + + class Image_Data: """ Image Data Class to store represent image data. @@ -32,6 +237,7 @@ class Image_Data: image_modified_time : float last modified time of image file """ + create_delayed_array: bool def __init__(self, image_path: Path): """ @@ -102,7 +308,8 @@ def __init__(self, presenter: LiveViewerWindowPresenter): self.presenter = presenter self._dataset_path: Path | None = None self.image_watcher: ImageWatcher | None = None - self.images: list[Image_Data] = [] + self._images: list[Image_Data] = [] + self.image_stack: DaskImageDataStack = DaskImageDataStack([]) @property def path(self) -> Path | None: @@ -114,9 +321,19 @@ def path(self, path: Path) -> None: self.image_watcher = ImageWatcher(path) self.image_watcher.image_changed.connect(self._handle_image_changed_in_list) self.image_watcher.recent_image_changed.connect(self.handle_image_modified) + self.image_watcher.update_spectrum.connect(self.presenter.update_spectrum) self.image_watcher._handle_notified_of_directry_change(str(path)) - def _handle_image_changed_in_list(self, image_files: list[Image_Data]) -> None: + @property + def images(self): + return self._images if self._images is not None else None + + @images.setter + def images(self, images): + self._images = images + + def _handle_image_changed_in_list(self, image_files: list[Image_Data], + dask_image_stack: DaskImageDataStack) -> None: """ Handle an image changed event. Update the image in the view. This method is called when the image_watcher detects a change @@ -125,10 +342,16 @@ def _handle_image_changed_in_list(self, image_files: list[Image_Data]) -> None: :param image_files: list of image files """ self.images = image_files + self.image_stack = dask_image_stack + # if dask_image_stack.image_list: + # self.image_stack = dask_image_stack self.presenter.update_image_list(image_files) + self.presenter.update_image_stack(self.image_stack) def handle_image_modified(self, image_path: Path): + self.image_stack.remove_image_data_by_path(image_path) self.presenter.update_image_modified(image_path) + self.presenter.update_image_stack(self.image_stack) def close(self) -> None: """Close the model.""" @@ -160,8 +383,11 @@ class ImageWatcher(QObject): sort_images_by_modified_time(images) Sort the images by modified time. """ - image_changed = pyqtSignal(list) # Signal emitted when an image is added or removed + image_changed = pyqtSignal(list, DaskImageDataStack) # Signal emitted when an image is added or removed + update_spectrum = pyqtSignal(np.ndarray) # Signal emitted to update the Live Viewer Spectrum recent_image_changed = pyqtSignal(Path) + create_delayed_array: bool = False + image_stack = DaskImageDataStack([]) def __init__(self, directory: Path): """ @@ -188,6 +414,8 @@ def __init__(self, directory: Path): self.sub_directories: dict[Path, SubDirectory] = {} self.add_sub_directory(SubDirectory(self.directory)) + self.image_stack.add_param_to_calc('mean') + def find_images(self, directory: Path) -> list[Image_Data]: """ Find all the images in the directory. @@ -266,10 +494,23 @@ def _handle_directory_change(self) -> None: if len(images) > 0: break - images = self.sort_images_by_modified_time(images) + if len(images) == 0: + self.image_stack.delete_all_data() + + if len(images) % 50 == 0: + print("\n") + with ExecutionProfiler(msg=f"self.image_stack.update_image_list(images): {len(images)=}"): + self.image_stack.update_image_list(images) + print("\n") + else: + self.image_stack.update_image_list(images) + + if 'mean' in self.image_stack.param_to_calc: + self.update_spectrum.emit(self.image_stack.mean) + self.update_recent_watcher(images[-1:]) - self.image_changed.emit(images) + self.image_changed.emit(images, self.image_stack) @staticmethod def _is_image_file(file_name: str) -> bool: diff --git a/mantidimaging/gui/windows/live_viewer/presenter.py b/mantidimaging/gui/windows/live_viewer/presenter.py index 34984d032b0..bf1fa0b8614 100644 --- a/mantidimaging/gui/windows/live_viewer/presenter.py +++ b/mantidimaging/gui/windows/live_viewer/presenter.py @@ -6,14 +6,17 @@ from typing import TYPE_CHECKING from collections.abc import Callable from logging import getLogger + +import dask_image.imread import numpy as np +import dask.array from imagecodecs._deflate import DeflateError -from tifffile import tifffile, TiffFileError +from tifffile import tifffile from astropy.io import fits from mantidimaging.gui.mvp_base import BasePresenter -from mantidimaging.gui.windows.live_viewer.model import LiveViewerWindowModel, Image_Data +from mantidimaging.gui.windows.live_viewer.model import LiveViewerWindowModel, Image_Data, DaskImageDataStack from mantidimaging.core.operations.loader import load_filter_packages from mantidimaging.core.data import ImageStack @@ -34,6 +37,7 @@ class LiveViewerWindowPresenter(BasePresenter): view: LiveViewerWindowView model: LiveViewerWindowModel op_func: Callable + image_stack: DaskImageDataStack def __init__(self, view: LiveViewerWindowView, main_window: MainWindowView): super().__init__(view) @@ -42,6 +46,8 @@ def __init__(self, view: LiveViewerWindowView, main_window: MainWindowView): self.main_window = main_window self.model = LiveViewerWindowModel(self) self.selected_image: Image_Data | None = None + self.selected_delayed_image: dask.array.Array | None + self.filters = {f.filter_name: f for f in load_filter_packages()} def close(self) -> None: @@ -77,38 +83,68 @@ def update_image_list(self, images_list: list[Image_Data]) -> None: def select_image(self, index: int) -> None: if not self.model.images: + self.update_image_list([]) return self.selected_image = self.model.images[index] + self.image_stack = self.model.image_stack + self.image_stack.selected_index = index + if not self.selected_image: + return image_timestamp = self.selected_image.image_modified_time_stamp self.view.label_active_filename.setText(f"{self.selected_image.image_name} - {image_timestamp}") - self.display_image(self.selected_image.image_path) + self.display_image(self.selected_image, self.image_stack) - def display_image(self, image_path: Path) -> None: + def display_image(self, image_data_obj: Image_Data, delayed_image_stack: DaskImageDataStack | None) -> None: """ Display image in the view after validating contents """ try: - image_data = self.load_image(image_path) - except (OSError, KeyError, ValueError, TiffFileError, DeflateError) as error: - message = f"{type(error).__name__} reading image: {image_path}: {error}" + if (delayed_image_stack is None or delayed_image_stack.delayed_stack is None + or not delayed_image_stack.create_delayed_array): + image_data = self.load_image_from_path(image_data_obj.image_path) + else: + try: + image_data = self.load_image_from_delayed_stack(delayed_image_stack) + except (AttributeError, dask_image.imread.pims.api.UnknownFormatError): + image_data = self.load_image_from_path(image_data_obj.image_path) + except (OSError, KeyError, ValueError, DeflateError) as error: + message = f"{type(error).__name__} reading image: {image_data_obj.image_path}: {error}" logger.error(message) self.view.remove_image() self.view.live_viewer.show_error(message) return + self.view.live_viewer.set_image_shape(image_data.shape) + if not self.view.live_viewer.roi_object and self.view.spectrum_action.isChecked(): + self.view.live_viewer.add_roi() + self.model.image_stack.set_roi(self.view.live_viewer.get_roi()) image_data = self.perform_operations(image_data) + self.model.image_stack.update_param_calculations() if image_data.size == 0: message = "reading image: {image_path}: Image has zero size" - logger.error("reading image: %s: Image has zero size", image_path) + logger.error("reading image: %s: Image has zero size", image_data_obj.image_path) self.view.remove_image() self.view.live_viewer.show_error(message) return - + # if np.any(np.isnan(self.model.image_stack.mean)): + # self.model.image_stack.calc_mean_fully_roi() self.view.show_most_recent_image(image_data) + self.update_spectrum(self.model.image_stack.mean) self.view.live_viewer.show_error(None) @staticmethod - def load_image(image_path: Path) -> np.ndarray: + def load_image_from_delayed_stack(delayed_image_stack: DaskImageDataStack | None) -> np.ndarray: + """ + Load a delayed stack from a DaskImageDataStack and compute + """ + if delayed_image_stack is not None: + image_data = delayed_image_stack.get_selected_computed_image() + else: + raise ValueError + return image_data + + @staticmethod + def load_image_from_path(image_path: Path) -> np.ndarray: """ Load a .Tif, .Tiff or .Fits file only if it exists and returns as an ndarray @@ -126,14 +162,14 @@ def update_image_modified(self, image_path: Path) -> None: Update the displayed image when the file is modified """ if self.selected_image and image_path == self.selected_image.image_path: - self.display_image(image_path) + self.display_image(self.selected_image, self.image_stack) def update_image_operation(self) -> None: """ Reload the current image if an operation has been performed on the current image """ if self.selected_image is not None: - self.display_image(self.selected_image.image_path) + self.display_image(self.selected_image, self.image_stack) def convert_image_to_imagestack(self, image_data) -> ImageStack: """ @@ -159,3 +195,16 @@ def load_as_dataset(self) -> None: if self.model.images: image_dir = self.model.images[0].image_path.parent self.main_window.show_image_load_dialog_with_path(str(image_dir)) + + def update_image_stack(self, image_stack: DaskImageDataStack): + self.image_stack = image_stack + + def update_spectrum(self, spec_data: list | np.ndarray): + self.view.spectrum.clearPlots() + self.view.spectrum.plot(spec_data) + + def handle_roi_moved(self, force_new_spectrums: bool = False): + roi = self.view.live_viewer.get_roi() + self.model.image_stack.set_roi(roi) + self.model.image_stack.calc_mean_fully_roi() + self.update_spectrum(self.model.image_stack.mean) diff --git a/mantidimaging/gui/windows/live_viewer/test/model_test.py b/mantidimaging/gui/windows/live_viewer/test/model_test.py index 0fbad537e2d..2c689d7c86f 100644 --- a/mantidimaging/gui/windows/live_viewer/test/model_test.py +++ b/mantidimaging/gui/windows/live_viewer/test/model_test.py @@ -4,12 +4,17 @@ import os import time +import unittest from pathlib import Path from unittest import mock +from numpy.testing import assert_array_equal +import numpy as np +import dask.array.random from PyQt5.QtCore import QFileSystemWatcher, pyqtSignal +from parameterized import parameterized -from mantidimaging.gui.windows.live_viewer.model import ImageWatcher +from mantidimaging.gui.windows.live_viewer.model import ImageWatcher, DaskImageDataStack, Image_Data from mantidimaging.test_helpers.unit_test_helper import FakeFSTestCase @@ -26,6 +31,7 @@ def setUp(self) -> None: mocker.side_effect = [mock_dir_watcher, mock_file_watcher] self.watcher = ImageWatcher(self.top_path) + self.watcher.create_delayed_array = False self.mock_signal_image = mock.create_autospec(pyqtSignal, emit=mock.Mock()) self.watcher.image_changed = self.mock_signal_image @@ -159,3 +165,71 @@ def test_WHEN_sub_directory_change_THEN_images_emitted(self, _mock_time): emitted_images = self._get_recent_emitted_files() self._file_list_count_equal(emitted_images, file_list2) + + +class DaskImageDataStackTest(unittest.TestCase): + + def setUp(self): + self.test_array = np.array([1, 3, 5, 12, 15]) + + def _get_fake_data(self, ext: str): + file_list = [Path(f"abc_{i:06d}" + ext) for i in range(5)] + with mock.patch("mantidimaging.gui.windows.live_viewer.model.Path.stat"): + image_data_list = [Image_Data(path) for path in file_list] + fake_data_array_list = [dask.array.random.random(5) for _ in image_data_list] + fake_data_stack = dask.array.stack(fake_data_array_list) + return image_data_list, fake_data_array_list, fake_data_stack + + def test_WHEN_not_create_delayed_array_THEN_no_delayed_array_created(self): + image_data_list, _, _ = self._get_fake_data('.tif') + self.delayed_image_stack = DaskImageDataStack(image_data_list, create_delayed_array=False) + self.assertIsNone(self.delayed_image_stack.delayed_stack) + self.assertEqual(self.delayed_image_stack.image_list, image_data_list) + + @mock.patch("mantidimaging.gui.windows.live_viewer.model.DaskImageDataStack.get_delayed_arrays") + def test_WHEN_create_delayed_array_THEN_delayed_array_created(self, mock_delayed_arrays): + image_data_list, fake_data_array_list, fake_data_stack = self._get_fake_data(".tif") + mock_delayed_arrays.return_value = fake_data_array_list + self.delayed_image_stack = DaskImageDataStack(image_data_list, create_delayed_array=True) + assert_array_equal(self.delayed_image_stack.delayed_stack, fake_data_stack) + assert_array_equal(self.delayed_image_stack.delayed_stack.compute(), fake_data_stack.compute()) + + @mock.patch("mantidimaging.gui.windows.live_viewer.model.dask_image.imread.imread") + def test_WHEN_tif_file_THEN_dask_image_imread_called(self, mock_imread): + image_data_list, _, _ = self._get_fake_data('.tif') + calls = [mock.call(image.image_path) for image in image_data_list] + self.delayed_image_stack = DaskImageDataStack(image_data_list, create_delayed_array=True) + mock_imread.assert_has_calls(calls, any_order=True) + + @mock.patch("mantidimaging.gui.windows.live_viewer.model.dask.delayed") + @mock.patch("mantidimaging.gui.windows.live_viewer.model.DaskImageDataStack.get_fits_sample") + def test_WHEN_fits_file_THEN_dask_delayed_called(self, mock_fits_sample, mock_dask_delayed): + mock_fits_sample.return_value = self.test_array + image_data_list, _, _ = self._get_fake_data('.fits') + calls = [mock.call()(image.image_path) for image in image_data_list] + with mock.patch("mantidimaging.gui.windows.live_viewer.model.fits.open"): + self.delayed_image_stack = DaskImageDataStack(image_data_list, create_delayed_array=True) + mock_dask_delayed.assert_has_calls(calls, any_order=True) + + @mock.patch("mantidimaging.gui.windows.live_viewer.model.DaskImageDataStack.get_delayed_arrays") + def test_WHEN_unsupported_file_THEN_raises_error(self, mock_delayed_arrays): + image_data_list, fake_data_array_list, _ = self._get_fake_data(".jpeg") + mock_delayed_arrays.return_value = fake_data_array_list + with self.assertRaises(NotImplementedError): + self.delayed_image_stack = DaskImageDataStack(image_data_list, create_delayed_array=True) + + @parameterized.expand([".tif", ".tiff", ".fits"]) + @mock.patch("mantidimaging.gui.windows.live_viewer.model.DaskImageDataStack.get_delayed_arrays") + @mock.patch("mantidimaging.gui.windows.live_viewer.model.dask.array.from_delayed") + @mock.patch("mantidimaging.gui.windows.live_viewer.model.dask.delayed") + @mock.patch("mantidimaging.gui.windows.live_viewer.model.DaskImageDataStack.get_fits_sample") + def test_WHEN_supported_file_THEN_no_error_raised(self, file_ext, mock_fits_sample, _, mock_from_delayed, + mock_delayed_arrays): + mock_fits_sample.return_value = self.test_array + image_data_list, fake_data_array_list, _ = self._get_fake_data(file_ext) + mock_delayed_arrays.return_value = fake_data_array_list + mock_from_delayed.return_value = fake_data_array_list + try: + self.delayed_image_stack = DaskImageDataStack(image_data_list, create_delayed_array=True) + except NotImplementedError: + self.fail("DaskImageDataStack raised NotImplementedError unexpectedly!") diff --git a/mantidimaging/gui/windows/live_viewer/view.py b/mantidimaging/gui/windows/live_viewer/view.py index c7d31c88058..1c41dd447d4 100644 --- a/mantidimaging/gui/windows/live_viewer/view.py +++ b/mantidimaging/gui/windows/live_viewer/view.py @@ -4,8 +4,8 @@ from pathlib import Path from typing import TYPE_CHECKING -from PyQt5.QtCore import QSignalBlocker -from PyQt5.QtWidgets import QVBoxLayout +from PyQt5.QtCore import QSignalBlocker, Qt +from PyQt5.QtWidgets import QVBoxLayout, QSplitter from PyQt5.Qt import QAction, QActionGroup from mantidimaging.gui.mvp_base import BaseMainWindowView @@ -14,6 +14,8 @@ import numpy as np +from ..spectrum_viewer.spectrum_widget import SpectrumPlotWidget + if TYPE_CHECKING: from mantidimaging.gui.windows.main import MainWindowView # noqa:F401 # pragma: no cover @@ -33,9 +35,19 @@ def __init__(self, main_window: MainWindowView, live_dir_path: Path) -> None: self.path = live_dir_path self.presenter = LiveViewerWindowPresenter(self, main_window) self.live_viewer = LiveViewWidget() - self.imageLayout.addWidget(self.live_viewer) + self.splitter = QSplitter(Qt.Vertical) + self.imageLayout.addWidget(self.splitter) self.live_viewer.z_slider.valueChanged.connect(self.presenter.select_image) + self.spectrum_plot_widget = SpectrumPlotWidget() + self.spectrum = self.spectrum_plot_widget.spectrum + self.live_viewer.roi_changed.connect(self.presenter.handle_roi_moved) + + self.splitter.addWidget(self.live_viewer) + self.splitter.addWidget(self.spectrum_plot_widget) + widget_height = self.frameGeometry().height() + self.splitter.setSizes([widget_height, 0]) + self.filter_params: dict[str, dict] = {} self.right_click_menu = self.live_viewer.image.vb.menu operations_menu = self.right_click_menu.addMenu("Operations") @@ -54,6 +66,14 @@ def __init__(self, main_window: MainWindowView, live_dir_path: Path) -> None: self.load_as_dataset_action = self.right_click_menu.addAction("Load as dataset") self.load_as_dataset_action.triggered.connect(self.presenter.load_as_dataset) + self.spectrum_action = QAction("Calculate Spectrum", self) + self.spectrum_action.setCheckable(True) + operations_menu.addAction(self.spectrum_action) + self.spectrum_action.triggered.connect(self.set_spectrum_visibility) + self.presenter.model.image_stack.create_delayed_array = False + self.live_viewer.set_roi_alpha(self.spectrum_action.isChecked() * 255) + self.live_viewer.set_roi_visibility_flags(False) + def show(self) -> None: """Show the window""" super().show() @@ -106,3 +126,20 @@ def set_image_rotation_angle(self) -> None: def set_load_as_dataset_enabled(self, enabled: bool): self.load_as_dataset_action.setEnabled(enabled) + + def set_spectrum_visibility(self): + widget_height = self.frameGeometry().height() + if self.spectrum_action.isChecked(): + if not self.live_viewer.roi_object: + self.live_viewer.add_roi() + self.live_viewer.set_roi_alpha(255) + self.splitter.setSizes([int(0.7 * widget_height), int(0.3 * widget_height)]) + self.presenter.model.image_stack.create_delayed_array = True + self.presenter.model.image_stack.set_roi(self.live_viewer.get_roi()) + self.presenter.model.image_stack.create_and_set_delayed_stack() + self.presenter.model.image_stack.calc_mean_fully_roi() + self.presenter.update_spectrum(self.presenter.model.image_stack.mean) + else: + self.live_viewer.set_roi_alpha(0) + self.splitter.setSizes([widget_height, 0]) + self.presenter.model.image_stack.create_delayed_array = False