Skip to content

Commit

Permalink
No-cache fast painting (napari#6607)
Browse files Browse the repository at this point in the history
# References and relevant issues

closes napari#6579
supersedes napari#6583

# Description

napari#5732 introduced a cache of mapped data so that only changed indices
were mapped to texture dtypes/values and sent on to the GPU. In this PR,
an alternate strategy is introduced: rather than caching
previously-transformed data and then doing a diff with the cache, we
paint the data *and* the texture-mapped data directly.

The partial update of the on-GPU texture also introduced in napari#5732 is
maintained, as it can dramatically reduce the amount of data needing to
be transferred from CPU to GPU memory.

This PR is built on top of napari#6602.

---------

Co-authored-by: Juan Nunez-Iglesias <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jan 22, 2024
1 parent 41dcb89 commit ded3311
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 174 deletions.
38 changes: 13 additions & 25 deletions napari/layers/labels/_tests/test_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,19 @@ def test_contour_local_updates():
)


def test_data_setitem_multi_dim():
"""
this test checks if data_setitem works when some of the indices are
outside currently rendered slice
"""
# create zarr zeros array in memory
data = zarr.zeros((10, 10, 10), chunks=(5, 5, 5), dtype=np.uint32)
labels = Labels(data)
labels.data_setitem(
(np.array([0, 1]), np.array([1, 1]), np.array([0, 0])), [1, 2]
)


def test_selecting_label():
"""Test selecting label."""
np.random.seed(0)
Expand Down Expand Up @@ -1501,22 +1514,6 @@ def test_invalidate_cache_when_change_color_mode():
)


@pytest.mark.parametrize("dtype", np.sctypes['int'] + np.sctypes['uint'])
@pytest.mark.parametrize("mode", ["auto", "direct"])
def test_cache_for_dtypes(dtype, mode):
if np.dtype(dtype).itemsize <= 2:
pytest.skip("No cache")
data = np.zeros((10, 10), dtype=dtype)
labels = Labels(data)
labels.color_mode = mode
assert labels._cached_labels is None
labels._raw_to_displayed(
labels._slice.image.raw, (slice(None), slice(None))
)
assert labels._cached_labels is not None
assert labels._cached_mapped_labels.dtype == labels._slice.image.view.dtype


def test_color_mapping_when_color_is_changed():
"""Checks if the color mapping is computed correctly when the color palette is changed."""

Expand Down Expand Up @@ -1671,15 +1668,6 @@ def on_event():
assert event_emitted


def test_invalidate_cache_when_change_slice():
layer = Labels(np.zeros((2, 4, 5), dtype=np.uint32))
assert layer._cached_labels is None
layer._setup_cache(layer._slice.image.raw)
assert layer._cached_labels is not None
layer._set_view_slice()
assert layer._cached_labels is None


