Skip to content

Commit

Permalink
fix pytorch tiledbsoma dependency (#957)
Browse files Browse the repository at this point in the history
* fix pytorch dependency on private tiledbsoma that is no longer available in 1.7.0
* impl switched to use blockwise iterator for retrieval of X chunks as csr matrix
* modify np.array_split usage to ensure dtype remains as int64, while avoiding numpy warning for ragged array
  • Loading branch information
atolopko-czi authored Jan 26, 2024
1 parent f9d72f7 commit 67982f1
Showing 1 changed file with 16 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from numpy.random import Generator
from scipy import sparse
from sklearn.preprocessing import LabelEncoder
from somacore.query import _fast_csr
from torch import Tensor
from torch import distributed as dist
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -122,7 +121,7 @@ def __init__(
obs: soma.DataFrame,
X: soma.SparseNDArray,
obs_column_names: Sequence[str],
obs_joinids_chunked: npt.NDArray[np.int64], # 2D
obs_joinids_chunked: List[npt.NDArray[np.int64]],
var_joinids: npt.NDArray[np.int64],
shuffle_rng: Optional[Generator] = None,
):
Expand All @@ -134,7 +133,7 @@ def __init__(

@staticmethod
def _maybe_local_shuffle_obs_joinids(
obs_joinids_chunked: npt.NDArray[np.int64], shuffle_rng: Optional[Generator] = None
obs_joinids_chunked: List[npt.NDArray[np.int64]], shuffle_rng: Optional[Generator] = None
) -> Iterator[npt.NDArray[np.int64]]:
return (
shuffle_rng.permutation(obs_joinid_chunk) if shuffle_rng else obs_joinid_chunk
Expand All @@ -145,7 +144,7 @@ def __next__(self) -> _SOMAChunk:
pytorch_logger.debug("Retrieving next SOMA chunk...")
start_time = time()

# If no more batches to iterate through, raise StopIteration, as all iterators do when at end
# If no more chunks to iterate through, raise StopIteration, as all iterators do when at end
obs_joinids_chunk = next(self.obs_joinids_chunks_iter)

obs_batch = (
Expand All @@ -166,8 +165,14 @@ def __next__(self) -> _SOMAChunk:
# reorder obs rows to match obs_joinids_chunk ordering, which may be shuffled
obs_batch = obs_batch.reindex(obs_joinids_chunk, copy=False)

# note: order of rows in returned CSR matches the order of the requested obs_joinids, so no need to reindex
X_batch = _fast_csr.read_scipy_csr(self.X, pa.array(obs_joinids_chunk), pa.array(self.var_joinids))
# note: the `blockwise` call is employed for its ability to reindex the axes of the sparse matrix,
# but the blockwise iteration feature is not used (block_size is set to retrieve the chunk as a single block)
scipy_iter = (
self.X.read(coords=(obs_joinids_chunk, self.var_joinids))
.blockwise(axis=0, size=len(obs_joinids_chunk), eager=False)
.scipy(compress=True)
)
X_batch, _ = next(scipy_iter)
assert obs_batch.shape[0] == X_batch.shape[0]

stats = Stats()
Expand Down Expand Up @@ -218,7 +223,7 @@ def __init__(
obs: soma.DataFrame,
X: soma.SparseNDArray,
obs_column_names: Sequence[str],
obs_joinids_chunked: npt.NDArray[np.int64], # 2D
obs_joinids_chunked: List[npt.NDArray[np.int64]],
var_joinids: npt.NDArray[np.int64],
batch_size: int,
encoders: Dict[str, LabelEncoder],
Expand Down Expand Up @@ -499,14 +504,14 @@ def _init(self) -> None:
@staticmethod
def _subset_ids_to_partition(
ids_chunked: List[npt.NDArray[np.int64]], partition_index: int, num_partitions: int
) -> npt.NDArray[np.int64]: # 2D NDArray
) -> List[npt.NDArray[np.int64]]:
"""Returns a single partition of the obs_joinids_chunked (a 2D ndarray), based upon the current process's distributed rank and world
size."""

# subset to a single partition
# typing does not reflect that is actually a List of 2D NDArrays
partitions: List[npt.NDArray[np.int64]] = np.array_split(np.array(ids_chunked, dtype=object), num_partitions)
partition = partitions[partition_index]
partition_indices = np.array_split(range(len(ids_chunked)), num_partitions)
partition = [ids_chunked[i] for i in partition_indices[partition_index]]

if pytorch_logger.isEnabledFor(logging.DEBUG) and len(partition) > 0:
pytorch_logger.debug(
Expand Down Expand Up @@ -562,7 +567,7 @@ def __iter__(self) -> Iterator[ObsAndXDatum]:
dist_partition=dist.get_rank() if dist.is_initialized() else 0,
num_dist_partitions=dist.get_world_size() if dist.is_initialized() else 1,
)
obs_joinids_chunked_partition: npt.NDArray[np.int64] = self._subset_ids_to_partition(
obs_joinids_chunked_partition: List[npt.NDArray[np.int64]] = self._subset_ids_to_partition(
obs_joinids_chunked, partition, partitions
)

Expand Down

0 comments on commit 67982f1

Please sign in to comment.