Skip to content

Commit

Permalink
Error on mixed device concatenation (#1156)
Browse files Browse the repository at this point in the history
  • Loading branch information
ivirshup authored Oct 4, 2023
1 parent 85162c7 commit afc5930
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 1 deletion.
6 changes: 6 additions & 0 deletions anndata/_core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,12 @@ def concat_arrays(arrays, reindexers, axis=0, index=None, fill_value=None):
elif any(isinstance(a, CupySparseMatrix) for a in arrays):
import cupyx.scipy.sparse as cpsparse

if not all(
isinstance(a, (CupySparseMatrix, CupyArray)) or 0 in a.shape for a in arrays
):
raise NotImplementedError(
"Cannot concatenate a cupy array with other array types."
)
sparse_stack = (cpsparse.vstack, cpsparse.hstack)[axis]
return sparse_stack(
[
Expand Down
43 changes: 42 additions & 1 deletion anndata/tests/test_concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections.abc import Hashable
from copy import deepcopy
from functools import partial, singledispatch
from itertools import chain, product
from itertools import chain, permutations, product
from typing import Any, Callable

import numpy as np
Expand Down Expand Up @@ -1454,3 +1454,44 @@ def test_concat_duplicated_columns(join_type):

with pytest.raises(pd.errors.InvalidIndexError, match=r"'a'"):
concat([a, b], join=join_type)


@pytest.mark.gpu
def test_error_on_mixed_device():
"""https://github.com/scverse/anndata/issues/1083"""
import cupy
import cupyx.scipy.sparse as cupy_sparse

cp_adata = AnnData(
cupy.random.randn(10, 10),
obs=pd.DataFrame(index=[f"cell_{i:02d}" for i in range(10)]),
)
cp_sparse_adata = AnnData(
cupy_sparse.random(10, 10, format="csr", density=0.2),
obs=pd.DataFrame(index=[f"cell_{i:02d}" for i in range(10, 20)]),
)
np_adata = AnnData(
np.random.randn(10, 10),
obs=pd.DataFrame(index=[f"cell_{i:02d}" for i in range(20, 30)]),
)
sparse_adata = AnnData(
sparse.random(10, 10, format="csr", density=0.2),
obs=pd.DataFrame(index=[f"cell_{i:02d}" for i in range(30, 40)]),
)

adatas = {
"cupy": cp_adata,
"cupy_sparse": cp_sparse_adata,
"numpy": np_adata,
"sparse": sparse_adata,
}

for p in map(dict, permutations(adatas.items())):
print(list(p.keys()))
with pytest.raises(
NotImplementedError, match="Cannot concatenate a cupy array with other"
):
concat(p)

for p in permutations([cp_adata, cp_sparse_adata]):
concat(p)

0 comments on commit afc5930

Please sign in to comment.