Skip to content

Commit

Permalink
Backport PR #1321: (fix): empty boolean mask on backed sparse matrix (#…
Browse files Browse the repository at this point in the history
…1336)

Co-authored-by: Ilan Gold <[email protected]>
  • Loading branch information
meeseeksmachine and ilan-gold authored Jan 25, 2024
1 parent 9b3d50e commit 0082f17
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 5 deletions.
14 changes: 10 additions & 4 deletions anndata/_core/sparse_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ def _get_sliceXslice(self, row: slice, col: slice) -> ss.csr_matrix:

def _get_arrayXslice(self, row: Sequence[int], col: slice) -> ss.csr_matrix:
idxs = np.asarray(row)
if len(idxs) == 0:
return ss.csr_matrix((0, self.shape[1]))
if idxs.dtype == bool:
idxs = np.where(idxs)
return ss.csr_matrix(
Expand Down Expand Up @@ -214,6 +216,8 @@ def _get_sliceXslice(self, row: slice, col: slice) -> ss.csc_matrix:

def _get_sliceXarray(self, row: slice, col: Sequence[int]) -> ss.csc_matrix:
idxs = np.asarray(col)
if len(idxs) == 0:
return ss.csc_matrix((self.shape[0], 0))
if idxs.dtype == bool:
idxs = np.where(idxs)
return ss.csc_matrix(
Expand Down Expand Up @@ -290,10 +294,12 @@ def mean_slice_length(slices):
return floor((slices[-1].stop - slices[0].start) / len(slices))

# heuristic for whether slicing should be optimized
if mean_slice_length(slices) <= 7:
return get_compressed_vectors(mtx, np.where(mask)[0])
else:
return get_compressed_vectors_for_slices(mtx, slices)
if len(slices) > 0:
if mean_slice_length(slices) <= 7:
return get_compressed_vectors(mtx, np.where(mask)[0])
else:
return get_compressed_vectors_for_slices(mtx, slices)
return [], [], [0]


def get_format(data: ss.spmatrix) -> str:
Expand Down
25 changes: 24 additions & 1 deletion anndata/tests/test_backed_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ def diskfmt(request):
return request.param


M = 50
N = 50


@pytest.fixture(scope="function")
def ondisk_equivalent_adata(
tmp_path: Path, diskfmt: Literal["h5ad", "zarr"]
Expand All @@ -37,7 +41,7 @@ def ondisk_equivalent_adata(

write = lambda x, pth, **kwargs: getattr(x, f"write_{diskfmt}")(pth, **kwargs)

csr_mem = ad.AnnData(X=sparse.random(50, 50, format="csr", density=0.1))
csr_mem = ad.AnnData(X=sparse.random(M, N, format="csr", density=0.1))
csc_mem = ad.AnnData(X=csr_mem.X.tocsc())
dense_mem = ad.AnnData(X=csr_mem.X.toarray())

Expand Down Expand Up @@ -77,6 +81,25 @@ def callback(func, elem_name, elem, iospec):
return csr_mem, csr_disk, csc_disk, dense_disk


@pytest.mark.parametrize(
"empty_mask", [[], np.zeros(M, dtype=bool)], ids=["empty_list", "empty_bool_mask"]
)
def test_empty_backed_indexing(
ondisk_equivalent_adata: tuple[AnnData, AnnData, AnnData, AnnData],
empty_mask,
):
csr_mem, csr_disk, csc_disk, _ = ondisk_equivalent_adata

assert_equal(csr_mem.X[empty_mask], csr_disk.X[empty_mask])
assert_equal(csr_mem.X[:, empty_mask], csc_disk.X[:, empty_mask])

# The following do not work because of https://github.com/scipy/scipy/issues/19919
# Our implementation returns a (0,0) sized matrix but scipy does (1,0).

# assert_equal(csr_mem.X[empty_mask, empty_mask], csr_disk.X[empty_mask, empty_mask])
# assert_equal(csr_mem.X[empty_mask, empty_mask], csc_disk.X[empty_mask, empty_mask])


def test_backed_indexing(
ondisk_equivalent_adata: tuple[AnnData, AnnData, AnnData, AnnData],
subset_func,
Expand Down

0 comments on commit 0082f17

Please sign in to comment.