diff --git a/anndata/_core/merge.py b/anndata/_core/merge.py index 0ecc9e913..3bb9970d5 100644 --- a/anndata/_core/merge.py +++ b/anndata/_core/merge.py @@ -765,6 +765,12 @@ def concat_arrays(arrays, reindexers, axis=0, index=None, fill_value=None): elif any(isinstance(a, CupySparseMatrix) for a in arrays): import cupyx.scipy.sparse as cpsparse + if not all( + isinstance(a, (CupySparseMatrix, CupyArray)) or 0 in a.shape for a in arrays + ): + raise NotImplementedError( + "Cannot concatenate a cupy array with other array types." + ) sparse_stack = (cpsparse.vstack, cpsparse.hstack)[axis] return sparse_stack( [ diff --git a/anndata/tests/test_concatenate.py b/anndata/tests/test_concatenate.py index 78caa8a19..ac9f7b0a9 100644 --- a/anndata/tests/test_concatenate.py +++ b/anndata/tests/test_concatenate.py @@ -4,7 +4,7 @@ from collections.abc import Hashable from copy import deepcopy from functools import partial, singledispatch -from itertools import chain, product +from itertools import chain, permutations, product from typing import Any, Callable import numpy as np @@ -1454,3 +1454,44 @@ def test_concat_duplicated_columns(join_type): with pytest.raises(pd.errors.InvalidIndexError, match=r"'a'"): concat([a, b], join=join_type) + + +@pytest.mark.gpu +def test_error_on_mixed_device(): + """https://github.com/scverse/anndata/issues/1083""" + import cupy + import cupyx.scipy.sparse as cupy_sparse + + cp_adata = AnnData( + cupy.random.randn(10, 10), + obs=pd.DataFrame(index=[f"cell_{i:02d}" for i in range(10)]), + ) + cp_sparse_adata = AnnData( + cupy_sparse.random(10, 10, format="csr", density=0.2), + obs=pd.DataFrame(index=[f"cell_{i:02d}" for i in range(10, 20)]), + ) + np_adata = AnnData( + np.random.randn(10, 10), + obs=pd.DataFrame(index=[f"cell_{i:02d}" for i in range(20, 30)]), + ) + sparse_adata = AnnData( + sparse.random(10, 10, format="csr", density=0.2), + obs=pd.DataFrame(index=[f"cell_{i:02d}" for i in range(30, 40)]), + ) + + adatas = { + "cupy": cp_adata, + "cupy_sparse": cp_sparse_adata, + "numpy": np_adata, + "sparse": sparse_adata, + } + + for p in map(dict, permutations(adatas.items())): + print(list(p.keys())) + with pytest.raises( + NotImplementedError, match="Cannot concatenate a cupy array with other" + ): + concat(p) + + for p in permutations([cp_adata, cp_sparse_adata]): + concat(p)