Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(fix): cache indptr for backed sparse matrices #1266

Merged
merged 21 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions anndata/_core/sparse_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import collections.abc as cabc
import warnings
from abc import ABC
from functools import cached_property
from itertools import accumulate, chain
from math import floor
from pathlib import Path
Expand All @@ -41,6 +42,8 @@
if TYPE_CHECKING:
from collections.abc import Iterable, Sequence

from .._types import GroupStorageType


class BackedFormat(NamedTuple):
format: str
Expand Down Expand Up @@ -138,7 +141,7 @@ def _offsets(
def _get_contiguous_compressed_slice(
self, s: slice
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
new_indptr = self.indptr[s.start : s.stop + 1]
new_indptr = self.indptr[s.start : s.stop + 1].copy()

start = new_indptr[0]
stop = new_indptr[-1]
Expand Down Expand Up @@ -325,13 +328,26 @@ def _get_group_format(group) -> str:
class BaseCompressedSparseDataset(ABC):
"""Analogous to :class:`h5py.Dataset <h5py:Dataset>` or `zarr.Array`, but for sparse matrices."""

def __init__(self, group: h5py.Group | ZarrGroup):
_group: GroupStorageType

def __init__(self, group: GroupStorageType):
type(self)._check_group_format(group)
self.group = group
self._group = group

shape: tuple[int, int]
"""Shape of the matrix."""

@property
def group(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def group(self):
def group(self) -> GroupStorageType:

Is this the change that was giving you problems with the docs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes!

"""The group underlying the backed matrix."""
return self._group

@group.setter
def group(self, val):
raise AttributeError(
f"Do not reset group on a {type(self)} with {val}. Instead use `sparse_dataset` to make a new class."
)

@property
def backend(self) -> Literal["zarr", "hdf5"]:
if isinstance(self.group, ZarrGroup):
Expand Down Expand Up @@ -489,20 +505,25 @@ def append(self, sparse_matrix: ss.spmatrix):
indices.resize((orig_data_size + sparse_matrix.indices.shape[0],))
indices[orig_data_size:] = sparse_matrix.indices

@cached_property
def indptr(self) -> np.ndarray:
arr = self.group["indptr"][...]
return arr

def _to_backed(self) -> BackedSparseMatrix:
format_class = get_backed_class(self.format)
mtx = format_class(self.shape, dtype=self.dtype)
mtx.data = self.group["data"]
mtx.indices = self.group["indices"]
mtx.indptr = self.group["indptr"][:]
mtx.indptr = self.indptr
return mtx

def to_memory(self) -> ss.spmatrix:
format_class = get_memory_class(self.format)
mtx = format_class(self.shape, dtype=self.dtype)
mtx.data = self.group["data"][...]
mtx.indices = self.group["indices"][...]
mtx.indptr = self.group["indptr"][...]
mtx.indptr = self.indptr
return mtx


Expand Down Expand Up @@ -530,7 +551,7 @@ class CSCDataset(BaseCompressedSparseDataset):
format = "csc"


def sparse_dataset(group: ZarrGroup | H5Group) -> CSRDataset | CSCDataset:
def sparse_dataset(group: GroupStorageType) -> CSRDataset | CSCDataset:
"""Generates a backed mode-compatible sparse dataset class.

Parameters
Expand Down
20 changes: 20 additions & 0 deletions anndata/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import numpy as np
import pandas as pd
import pytest
import zarr
from pandas.api.types import is_numeric_dtype
from scipy import sparse

Expand Down Expand Up @@ -743,3 +744,22 @@ def shares_memory_sparse(x, y):
marks=pytest.mark.gpu,
),
]


class AccessTrackingStore(zarr.DirectoryStore):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._access_count = {}

def __getitem__(self, key):
for tracked in self._access_count:
if tracked in key:
self._access_count[tracked] += 1
return super().__getitem__(key)

def get_access_count(self, key):
return self._access_count[key]

def set_key_trackers(self, keys_to_track):
for k in keys_to_track:
self._access_count[k] = 0
45 changes: 44 additions & 1 deletion anndata/tests/test_backed_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from anndata._core.anndata import AnnData
from anndata._core.sparse_dataset import sparse_dataset
from anndata.experimental import read_dispatched
from anndata.tests.helpers import assert_equal, subset_func
from anndata.tests.helpers import AccessTrackingStore, assert_equal, subset_func

if TYPE_CHECKING:
from pathlib import Path
Expand Down Expand Up @@ -199,6 +199,34 @@ def test_dataset_append_disk(
assert_equal(fromdisk, frommem)


@pytest.mark.parametrize(
["sparse_format"],
[
pytest.param(sparse.csr_matrix),
pytest.param(sparse.csc_matrix),
],
)
def test_indptr_cache(
tmp_path: Path,
sparse_format: Callable[[ArrayLike], sparse.spmatrix],
):
path = tmp_path / "test.zarr" # diskfmt is either h5ad or zarr
a = sparse_format(sparse.random(10, 10))
f = zarr.open_group(path, "a")
ad._io.specs.write_elem(f, "X", a)
store = AccessTrackingStore(path)
store.set_key_trackers(["X/indptr"])
f = zarr.open_group(store, "a")
a_disk = sparse_dataset(f["X"])
a_disk[:1]
a_disk[3:5]
a_disk[6:7]
a_disk[8:9]
assert (
store.get_access_count("X/indptr") == 2
) # one each for .zarray and actual access


@pytest.mark.parametrize(
["sparse_format", "a_shape", "b_shape"],
[
Expand Down Expand Up @@ -233,6 +261,21 @@ def test_wrong_shape(
a_disk.append(b_disk)


def test_reset_group(tmp_path: Path):
path = tmp_path / "test.zarr" # diskfmt is either h5ad or zarr
base = sparse.random(100, 100, format="csr")

if diskfmt == "zarr":
f = zarr.open_group(path, "a")
else:
f = h5py.File(path, "a")

ad._io.specs.write_elem(f, "base", base)
disk_mtx = sparse_dataset(f["base"])
with pytest.raises(AttributeError):
disk_mtx.group = f


def test_wrong_formats(tmp_path: Path, diskfmt: Literal["h5ad", "zarr"]):
path = (
tmp_path / f"test.{diskfmt.replace('ad', '')}"
Expand Down
1 change: 1 addition & 0 deletions docs/release-notes/0.10.5.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@
```{rubric} Performance
```

* `BaseCompressedSparseDataset`'s `indptr` is cached {pr}`1266` {user}`ilan-gold`
* Improved performance when indexing backed sparse matrices with boolean masks along their major axis {pr}`1233` {user}`ilan-gold`
Loading