Skip to content

Commit

Permalink
🎨 add inner batch size suggestion
Browse files Browse the repository at this point in the history
  • Loading branch information
subercui committed Jun 10, 2022
1 parent a5a4b57 commit 42f97ce
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 11 deletions.
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ pip install deepvelo

### Using GPU

The `dgl` cpu version is installed by default. For GPU acceleration, please install the proper [dgl gpu](https://www.dgl.ai/pages/start.html) version compatible with your CUDA environment.
The `dgl` cpu version is installed by default. For GPU acceleration, please install a proper [dgl gpu](https://www.dgl.ai/pages/start.html) version compatible with your CUDA environment.

```bash
pip uninstall dgl # remove the cpu version
Expand Down Expand Up @@ -54,3 +54,9 @@ scv.pp.moments(adata, n_neighbors=30, n_pcs=30)
trainer = dv.train(adata, dv.Constants.default_configs)
# this will train the model and predict the velocity vectore. The result is stored in adata.layers['velocity']. You can use trainer.model to access the model.
```

### Fitting large number of cells

If you can not fit a large dataset into (GPU) memory using the default configs, please try setting a small `inner_batch_size` in the configs, which can reduce the memory usage and maintain the same performance.

Currently the training works on the whole graph of cells, we plan to release a flexible version using graph node sampling in the near future.
37 changes: 28 additions & 9 deletions deepvelo/model/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def integrate_mle(
idx: torch.LongTensor,
candidate_states: torch.Tensor,
n_spliced: int = None,
inner_batch_size: int = None,
*args,
**kwargs,
) -> torch.Tensor:
Expand All @@ -142,21 +143,37 @@ def integrate_mle(
The indices future nearest neighbors, (batch_size, num_neighbors).
candidate_states (torch.Tensor):
The states of potential future nearest neighbors, (all_cells, genes).
n_spliced (int):
The number of spliced genes.
inner_batch_size (int):
The batch size for the inner loop.
Returns: (torch.Tensor)
"""
batch_size, genes = current_state.shape
if n_spliced is not None:
genes = n_spliced
with torch.no_grad():
delta_tp1, delta_tm1 = _find_candidates(
output,
current_state,
idx,
candidate_states,
n_genes=genes,
inner_batch_size=int(5e8 / idx.shape[1] / current_state.shape[1]),
) # (batch_size, genes)
if inner_batch_size is None:
inner_batch_size = int(max(5e8 / idx.shape[1] / current_state.shape[1], 1))
elif inner_batch_size > 0:
inner_batch_size = int(inner_batch_size)
else:
raise ValueError("inner_batch_size must be a positive integer.")
try:
with torch.no_grad():
delta_tp1, delta_tm1 = _find_candidates(
output,
current_state,
idx,
candidate_states,
n_genes=genes,
inner_batch_size=inner_batch_size,
) # (batch_size, genes)
except RuntimeError as e:
raise RuntimeError(
f"The current inner batch size {inner_batch_size} is too large. "
"Try reducing the 'inner_batch_size' in congigs."
) from e
loss_tp1 = torch.mean(torch.pow(output - delta_tp1, 2))
loss_tm1 = torch.mean(torch.pow(output + delta_tm1, 2))
loss = (loss_tp1 + loss_tm1) / 2 * np.sqrt(genes)
Expand Down Expand Up @@ -324,6 +341,7 @@ def mle_plus_direction(
pearson_scale: float = 10.0,
coeff_u: float = 1.0,
coeff_s: float = 1.0,
inner_batch_size: Optional[int] = None,
) -> torch.Tensor:
"""
The combination of maximum likelihood estimation loss and direction loss.
Expand All @@ -347,6 +365,7 @@ def mle_plus_direction(
idx,
candidate_states,
n_spliced=velocity.shape[1],
inner_batch_size=inner_batch_size,
)
loss_pearson = direction_loss(
velocity,
Expand Down
1 change: 1 addition & 0 deletions deepvelo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class Constants(object, metaclass=MetaConstants):
"pearson_scale": 18.0,
"coeff_u": 1.0,
"coeff_s": 1.0,
"inner_batch_size": None, # if None, will autoset the size.
},
},
"constraint_loss": False,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "deepvelo"
version = "0.2.4"
version = "0.2.5-rc.1"
description = "Deep Velocity"
authors = ["subercui <[email protected]>"]
readme = "README.md"
Expand Down

0 comments on commit 42f97ce

Please sign in to comment.