Skip to content

Commit

Permalink
chore: update lifecycle tags (#509)
Browse files Browse the repository at this point in the history
* Update lifecycle tags for non-experimental Python API to "maturing"
* Update lifecycle tags for experimental Python API to "experimental"
* export public names for experimental ml package
  • Loading branch information
atolopko-czi authored May 30, 2023
1 parent e5b59d5 commit 1c264c4
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _get_experiment(census: soma.Collection, organism: str) -> soma.Experiment:
ValueError: if unable to find the specified organism.
Lifecycle:
Experimental.
maturing
Examples:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def get_anndata(
An :class:`anndata.AnnData` object containing the census slice.
Lifecycle:
Experimental.
maturing
Examples:
>>> get_anndata(census, "Mus musculus", obs_value_filter="tissue_general in ['brain', 'lung']")
Expand Down
6 changes: 3 additions & 3 deletions api/python/cellxgene_census/src/cellxgene_census/_open.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def open_soma(
or a version are specified.
Lifecycle:
Experimental.
maturing
Examples:
Open the default Census version, using a context manager which will automatically
Expand Down Expand Up @@ -169,7 +169,7 @@ def get_source_h5ad_uri(dataset_id: str, *, census_version: str = "latest") -> C
KeyError: if either `dataset_id` or `census_version` do not exist.
Lifecycle:
Experimental.
maturing
Examples:
>>> cellxgene_census.get_source_h5ad_uri("cb5efdb0-f91c-4cbd-9ad4-9d4fa41c572d")
Expand Down Expand Up @@ -206,7 +206,7 @@ def download_source_h5ad(dataset_id: str, to_path: str, *, census_version: str =
an existing file), or is not a file.
Lifecycle:
Experimental.
maturing
See Also:
:func:`get_source_h5ad_uri`: Look up the location of the source H5AD.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def get_presence_matrix(
ValueError: if the organism cannot be found.
Lifecycle:
Experimental.
maturing
Examples:
>>> get_presence_matrix(census, "Homo sapiens", "RNA")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def get_census_version_description(census_version: str) -> CensusVersionDescript
KeyError: if unknown census_version value.
Lifecycle:
Experimental.
maturing
See Also:
:func:`get_census_version_directory`: returns the entire directory as a dict.
Expand Down Expand Up @@ -83,7 +83,7 @@ def get_census_version_directory() -> Dict[CensusVersionName, CensusVersionDescr
A dictionary that contains release names and their corresponding release description.
Lifecycle:
Experimental.
maturing
See Also:
:func:`get_census_version_description`: get description by census_version.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""
An API to facilitate use of PyTorch ML training with data from the CZI Science CELLxGENE Census.
"""

from .pytorch import ExperimentDataPipe, Stats, experiment_dataloader

__all__ = ["Stats", "ExperimentDataPipe", "experiment_dataloader"]
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@

@attrs
class Stats:
"""
Statistics about the data retrieved by ``ExperimentDataPipe`` via SOMA API.
Lifecycle:
experimental
"""

n_obs: int = 0
"""The total number of obs rows retrieved"""

Expand Down Expand Up @@ -71,21 +78,21 @@ def _open_experiment(

class _ObsAndXIterator(Iterator[ObsDatum]):
"""
Iterates through a set of obs and related X rows, specified as `soma_joinid`s. Encapsulates the batch-based data
fetching from TileDB-SOMA objects, providing row-based iteration.
Iterates through a set of obs and related X rows, specified as ``soma_joinid``s. Encapsulates the batch-based data
fetching from SOMA objects, providing row-based iteration.
"""

obs_tables_iter: somacore.ReadIter[pa.Table]
"""Iterates the TileDB-SOMA batches (tables) of obs data"""
"""Iterates the SOMA batches (tables) of obs data"""

obs_batch_: pd.DataFrame = pd.DataFrame()
"""The current TileDB-SOMA batch of obs data"""
"""The current SOMA batch of obs data"""

X_batch: scipy.matrix = None
"""All X data for the soma_joinids of the current obs - batch"""
"""All X data for the ``soma_joinid``s of the current obs - batch"""

i: int = -1
"""Index into current obs TileDB-SOMA batch"""
"""Index into current obs ``SOMA`` batch"""

def __init__(
self,
Expand Down Expand Up @@ -177,7 +184,7 @@ def obs_batch(self) -> pd.DataFrame:
Returns the current SOMA batch of obs rows.
If the current SOMA batch has been fully iterated, loads the next SOMA batch of both obs and X data and returns
the new obs batch (only).
Raises StopIteration if there are no more SOMA batches to retrieve.
Raises ``StopIteration`` if there are no more SOMA batches to retrieve.
"""
if 0 <= self.i < len(self.obs_batch_):
return self.obs_batch_
Expand All @@ -203,30 +210,33 @@ def obs_batch(self) -> pd.DataFrame:

class ExperimentDataPipe(pipes.IterDataPipe[Dataset[ObsDatum]]): # type: ignore
"""
An iterable-style PyTorch data pipe that reads obs and X data from a SOMA Experiment, and returns an iterator of
tuples of torch tensors:
An iterable-style PyTorch ``DataPipe`` that reads obs and X data from a SOMA Experiment, and returns an iterator of
tuples of PyTorch ``Tensor``s:
(tensor([0., 0., 0., 0., 0., 1., 0., 0., 0.]), # X data
tensor([2415, 0, 0], dtype=torch.int32)) # obs data, encoded
>>> (tensor([0., 0., 0., 0., 0., 1., 0., 0., 0.]), # X data
tensor([2415, 0, 0], dtype=torch.int32)) # obs data, encoded
Supports batching via `batch_size` param:
DataLoader(..., batch_size=3, ...):
(tensor([[0., 0., 0., 0., 0., 1., 0., 0., 0.], # X batch
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.]]),
tensor([[2415, 0, 0], # obs batch
[2416, 0, 4],
[2417, 0, 3]], dtype=torch.int32))
>>> DataLoader(..., batch_size=3, ...):
(tensor([[0., 0., 0., 0., 0., 1., 0., 0., 0.], # X batch
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.]]),
tensor([[2415, 0, 0], # obs batch
[2416, 0, 4],
[2417, 0, 3]], dtype=torch.int32))
Obs attribute values are encoded as categoricals. Values can be decoded by obtaining the encoder for a given
attribute:
exp_data_pipe.obs_encoders()["<obs_attr_name>"].inverse_transform(encoded_values)
>>> exp_data_pipe.obs_encoders()["<obs_attr_name>"].inverse_transform(encoded_values)
Lifecycle:
experimental
"""

_query: Optional[soma.ExperimentAxisQuery]
"""In multi-processing mode (i.e. num_workers > 0), this ExperimentAxisQuery object will *not* be pickled;
"""In multi-processing mode (i.e. num_workers > 0), this ``ExperimentAxisQuery`` object will *not* be pickled;
each worker will instantiate a new query"""

_obs_joinids_partitioned: Optional[pa.Array]
Expand All @@ -252,6 +262,16 @@ def __init__(
num_workers: int = 0,
soma_buffer_bytes: Optional[int] = None,
) -> None:
"""
Construct a new ``ExperimentDataPipe``.
Returns:
``ExperimentDataPipe``.
Lifecycle:
experimental
"""

self.exp_uri = experiment.uri
self.aws_region = experiment.context.tiledb_ctx.config().get("vfs.s3.region")
self.measurement_name = measurement_name
Expand Down Expand Up @@ -356,6 +376,15 @@ def __setstate__(self, state: Dict[str, Any]) -> None:
self._query = None

def obs_encoders(self) -> Encoders:
"""
Returns the encoders that were used to encode obs column values and that are needed to decode them.
Returns:
``Dict[str, LabelEncoder]`` mapping column names to ``LabelEncoder``s.
Lifecycle:
experimental
"""
if self._encoders is not None:
return self._encoders

Expand All @@ -373,18 +402,41 @@ def obs_encoders(self) -> Encoders:
return self._encoders

def stats(self) -> Stats:
"""
Get data loading stats for this ``ExperimentDataPipe``.
Args: None.
Returns:
``Stats`` object.
Lifecycle:
experimental
"""
return self._stats

@property
def shape(self) -> Tuple[int, int]:
"""
Get the shape of the data that will be returned by this ExperimentDataPipe. This is the number of
obs (cell) and var (feature) counts in the returned data. If used in multiprocessing mode
(i.e. DataLoader instantiated with num_workers > 0), the obs (cell) count will reflect the size of the
partition of the data assigned to the active process.
Returns:
2-tuple of ``int``s, for obs and var counts, respectively.
Lifecycle:
experimental
"""
self._init()
assert self._query is not None

return self._query.n_obs, self._query.n_vars


# Note: must be a top-level function (and not a lambda), to play nice with multiprocessing pickling
def collate_noop(x: Any) -> Any:
def _collate_noop(x: Any) -> Any:
return x


Expand All @@ -395,9 +447,15 @@ def experiment_dataloader(
**dataloader_kwargs: Any,
) -> DataLoader:
"""
Factory method for PyTorch DataLoader. Provides a safer, more convenient interface for instantiating a DataLoader
that works with the ExperimentDataPipe, since not all of DataLoader's params can be used (batch_size, sampler,
batch_sampler, collate_fn).
Factory method for PyTorch ``DataLoader``. Provides a safer, more convenient interface for instantiating a
``DataLoader`` that works with the ``ExperimentDataPipe``, since not all of ``DataLoader``'s params can be
used (``batch_size``, ``sampler``, ``batch_sampler``, ``collate_fn``).
Returns:
PyTorch ``DataLoader``.
Lifecycle:
experimental
"""

unsupported_dataloader_args = ["batch_size", "sampler", "batch_sampler", "collate_fn"]
Expand All @@ -409,7 +467,7 @@ def experiment_dataloader(
batch_size=None, # batching is handled by our ExperimentDataPipe
num_workers=num_workers,
# avoid use of default collator, which adds an extra (3rd) dimension to the tensor batches
collate_fn=collate_noop,
collate_fn=_collate_noop,
**dataloader_kwargs,
)

Expand Down
6 changes: 3 additions & 3 deletions api/python/notebooks/ml_demo/pytorch_lr_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch

import cellxgene_census
from cellxgene_census.experimental.ml.pytorch import experiment_dataloader, ExperimentDataPipe
import cellxgene_census.experimental.ml as census_ml

# TODO: Convert this to a notebook

Expand Down Expand Up @@ -82,7 +82,7 @@ def run():
obs_value_filter = "tissue_general == 'tongue' and is_primary_data == True"
var_value_filter = ""

exp_dp = ExperimentDataPipe(
exp_dp = census_ml.ExperimentDataPipe(
census["census_data"]["homo_sapiens"],
measurement_name="RNA",
X_name="raw",
Expand All @@ -95,7 +95,7 @@ def run():
dp = exp_dp.shuffle(buffer_size=len(exp_dp))
dp_train, dp_test = dp.random_split(weights={"train": 0.7, "test": 0.3}, seed=RANDOM_SEED)

dl_train = experiment_dataloader(
dl_train = census_ml.experiment_dataloader(
dp_train,
# >= 1 uses multiprocessing to load data
num_workers=0,
Expand Down

0 comments on commit 1c264c4

Please sign in to comment.