From b14ffc895ce88122fc82e9435541a8b7d62c0f0a Mon Sep 17 00:00:00 2001 From: Joep Vanlier Date: Tue, 6 Aug 2024 15:25:01 +0200 Subject: [PATCH] channel: add optional global cache --- lumicks/pylake/__init__.py | 1 + lumicks/pylake/channel.py | 148 ++++++++++++++---- lumicks/pylake/file.py | 3 +- .../tests/test_channels/test_channels.py | 6 +- .../pylake/tests/test_file/test_caching.py | 65 ++++++++ pyproject.toml | 2 +- 6 files changed, 188 insertions(+), 37 deletions(-) create mode 100644 lumicks/pylake/tests/test_file/test_caching.py diff --git a/lumicks/pylake/__init__.py b/lumicks/pylake/__init__.py index 1c936d22f..8adc57819 100644 --- a/lumicks/pylake/__init__.py +++ b/lumicks/pylake/__init__.py @@ -1,5 +1,6 @@ from . import simulation from .file import * +from .channel import set_cache_enabled from .scalebar import ScaleBar from .__about__ import ( __doc__, diff --git a/lumicks/pylake/channel.py b/lumicks/pylake/channel.py index 83b63d679..5c1192bc7 100644 --- a/lumicks/pylake/channel.py +++ b/lumicks/pylake/channel.py @@ -5,12 +5,80 @@ import numpy as np import numpy.typing as npt +from cachetools import LRUCache, cached from .detail.plotting import _annotate from .detail.timeindex import to_seconds, to_timestamp from .detail.utilities import downsample from .nb_widgets.range_selector import SliceRangeSelectorWidget +global_cache = False + + +def set_cache_enabled(enabled): + """Enable or disable the global cache + + Pylake offers a global cache. When the global cache is enabled, all `Slice` objects come from + the same cache. + + Parameters + ---------- + enabled : bool + Whether the caching should be enabled (by default it is off) + """ + global global_cache + global_cache = enabled + + +@cached(LRUCache(maxsize=1 << 30, getsizeof=lambda x: x.nbytes), info=True) # 1 GB of cache +def _get_array(cache_object): + return cache_object.read_array() + + +class LazyCache: + def __init__(self, location, dset, nbytes): + """A lazy globally cached wrapper around an object that is convertible to a numpy array""" + self._location = location + self._dset = dset + self._nbytes = nbytes + + def __len__(self): + return len(self._dset) + + @property + def nbytes(self): + return self._nbytes + + def __hash__(self): + return hash(self._location) + + @staticmethod + def from_h5py_dset(dset, field=None): + location = f"{dset.file.filename}{dset.name}" + if field: + location = f"{location}.{field}" + dset = dset.fields(field) + item_size = dset.read_dtype.itemsize + else: + item_size = dset.dtype.itemsize + + return LazyCache(location, dset, nbytes=item_size * len(dset)) + + def read_array(self): + # Note, we deliberately do _not_ allow additional arguments to asarray since we would + # have to hash those with and unless necessary, they would unnecessarily increase the + # cache (because of sometimes defensively adding an explicit type). It's better to raise + # in this case and end up at this comment. + arr = np.asarray(self._dset) + arr.flags.writeable = False + return arr + + def __eq__(self, other): + return self._location == other._location + + def __array__(self): + return _get_array(self) + class Slice: """A lazily evaluated slice of a timeline/HDF5 channel @@ -693,7 +761,7 @@ def from_dataset(dset, y_label="y", calibration=None): start = dset.attrs["Start time (ns)"] dt = int(1e9 / dset.attrs["Sample rate (Hz)"]) # ns return Slice( - Continuous(dset, start, dt), + Continuous(LazyCache.from_h5py_dset(dset) if global_cache else dset, start, dt), labels={"title": dset.name.strip("/"), "y": y_label}, calibration=calibration, ) @@ -719,9 +787,12 @@ def to_dataset(self, parent, name, **kwargs): @property def data(self) -> npt.ArrayLike: - if self._cached_data is None: - self._cached_data = np.asarray(self._src_data) - return self._cached_data + if global_cache: + return np.asarray(self._src_data) # Reads from cache if it exists + else: + if self._cached_data is None: + self._cached_data = np.asarray(self._src_data) + return self._cached_data @property def timestamps(self) -> npt.ArrayLike: @@ -796,32 +867,14 @@ def _apply_mask(self, mask): @staticmethod def from_dataset(dset, y_label="y", calibration=None) -> Slice: - class LazyLoadedCompoundField: - """Wrapper to enable lazy loading of HDF5 compound datasets - - Notes - ----- - We only need to support the methods `__array__()` and `__len__()`, as we only access - `LazyLoadedCompoundField` via the properties `TimeSeries.data`, `timestamps` and the - method `__len__()`. - - `LazyLoadCompoundField` might be replaced with `dset.fields(fieldname)` if and when the - returned `FieldsWrapper` object provides an `__array__()` method itself""" - - def __init__(self, dset, fieldname): - self._dset = dset - self._fieldname = fieldname - - def __array__(self): - """Get the data of the field as an array""" - return self._dset[self._fieldname] - - def __len__(self): - """Get the length of the underlying dataset""" - return len(self._dset) - - data = LazyLoadedCompoundField(dset, "Value") - timestamps = LazyLoadedCompoundField(dset, "Timestamp") + data = ( + LazyCache.from_h5py_dset(dset, field="Value") if global_cache else dset.fields("Value") + ) + timestamps = ( + LazyCache.from_h5py_dset(dset, field="Timestamp") + if global_cache + else dset.fields("Timestamp") + ) return Slice( TimeSeries(data, timestamps), labels={"title": dset.name.strip("/"), "y": y_label}, @@ -850,12 +903,18 @@ def to_dataset(self, parent, name, **kwargs): @property def data(self) -> npt.ArrayLike: + if global_cache: + return np.asarray(self._src_data) + if self._cached_data is None: self._cached_data = np.asarray(self._src_data) return self._cached_data @property def timestamps(self) -> npt.ArrayLike: + if global_cache: + return np.asarray(self._src_timestamps) + if self._cached_timestamps is None: self._cached_timestamps = np.asarray(self._src_timestamps) return self._cached_timestamps @@ -907,13 +966,31 @@ class TimeTags: """ def __init__(self, data, start=None, stop=None): - self.data = np.asarray(data, dtype=np.int64) - self.start = start if start is not None else (self.data[0] if self.data.size > 0 else 0) - self.stop = stop if stop is not None else (self.data[-1] + 1 if self.data.size > 0 else 0) + self._src_data = data + self._start = start + self._stop = stop def __len__(self): return self.data.size + @property + def start(self): + return ( + self._start if self._start is not None else (self.data[0] if self.data.size > 0 else 0) + ) + + @property + def stop(self): + return ( + self._stop + if self._stop is not None + else (self.data[-1] + 1 if self.data.size > 0 else 0) + ) + + @property + def data(self): + return np.asarray(self._src_data) + def _with_data(self, data): raise NotImplementedError("Time tags do not currently support this operation") @@ -922,7 +999,10 @@ def _apply_mask(self, mask): @staticmethod def from_dataset(dset, y_label="y"): - return Slice(TimeTags(dset)) + return Slice( + TimeTags(LazyCache.from_h5py_dset(dset) if global_cache else dset), + labels={"title": dset.name.strip("/"), "y": y_label}, + ) def to_dataset(self, parent, name, **kwargs): """Save this to an h5 dataset diff --git a/lumicks/pylake/file.py b/lumicks/pylake/file.py index 94d66a08d..55073e550 100644 --- a/lumicks/pylake/file.py +++ b/lumicks/pylake/file.py @@ -1,3 +1,4 @@ +import pathlib import warnings from typing import Dict @@ -50,7 +51,7 @@ class File(Group, Force, DownsampledFD, BaselineCorrectedForce, PhotonCounts, Ph def __init__(self, filename, *, rgb_to_detectors=None): import h5py - super().__init__(h5py.File(filename, "r"), lk_file=self) + super().__init__(h5py.File(pathlib.Path(filename).absolute(), "r"), lk_file=self) self._check_file_format() self._rgb_to_detectors = self._get_detector_mapping(rgb_to_detectors) diff --git a/lumicks/pylake/tests/test_channels/test_channels.py b/lumicks/pylake/tests/test_channels/test_channels.py index 349d941cc..e5d9d361b 100644 --- a/lumicks/pylake/tests/test_channels/test_channels.py +++ b/lumicks/pylake/tests/test_channels/test_channels.py @@ -7,6 +7,7 @@ import pytest import matplotlib as mpl +import lumicks.pylake.channel from lumicks.pylake import channel from lumicks.pylake.low_level import make_continuous_slice from lumicks.pylake.calibration import ForceCalibrationList @@ -893,7 +894,10 @@ def test_annotation_bad_item(): def test_regression_lazy_loading(channel_h5_file): ch = channel.Continuous.from_dataset(channel_h5_file["Force HF"]["Force 1x"]) - assert isinstance(ch._src._src_data, h5py.Dataset) + if lumicks.pylake.channel.global_cache: + assert isinstance(ch._src._src_data._dset, h5py.Dataset) + else: + assert isinstance(ch._src._src_data, h5py.Dataset) @pytest.mark.parametrize( diff --git a/lumicks/pylake/tests/test_file/test_caching.py b/lumicks/pylake/tests/test_file/test_caching.py new file mode 100644 index 000000000..dfcd0f6d9 --- /dev/null +++ b/lumicks/pylake/tests/test_file/test_caching.py @@ -0,0 +1,65 @@ +import pytest + +from lumicks import pylake +from lumicks.pylake.channel import _get_array + + +def test_global_cache_continuous(h5_file): + pylake.set_cache_enabled(True) + _get_array.cache_clear() + + # Load the file (never storing the file handle) + f1x1 = pylake.File.from_h5py(h5_file)["Force HF/Force 1x"] + f1x2 = pylake.File.from_h5py(h5_file).force1x + assert _get_array.cache_info().hits == 0 # No cache used yet (lazy loading) + + # These should point to the same data + assert id(f1x1.data) == id(f1x2.data) + assert _get_array.cache_info().hits == 1 + assert _get_array.cache_info().currsize == 40 + + with pytest.raises(ValueError, match="assignment destination is read-only"): + f1x1.data[5:100] = 3 + + file = pylake.File.from_h5py(h5_file) + assert id(file.force1x.data) == id(file.force1x.data) + + +def test_global_cache_timeseries(h5_file): + pylake.set_cache_enabled(True) + _get_array.cache_clear() + + f1x1 = pylake.File.from_h5py(h5_file).downsampled_force1x + f1x2 = pylake.File.from_h5py(h5_file).downsampled_force1x + assert _get_array.cache_info().hits == 0 # No cache used yet (lazy loading) + + # These should point to the same data + assert id(f1x1.data) == id(f1x2.data) + assert _get_array.cache_info().hits == 1 + assert _get_array.cache_info().currsize == 16 + assert id(f1x1.timestamps) == id(f1x2.timestamps) + assert _get_array.cache_info().hits == 2 + assert _get_array.cache_info().currsize == 32 + + with pytest.raises(ValueError, match="assignment destination is read-only"): + f1x1.data[5:100] = 3 + + with pytest.raises(ValueError, match="assignment destination is read-only"): + f1x1.timestamps[5:100] = 3 + + +def test_global_cache_timetags(h5_file): + pylake.set_cache_enabled(True) + if pylake.File.from_h5py(h5_file).format_version == 2: + _get_array.cache_clear() + tags1 = pylake.File.from_h5py(h5_file)["Photon Time Tags"]["Red"] + tags2 = pylake.File.from_h5py(h5_file)["Photon Time Tags"]["Red"] + assert _get_array.cache_info().hits == 0 + + # These should point to the same data + assert id(tags1.data) == id(tags2.data) + assert _get_array.cache_info().hits == 1 + assert _get_array.cache_info().currsize == 72 + + with pytest.raises(ValueError, match="assignment destination is read-only"): + tags1.data[5:100] = 3 diff --git a/pyproject.toml b/pyproject.toml index 53447bb60..e10c2492b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ classifiers=[ ] dependencies = [ "pytest>=3.5", - "h5py>=3.4, <4", + "h5py>=3.8, <4", # Minimum bound needed for using __array__ on Dataset.fields() "numpy>=1.24, <2", # 1.24 is needed for dtype in vstack/hstack (Dec 18th, 2022) "scipy>=1.9, <2", # 1.9.0 needed for lazy imports (July 29th, 2022) "matplotlib>=3.8",