def test_copy():
l1 = Labels(np.zeros((2, 4, 5), dtype=np.uint8))
l2 = copy.copy(l1)
Expand Down
143 changes: 38 additions & 105 deletions napari/layers/labels/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Dict,
List,
Optional,
Sequence,
Tuple,
Union,
cast,
Expand Down Expand Up @@ -48,6 +49,7 @@
from napari.layers.utils.color_transformations import transform_color
from napari.layers.utils.layer_utils import _FeatureTable
from napari.utils._dtype import normalize_dtype, vispy_texture_dtype
from napari.utils._indexing import elements_in_slice, index_in_slice
from napari.utils.colormaps import (
direct_colormap,
ensure_colormap,
Expand All @@ -56,15 +58,11 @@
from napari.utils.colormaps.colormap import (
LabelColormap,
LabelColormapBase,
_cast_labels_data_to_texture_dtype_auto,
_cast_labels_data_to_texture_dtype_direct,
_texture_dtype,
)
from napari.utils.colormaps.colormap_utils import shuffle_and_extend_colormap
from napari.utils.events import EmitterGroup, Event
from napari.utils.events.custom_types import Array
from napari.utils.geometry import clamp_point_to_bounding_box
from napari.utils.indexing import index_in_slice
from napari.utils.migrations import deprecated_constructor_arg_by_attr
from napari.utils.misc import StringEnum, _is_array_type
from napari.utils.naming import magic_name
Expand Down Expand Up @@ -322,8 +320,6 @@ def __init__(
self._color_mode = LabelColorMode.AUTO
self._show_selected_label = False
self._contour = 0
self._cached_labels = None
self._cached_mapped_labels = np.zeros((0, 4), dtype=np.uint8)

data = self._ensure_int_labels(data)

Expand Down Expand Up @@ -492,7 +488,6 @@ def seed(self, seed):
self.colormap = label_colormap(
self.num_colors, self.seed, self._background_label
)
self._cached_labels = None # invalidate the cached color mapping
self._selected_color = self.get_color(self.selected_label)
self.events.colormap() # Will update the LabelVispyColormap shader
self.refresh()
Expand All @@ -516,7 +511,6 @@ def seed_rng(self, seed_rng: Optional[int]) -> None:
self._random_colormap = shuffle_and_extend_colormap(
self._original_random_colormap, self._seed_rng
)
self._cached_labels = None # invalidate the cached color mapping
self._selected_color = self.get_color(self.selected_label)
self.events.colormap() # Will update the LabelVispyColormap shader
self.events.selected_label()
Expand Down Expand Up @@ -570,8 +564,6 @@ def num_colors(self, num_colors):
num_colors, self.seed, self._background_label
)
self._num_colors = num_colors
self._cached_labels = None # invalidate the cached color mapping
self._cached_mapped_labels = None
self.refresh()
self._selected_color = self.get_color(self.selected_label)
self.events.selected_label()
Expand Down Expand Up @@ -767,7 +759,6 @@ def selected_label(self, selected_label):
self.events.selected_label()

if self.show_selected_label:
self._cached_labels = None # invalidates labels cache
self.refresh()

def swap_selected_and_background_labels(self):
Expand All @@ -790,7 +781,6 @@ def color_mode(self):
@color_mode.setter
def color_mode(self, color_mode: Union[str, LabelColorMode]):
color_mode = LabelColorMode(color_mode)
self._cached_labels = None # invalidates labels cache
self._color_mode = color_mode
if color_mode == LabelColorMode.AUTO:
self._colormap = ensure_colormap(self._random_colormap)
Expand All @@ -813,7 +803,6 @@ def show_selected_label(self, show_selected):
self.colormap.use_selection = show_selected
self.colormap.selection = self.selected_label
self.events.show_selected_label(show_selected_label=show_selected)
self._cached_labels = None
self.refresh()

# Only overriding to change the docstring
Expand Down Expand Up @@ -914,9 +903,12 @@ def _partial_labels_refresh(self):

offset = [axis_slice.start for axis_slice in updated_slice]

colors_sliced = self._raw_to_displayed(
raw_displayed, data_slice=updated_slice
)
if self.contour > 0:
colors_sliced = self._raw_to_displayed(
raw_displayed, data_slice=updated_slice
)
else:
colors_sliced = self._slice.image.view[updated_slice]
# The next line is needed to make the following tests pass in
# napari/_vispy/_tests/:
# - test_vispy_labels_layer.py::test_labels_painting
Expand Down Expand Up @@ -972,44 +964,6 @@ def _calculate_contour(
)
return sliced_labels[delta_slice]

def _get_cache_dtype(self, raw_dtype: np.dtype) -> np.dtype:
if self.color_mode == LabelColorMode.DIRECT:
return _texture_dtype(
self._direct_colormap._num_unique_colors + 2,
raw_dtype,
)
return _texture_dtype(self.num_colors, raw_dtype)

def _setup_cache(self, labels):
"""
Initializes the cache for the Labels layer
Parameters
----------
labels : numpy array
The labels data to be cached
"""
if self._cached_labels is not None:
return

if isinstance(self._colormap, LabelColormap):
mapped_background = _cast_labels_data_to_texture_dtype_auto(
labels.dtype.type(self.colormap.background_value),
self._random_colormap,
)
else: # direct
mapped_background = _cast_labels_data_to_texture_dtype_direct(
labels.dtype.type(self.colormap.background_value),
self._direct_colormap,
)

self._cached_labels = np.zeros_like(labels)
self._cached_mapped_labels = np.full(
shape=labels.shape,
fill_value=mapped_background,
dtype=self._get_cache_dtype(labels.dtype),
)

def _raw_to_displayed(
self, raw, data_slice: Optional[Tuple[slice, ...]] = None
) -> np.ndarray:
Expand All @@ -1036,9 +990,6 @@ def _raw_to_displayed(

if data_slice is None:
data_slice = tuple(slice(0, size) for size in raw.shape)
setup_cache = False
else:
setup_cache = True

labels = raw # for readability

Expand All @@ -1049,44 +1000,7 @@ def _raw_to_displayed(
if sliced_labels is None:
sliced_labels = labels[data_slice]

if sliced_labels.dtype.itemsize <= 2:
return self.colormap._data_to_texture(sliced_labels)

if setup_cache:
self._setup_cache(raw)
else:
self._cached_labels = None

# cache the labels and keep track of when values are changed
update_mask = None
if (
self._cached_labels is not None
and self._cached_mapped_labels is not None
and self._cached_labels.shape == labels.shape
):
update_mask = self._cached_labels[data_slice] != sliced_labels
# Select only a subset with changes for further computations
labels_to_map = sliced_labels[update_mask]
# Update the cache
self._cached_labels[data_slice][update_mask] = labels_to_map
else:
labels_to_map = sliced_labels

# If there are no changes, just return the cached image
if labels_to_map.size == 0:
return self._cached_mapped_labels[data_slice]

mapped_labels = self.colormap._data_to_texture(labels_to_map)

if self._cached_labels is not None:
if update_mask is not None:
self._cached_mapped_labels[data_slice][
update_mask
] = mapped_labels
else:
self._cached_mapped_labels[data_slice] = mapped_labels
return self._cached_mapped_labels[data_slice]
return mapped_labels
return self.colormap._data_to_texture(sliced_labels)

def _update_thumbnail(self):
"""Update the thumbnail with current data and colormap.
Expand Down Expand Up @@ -1588,6 +1502,16 @@ def _get_shape_and_dims_to_paint(self) -> Tuple[list, list]:
def _get_dims_to_paint(self) -> list:
return list(self._slice_input.order[-self.n_edit_dimensions :])

def _get_pt_not_disp(self) -> Dict[int, int]:
"""
Get indices of current visible slice.
"""
slice_input = self._slice.slice_input
point = np.round(
self.world_to_data(slice_input.world_slice.point)
).astype(int)
return {dim: point[dim] for dim in slice_input.not_displayed}

def data_setitem(self, indices, value, refresh=True):
"""Set `indices` in `data` to `value`, while writing to edit history.
Expand All @@ -1607,7 +1531,12 @@ def data_setitem(self, indices, value, refresh=True):
..[1] https://numpy.org/doc/stable/user/basics.indexing.html
"""
changed_indices = self.data[indices] != value
indices = tuple([x[changed_indices] for x in indices])
indices = tuple(x[changed_indices] for x in indices)

if isinstance(value, Sequence):
value = np.asarray(value, dtype=self._slice.image.raw.dtype)
else:
value = self._slice.image.raw.dtype.type(value)

if not indices or indices[0].size == 0:
return
Expand All @@ -1623,6 +1552,13 @@ def data_setitem(self, indices, value, refresh=True):
# update the labels image
self.data[indices] = value

pt_not_disp = self._get_pt_not_disp()
displayed_indices = index_in_slice(indices, pt_not_disp)
if isinstance(value, np.ndarray):
visible_values = value[elements_in_slice(indices, pt_not_disp)]
else:
visible_values = value

if not ( # if not a numpy array or numpy-backed xarray
isinstance(self.data, np.ndarray)
or isinstance(getattr(self.data, 'data', None), np.ndarray)
Expand All @@ -1632,15 +1568,7 @@ def data_setitem(self, indices, value, refresh=True):
# array, or a NumPy-array-backed Xarray, is the slice a view and
# therefore updated automatically.
# For other types, we update it manually here.
slice_input = self._slice.slice_input
point = np.round(
self.world_to_data(slice_input.world_slice.point)
).astype(int)
pt_not_disp = {
dim: point[dim] for dim in slice_input.not_displayed
}
displayed_indices = index_in_slice(indices, pt_not_disp)
self._slice.image.raw[displayed_indices] = value
self._slice.image.raw[displayed_indices] = visible_values

# tensorstore and xarray do not return their indices in
# np.ndarray format, so they need to be converted explicitly
Expand All @@ -1659,6 +1587,11 @@ def data_setitem(self, indices, value, refresh=True):
# the original slice because of the morphological dilation
# (1 pixel because get_countours always applies 1 pixel dilation)
updated_slice = expand_slice(updated_slice, self.data.shape, 1)
else:
# update data view
self._slice.image.view[
displayed_indices
] = self.colormap._data_to_texture(visible_values)

if self._updated_slice is None:
self._updated_slice = updated_slice
Expand Down
Loading

0 comments on commit ded3311

Please sign in to comment.