diff --git a/README.md b/README.md index eb74217..c174592 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ pip install deepvelo ### Using GPU -The `dgl` cpu version is installed by default. For GPU acceleration, please install the proper [dgl gpu](https://www.dgl.ai/pages/start.html) version compatible with your CUDA environment. +The `dgl` cpu version is installed by default. For GPU acceleration, please install a proper [dgl gpu](https://www.dgl.ai/pages/start.html) version compatible with your CUDA environment. ```bash pip uninstall dgl # remove the cpu version @@ -54,3 +54,9 @@ scv.pp.moments(adata, n_neighbors=30, n_pcs=30) trainer = dv.train(adata, dv.Constants.default_configs) # this will train the model and predict the velocity vectore. The result is stored in adata.layers['velocity']. You can use trainer.model to access the model. ``` + +### Fitting large number of cells + +If you can not fit a large dataset into (GPU) memory using the default configs, please try setting a small `inner_batch_size` in the configs, which can reduce the memory usage and maintain the same performance. + +Currently the training works on the whole graph of cells, we plan to release a flexible version using graph node sampling in the near future. diff --git a/deepvelo/model/loss.py b/deepvelo/model/loss.py index fd470f7..2cd5b67 100644 --- a/deepvelo/model/loss.py +++ b/deepvelo/model/loss.py @@ -126,6 +126,7 @@ def integrate_mle( idx: torch.LongTensor, candidate_states: torch.Tensor, n_spliced: int = None, + inner_batch_size: int = None, *args, **kwargs, ) -> torch.Tensor: @@ -142,21 +143,37 @@ def integrate_mle( The indices future nearest neighbors, (batch_size, num_neighbors). candidate_states (torch.Tensor): The states of potential future nearest neighbors, (all_cells, genes). + n_spliced (int): + The number of spliced genes. + inner_batch_size (int): + The batch size for the inner loop. Returns: (torch.Tensor) """ batch_size, genes = current_state.shape if n_spliced is not None: genes = n_spliced - with torch.no_grad(): - delta_tp1, delta_tm1 = _find_candidates( - output, - current_state, - idx, - candidate_states, - n_genes=genes, - inner_batch_size=int(5e8 / idx.shape[1] / current_state.shape[1]), - ) # (batch_size, genes) + if inner_batch_size is None: + inner_batch_size = int(max(5e8 / idx.shape[1] / current_state.shape[1], 1)) + elif inner_batch_size > 0: + inner_batch_size = int(inner_batch_size) + else: + raise ValueError("inner_batch_size must be a positive integer.") + try: + with torch.no_grad(): + delta_tp1, delta_tm1 = _find_candidates( + output, + current_state, + idx, + candidate_states, + n_genes=genes, + inner_batch_size=inner_batch_size, + ) # (batch_size, genes) + except RuntimeError as e: + raise RuntimeError( + f"The current inner batch size {inner_batch_size} is too large. " + "Try reducing the 'inner_batch_size' in congigs." + ) from e loss_tp1 = torch.mean(torch.pow(output - delta_tp1, 2)) loss_tm1 = torch.mean(torch.pow(output + delta_tm1, 2)) loss = (loss_tp1 + loss_tm1) / 2 * np.sqrt(genes) @@ -324,6 +341,7 @@ def mle_plus_direction( pearson_scale: float = 10.0, coeff_u: float = 1.0, coeff_s: float = 1.0, + inner_batch_size: Optional[int] = None, ) -> torch.Tensor: """ The combination of maximum likelihood estimation loss and direction loss. @@ -347,6 +365,7 @@ def mle_plus_direction( idx, candidate_states, n_spliced=velocity.shape[1], + inner_batch_size=inner_batch_size, ) loss_pearson = direction_loss( velocity, diff --git a/deepvelo/train.py b/deepvelo/train.py index d9328a9..64a11ff 100644 --- a/deepvelo/train.py +++ b/deepvelo/train.py @@ -53,6 +53,7 @@ class Constants(object, metaclass=MetaConstants): "pearson_scale": 18.0, "coeff_u": 1.0, "coeff_s": 1.0, + "inner_batch_size": None, # if None, will autoset the size. }, }, "constraint_loss": False, diff --git a/pyproject.toml b/pyproject.toml index 56705cd..1f5f15d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "deepvelo" -version = "0.2.4" +version = "0.2.5-rc.1" description = "Deep Velocity" authors = ["subercui "] readme = "README.md"