From ded3311b1fae815ef0af3572870bf50e67fdcbda Mon Sep 17 00:00:00 2001 From: Grzegorz Bokota Date: Mon, 22 Jan 2024 18:09:19 +0100 Subject: [PATCH] No-cache fast painting (#6607) # References and relevant issues closes #6579 supersedes #6583 # Description #5732 introduced a cache of mapped data so that only changed indices were mapped to texture dtypes/values and sent on to the GPU. In this PR, an alternate strategy is introduced: rather than caching previously-transformed data and then doing a diff with the cache, we paint the data *and* the texture-mapped data directly. The partial update of the on-GPU texture also introduced in #5732 is maintained, as it can dramatically reduce the amount of data needing to be transferred from CPU to GPU memory. This PR is built on top of #6602. --------- Co-authored-by: Juan Nunez-Iglesias Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- napari/layers/labels/_tests/test_labels.py | 38 ++---- napari/layers/labels/labels.py | 143 ++++++--------------- napari/utils/_indexing.py | 72 +++++++++++ napari/utils/colormaps/colormap.py | 19 ++- napari/utils/indexing.py | 50 ++----- setup.cfg | 1 + 6 files changed, 149 insertions(+), 174 deletions(-) create mode 100644 napari/utils/_indexing.py diff --git a/napari/layers/labels/_tests/test_labels.py b/napari/layers/labels/_tests/test_labels.py index 21737e5901f..1f0cebcb75a 100644 --- a/napari/layers/labels/_tests/test_labels.py +++ b/napari/layers/labels/_tests/test_labels.py @@ -664,6 +664,19 @@ def test_contour_local_updates(): ) +def test_data_setitem_multi_dim(): + """ + this test checks if data_setitem works when some of the indices are + outside currently rendered slice + """ + # create zarr zeros array in memory + data = zarr.zeros((10, 10, 10), chunks=(5, 5, 5), dtype=np.uint32) + labels = Labels(data) + labels.data_setitem( + (np.array([0, 1]), np.array([1, 1]), np.array([0, 0])), [1, 2] + ) + + def test_selecting_label(): """Test selecting label.""" np.random.seed(0) @@ -1501,22 +1514,6 @@ def test_invalidate_cache_when_change_color_mode(): ) -@pytest.mark.parametrize("dtype", np.sctypes['int'] + np.sctypes['uint']) -@pytest.mark.parametrize("mode", ["auto", "direct"]) -def test_cache_for_dtypes(dtype, mode): - if np.dtype(dtype).itemsize <= 2: - pytest.skip("No cache") - data = np.zeros((10, 10), dtype=dtype) - labels = Labels(data) - labels.color_mode = mode - assert labels._cached_labels is None - labels._raw_to_displayed( - labels._slice.image.raw, (slice(None), slice(None)) - ) - assert labels._cached_labels is not None - assert labels._cached_mapped_labels.dtype == labels._slice.image.view.dtype - - def test_color_mapping_when_color_is_changed(): """Checks if the color mapping is computed correctly when the color palette is changed.""" @@ -1671,15 +1668,6 @@ def on_event(): assert event_emitted -def test_invalidate_cache_when_change_slice(): - layer = Labels(np.zeros((2, 4, 5), dtype=np.uint32)) - assert layer._cached_labels is None - layer._setup_cache(layer._slice.image.raw) - assert layer._cached_labels is not None - layer._set_view_slice() - assert layer._cached_labels is None - - def test_copy(): l1 = Labels(np.zeros((2, 4, 5), dtype=np.uint8)) l2 = copy.copy(l1) diff --git a/napari/layers/labels/labels.py b/napari/layers/labels/labels.py index 6c6078c2796..04956bc3de6 100644 --- a/napari/layers/labels/labels.py +++ b/napari/layers/labels/labels.py @@ -7,6 +7,7 @@ Dict, List, Optional, + Sequence, Tuple, Union, cast, @@ -48,6 +49,7 @@ from napari.layers.utils.color_transformations import transform_color from napari.layers.utils.layer_utils import _FeatureTable from napari.utils._dtype import normalize_dtype, vispy_texture_dtype +from napari.utils._indexing import elements_in_slice, index_in_slice from napari.utils.colormaps import ( direct_colormap, ensure_colormap, @@ -56,15 +58,11 @@ from napari.utils.colormaps.colormap import ( LabelColormap, LabelColormapBase, - _cast_labels_data_to_texture_dtype_auto, - _cast_labels_data_to_texture_dtype_direct, - _texture_dtype, ) from napari.utils.colormaps.colormap_utils import shuffle_and_extend_colormap from napari.utils.events import EmitterGroup, Event from napari.utils.events.custom_types import Array from napari.utils.geometry import clamp_point_to_bounding_box -from napari.utils.indexing import index_in_slice from napari.utils.migrations import deprecated_constructor_arg_by_attr from napari.utils.misc import StringEnum, _is_array_type from napari.utils.naming import magic_name @@ -322,8 +320,6 @@ def __init__( self._color_mode = LabelColorMode.AUTO self._show_selected_label = False self._contour = 0 - self._cached_labels = None - self._cached_mapped_labels = np.zeros((0, 4), dtype=np.uint8) data = self._ensure_int_labels(data) @@ -492,7 +488,6 @@ def seed(self, seed): self.colormap = label_colormap( self.num_colors, self.seed, self._background_label ) - self._cached_labels = None # invalidate the cached color mapping self._selected_color = self.get_color(self.selected_label) self.events.colormap() # Will update the LabelVispyColormap shader self.refresh() @@ -516,7 +511,6 @@ def seed_rng(self, seed_rng: Optional[int]) -> None: self._random_colormap = shuffle_and_extend_colormap( self._original_random_colormap, self._seed_rng ) - self._cached_labels = None # invalidate the cached color mapping self._selected_color = self.get_color(self.selected_label) self.events.colormap() # Will update the LabelVispyColormap shader self.events.selected_label() @@ -570,8 +564,6 @@ def num_colors(self, num_colors): num_colors, self.seed, self._background_label ) self._num_colors = num_colors - self._cached_labels = None # invalidate the cached color mapping - self._cached_mapped_labels = None self.refresh() self._selected_color = self.get_color(self.selected_label) self.events.selected_label() @@ -767,7 +759,6 @@ def selected_label(self, selected_label): self.events.selected_label() if self.show_selected_label: - self._cached_labels = None # invalidates labels cache self.refresh() def swap_selected_and_background_labels(self): @@ -790,7 +781,6 @@ def color_mode(self): @color_mode.setter def color_mode(self, color_mode: Union[str, LabelColorMode]): color_mode = LabelColorMode(color_mode) - self._cached_labels = None # invalidates labels cache self._color_mode = color_mode if color_mode == LabelColorMode.AUTO: self._colormap = ensure_colormap(self._random_colormap) @@ -813,7 +803,6 @@ def show_selected_label(self, show_selected): self.colormap.use_selection = show_selected self.colormap.selection = self.selected_label self.events.show_selected_label(show_selected_label=show_selected) - self._cached_labels = None self.refresh() # Only overriding to change the docstring @@ -914,9 +903,12 @@ def _partial_labels_refresh(self): offset = [axis_slice.start for axis_slice in updated_slice] - colors_sliced = self._raw_to_displayed( - raw_displayed, data_slice=updated_slice - ) + if self.contour > 0: + colors_sliced = self._raw_to_displayed( + raw_displayed, data_slice=updated_slice + ) + else: + colors_sliced = self._slice.image.view[updated_slice] # The next line is needed to make the following tests pass in # napari/_vispy/_tests/: # - test_vispy_labels_layer.py::test_labels_painting @@ -972,44 +964,6 @@ def _calculate_contour( ) return sliced_labels[delta_slice] - def _get_cache_dtype(self, raw_dtype: np.dtype) -> np.dtype: - if self.color_mode == LabelColorMode.DIRECT: - return _texture_dtype( - self._direct_colormap._num_unique_colors + 2, - raw_dtype, - ) - return _texture_dtype(self.num_colors, raw_dtype) - - def _setup_cache(self, labels): - """ - Initializes the cache for the Labels layer - - Parameters - ---------- - labels : numpy array - The labels data to be cached - """ - if self._cached_labels is not None: - return - - if isinstance(self._colormap, LabelColormap): - mapped_background = _cast_labels_data_to_texture_dtype_auto( - labels.dtype.type(self.colormap.background_value), - self._random_colormap, - ) - else: # direct - mapped_background = _cast_labels_data_to_texture_dtype_direct( - labels.dtype.type(self.colormap.background_value), - self._direct_colormap, - ) - - self._cached_labels = np.zeros_like(labels) - self._cached_mapped_labels = np.full( - shape=labels.shape, - fill_value=mapped_background, - dtype=self._get_cache_dtype(labels.dtype), - ) - def _raw_to_displayed( self, raw, data_slice: Optional[Tuple[slice, ...]] = None ) -> np.ndarray: @@ -1036,9 +990,6 @@ def _raw_to_displayed( if data_slice is None: data_slice = tuple(slice(0, size) for size in raw.shape) - setup_cache = False - else: - setup_cache = True labels = raw # for readability @@ -1049,44 +1000,7 @@ def _raw_to_displayed( if sliced_labels is None: sliced_labels = labels[data_slice] - if sliced_labels.dtype.itemsize <= 2: - return self.colormap._data_to_texture(sliced_labels) - - if setup_cache: - self._setup_cache(raw) - else: - self._cached_labels = None - - # cache the labels and keep track of when values are changed - update_mask = None - if ( - self._cached_labels is not None - and self._cached_mapped_labels is not None - and self._cached_labels.shape == labels.shape - ): - update_mask = self._cached_labels[data_slice] != sliced_labels - # Select only a subset with changes for further computations - labels_to_map = sliced_labels[update_mask] - # Update the cache - self._cached_labels[data_slice][update_mask] = labels_to_map - else: - labels_to_map = sliced_labels - - # If there are no changes, just return the cached image - if labels_to_map.size == 0: - return self._cached_mapped_labels[data_slice] - - mapped_labels = self.colormap._data_to_texture(labels_to_map) - - if self._cached_labels is not None: - if update_mask is not None: - self._cached_mapped_labels[data_slice][ - update_mask - ] = mapped_labels - else: - self._cached_mapped_labels[data_slice] = mapped_labels - return self._cached_mapped_labels[data_slice] - return mapped_labels + return self.colormap._data_to_texture(sliced_labels) def _update_thumbnail(self): """Update the thumbnail with current data and colormap. @@ -1588,6 +1502,16 @@ def _get_shape_and_dims_to_paint(self) -> Tuple[list, list]: def _get_dims_to_paint(self) -> list: return list(self._slice_input.order[-self.n_edit_dimensions :]) + def _get_pt_not_disp(self) -> Dict[int, int]: + """ + Get indices of current visible slice. + """ + slice_input = self._slice.slice_input + point = np.round( + self.world_to_data(slice_input.world_slice.point) + ).astype(int) + return {dim: point[dim] for dim in slice_input.not_displayed} + def data_setitem(self, indices, value, refresh=True): """Set `indices` in `data` to `value`, while writing to edit history. @@ -1607,7 +1531,12 @@ def data_setitem(self, indices, value, refresh=True): ..[1] https://numpy.org/doc/stable/user/basics.indexing.html """ changed_indices = self.data[indices] != value - indices = tuple([x[changed_indices] for x in indices]) + indices = tuple(x[changed_indices] for x in indices) + + if isinstance(value, Sequence): + value = np.asarray(value, dtype=self._slice.image.raw.dtype) + else: + value = self._slice.image.raw.dtype.type(value) if not indices or indices[0].size == 0: return @@ -1623,6 +1552,13 @@ def data_setitem(self, indices, value, refresh=True): # update the labels image self.data[indices] = value + pt_not_disp = self._get_pt_not_disp() + displayed_indices = index_in_slice(indices, pt_not_disp) + if isinstance(value, np.ndarray): + visible_values = value[elements_in_slice(indices, pt_not_disp)] + else: + visible_values = value + if not ( # if not a numpy array or numpy-backed xarray isinstance(self.data, np.ndarray) or isinstance(getattr(self.data, 'data', None), np.ndarray) @@ -1632,15 +1568,7 @@ def data_setitem(self, indices, value, refresh=True): # array, or a NumPy-array-backed Xarray, is the slice a view and # therefore updated automatically. # For other types, we update it manually here. - slice_input = self._slice.slice_input - point = np.round( - self.world_to_data(slice_input.world_slice.point) - ).astype(int) - pt_not_disp = { - dim: point[dim] for dim in slice_input.not_displayed - } - displayed_indices = index_in_slice(indices, pt_not_disp) - self._slice.image.raw[displayed_indices] = value + self._slice.image.raw[displayed_indices] = visible_values # tensorstore and xarray do not return their indices in # np.ndarray format, so they need to be converted explicitly @@ -1659,6 +1587,11 @@ def data_setitem(self, indices, value, refresh=True): # the original slice because of the morphological dilation # (1 pixel because get_countours always applies 1 pixel dilation) updated_slice = expand_slice(updated_slice, self.data.shape, 1) + else: + # update data view + self._slice.image.view[ + displayed_indices + ] = self.colormap._data_to_texture(visible_values) if self._updated_slice is None: self._updated_slice = updated_slice diff --git a/napari/utils/_indexing.py b/napari/utils/_indexing.py new file mode 100644 index 00000000000..5b6fe9ef8e0 --- /dev/null +++ b/napari/utils/_indexing.py @@ -0,0 +1,72 @@ +from typing import Dict, Tuple + +import numpy as np +import numpy.typing as npt + + +def elements_in_slice( + index: Tuple[npt.NDArray[np.int_], ...], position_in_axes: Dict[int, int] +) -> npt.NDArray[np.bool_]: + """Mask elements from a multi-dimensional index not in a given slice. + + Some n-D operations may edit data that is not visible in the current slice. + Given slice position information (as a dictionary mapping axis to index on that + axis), this function returns a boolean mask for the possibly higher-dimensional + multi-index so that elements not currently visible are masked out. The + resulting multi-index can then be subset and used to index into a texture or + other lower-dimensional view. + + Parameters + ---------- + index : tuple of array of int + A NumPy fancy indexing expression [1]_. + position_in_axes : dict[int, int] + A dictionary mapping sliced (non-displayed) axes to a slice position. + + Returns + ------- + visible : array of bool + A boolean array indicating which items are visible in the current view. + """ + queries = [ + index[ax] == position for ax, position in position_in_axes.items() + ] + return np.logical_and.reduce(queries, axis=0) + + +def index_in_slice( + index: Tuple[npt.NDArray[np.int_], ...], position_in_axes: Dict[int, int] +) -> Tuple[npt.NDArray[np.int_], ...]: + """Convert a NumPy fancy indexing expression from data to sliced space. + + Parameters + ---------- + index : tuple of array of int + A NumPy fancy indexing expression [1]_. + position_in_axes : dict[int, int] + A dictionary mapping sliced (non-displayed) axes to a slice position. + + Returns + ------- + sliced_index : tuple of array of int + The indexing expression (nD) restricted to the current slice (usually + 2D or 3D). + + Examples + -------- + >>> index = (np.arange(5), np.full(5, 1), np.arange(4, 9)) + >>> index_in_slice(index, {0: 3}) + (array([1]), array([7])) + >>> index_in_slice(index, {1: 1, 2: 8}) + (array([4]),) + + References + ---------- + [1]: https://numpy.org/doc/stable/user/basics.indexing.html#integer-array-indexing + """ + index_in_slice = elements_in_slice(index, position_in_axes) + return tuple( + ix[index_in_slice] + for i, ix in enumerate(index) + if i not in position_in_axes + ) diff --git a/napari/utils/colormaps/colormap.py b/napari/utils/colormaps/colormap.py index fbc31aa9285..d24d29c6a6b 100644 --- a/napari/utils/colormaps/colormap.py +++ b/napari/utils/colormaps/colormap.py @@ -694,21 +694,27 @@ def _cast_labels_data_to_texture_dtype_auto( data_arr = np.atleast_1d(data) num_colors = len(colormap.colors) - 1 + zero_preserving_modulo_func = _zero_preserving_modulo + if isinstance(data, np.integer): + zero_preserving_modulo_func = _zero_preserving_modulo_numpy dtype = minimum_dtype_for_labels(num_colors + 1) if colormap.use_selection: - selection_in_texture = _zero_preserving_modulo( + selection_in_texture = _zero_preserving_modulo_numpy( np.array([colormap.selection]), num_colors, dtype ) converted = np.where( data_arr == colormap.selection, selection_in_texture, dtype.type(0) ) else: - converted = _zero_preserving_modulo( + converted = zero_preserving_modulo_func( data_arr, num_colors, dtype, colormap.background_value ) + if isinstance(data, np.integer): + return dtype.type(converted[0]) + return np.reshape(converted, original_shape) @@ -862,6 +868,15 @@ def _cast_labels_data_to_texture_dtype_direct( if data.itemsize <= 2: return data + if isinstance(data, np.integer): + mapper = direct_colormap._label_mapping_and_color_dict[0] + target_dtype = minimum_dtype_for_labels( + direct_colormap._num_unique_colors + 2 + ) + return target_dtype.type( + mapper.get(int(data), MAPPING_OF_UNKNOWN_VALUE) + ) + original_shape = np.shape(data) array_data = np.atleast_1d(data) return np.reshape( diff --git a/napari/utils/indexing.py b/napari/utils/indexing.py index 44606f794e4..110ff7aae87 100644 --- a/napari/utils/indexing.py +++ b/napari/utils/indexing.py @@ -1,45 +1,11 @@ -from typing import Dict, Tuple +import warnings -import numpy as np -import numpy.typing as npt +from napari.utils._indexing import index_in_slice +__all__ = ['index_in_slice'] -def index_in_slice( - index: Tuple[npt.NDArray[np.int_], ...], position_in_axes: Dict[int, int] -) -> Tuple[npt.NDArray[np.int_], ...]: - """Convert a NumPy fancy indexing expression from data to sliced space. - - Parameters - ---------- - index : tuple of array of int - A NumPy fancy indexing expression [1]_. - position_in_axes : dict[int, int] - A dictionary mapping sliced (non-displayed) axes to a slice position. - - Returns - ------- - sliced_index : tuple of array of int - The indexing expression (nD) restricted to the current slice (usually - 2D or 3D). - - Examples - -------- - >>> index = (np.arange(5), np.full(5, 1), np.arange(4, 9)) - >>> index_in_slice(index, {0: 3}) - (array([1]), array([7])) - >>> index_in_slice(index, {1: 1, 2: 8}) - (array([4]),) - - References - ---------- - [1]: https://numpy.org/doc/stable/user/basics.indexing.html#integer-array-indexing - """ - queries = [ - index[ax] == position for ax, position in position_in_axes.items() - ] - index_in_slice = np.logical_and.reduce(queries, axis=0) - return tuple( - ix[index_in_slice] - for i, ix in enumerate(index) - if i not in position_in_axes - ) +warnings.warn( + "napari.utils.indexing is deprecated since 0.4.19 and will be removed in 0.5.0.", + FutureWarning, + stacklevel=2, +) diff --git a/setup.cfg b/setup.cfg index d2ac6e45d66..15b80611f6c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -181,6 +181,7 @@ omit = */_vendor/* */_version.py */benchmarks/* + napari/utils/indexing.py source = napari napari_builtins