diff --git a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py index b1c015836..19eedf122 100644 --- a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py +++ b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py @@ -20,7 +20,6 @@ from numpy.random import Generator from scipy import sparse from sklearn.preprocessing import LabelEncoder -from somacore.query import _fast_csr from torch import Tensor from torch import distributed as dist from torch.utils.data import DataLoader @@ -122,7 +121,7 @@ def __init__( obs: soma.DataFrame, X: soma.SparseNDArray, obs_column_names: Sequence[str], - obs_joinids_chunked: npt.NDArray[np.int64], # 2D + obs_joinids_chunked: List[npt.NDArray[np.int64]], var_joinids: npt.NDArray[np.int64], shuffle_rng: Optional[Generator] = None, ): @@ -134,7 +133,7 @@ def __init__( @staticmethod def _maybe_local_shuffle_obs_joinids( - obs_joinids_chunked: npt.NDArray[np.int64], shuffle_rng: Optional[Generator] = None + obs_joinids_chunked: List[npt.NDArray[np.int64]], shuffle_rng: Optional[Generator] = None ) -> Iterator[npt.NDArray[np.int64]]: return ( shuffle_rng.permutation(obs_joinid_chunk) if shuffle_rng else obs_joinid_chunk @@ -145,7 +144,7 @@ def __next__(self) -> _SOMAChunk: pytorch_logger.debug("Retrieving next SOMA chunk...") start_time = time() - # If no more batches to iterate through, raise StopIteration, as all iterators do when at end + # If no more chunks to iterate through, raise StopIteration, as all iterators do when at end obs_joinids_chunk = next(self.obs_joinids_chunks_iter) obs_batch = ( @@ -166,8 +165,14 @@ def __next__(self) -> _SOMAChunk: # reorder obs rows to match obs_joinids_chunk ordering, which may be shuffled obs_batch = obs_batch.reindex(obs_joinids_chunk, copy=False) - # note: order of rows in returned CSR matches the order of the requested obs_joinids, so no need to reindex - X_batch = _fast_csr.read_scipy_csr(self.X, pa.array(obs_joinids_chunk), pa.array(self.var_joinids)) + # note: the `blockwise` call is employed for its ability to reindex the axes of the sparse matrix, + # but the blockwise iteration feature is not used (block_size is set to retrieve the chunk as a single block) + scipy_iter = ( + self.X.read(coords=(obs_joinids_chunk, self.var_joinids)) + .blockwise(axis=0, size=len(obs_joinids_chunk), eager=False) + .scipy(compress=True) + ) + X_batch, _ = next(scipy_iter) assert obs_batch.shape[0] == X_batch.shape[0] stats = Stats() @@ -218,7 +223,7 @@ def __init__( obs: soma.DataFrame, X: soma.SparseNDArray, obs_column_names: Sequence[str], - obs_joinids_chunked: npt.NDArray[np.int64], # 2D + obs_joinids_chunked: List[npt.NDArray[np.int64]], var_joinids: npt.NDArray[np.int64], batch_size: int, encoders: Dict[str, LabelEncoder], @@ -499,14 +504,14 @@ def _init(self) -> None: @staticmethod def _subset_ids_to_partition( ids_chunked: List[npt.NDArray[np.int64]], partition_index: int, num_partitions: int - ) -> npt.NDArray[np.int64]: # 2D NDArray + ) -> List[npt.NDArray[np.int64]]: """Returns a single partition of the obs_joinids_chunked (a 2D ndarray), based upon the current process's distributed rank and world size.""" # subset to a single partition # typing does not reflect that is actually a List of 2D NDArrays - partitions: List[npt.NDArray[np.int64]] = np.array_split(np.array(ids_chunked, dtype=object), num_partitions) - partition = partitions[partition_index] + partition_indices = np.array_split(range(len(ids_chunked)), num_partitions) + partition = [ids_chunked[i] for i in partition_indices[partition_index]] if pytorch_logger.isEnabledFor(logging.DEBUG) and len(partition) > 0: pytorch_logger.debug( @@ -562,7 +567,7 @@ def __iter__(self) -> Iterator[ObsAndXDatum]: dist_partition=dist.get_rank() if dist.is_initialized() else 0, num_dist_partitions=dist.get_world_size() if dist.is_initialized() else 1, ) - obs_joinids_chunked_partition: npt.NDArray[np.int64] = self._subset_ids_to_partition( + obs_joinids_chunked_partition: List[npt.NDArray[np.int64]] = self._subset_ids_to_partition( obs_joinids_chunked, partition, partitions )