Skip to content

Commit

Permalink
Merge pull request #2 from bowang-lab/1-getting-runtime-error-when-ru…
Browse files Browse the repository at this point in the history
…nning-on-cuda-102

🔧 fix support gpu dgl>0.6 by explicity convert graph device
  • Loading branch information
subercui authored Jun 10, 2022
2 parents e0e0a6d + 9bc06a6 commit a5a4b57
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ pip install deepvelo
The `dgl` cpu version is installed by default. For GPU acceleration, please install the proper [dgl gpu](https://www.dgl.ai/pages/start.html) version compatible with your CUDA environment.

```bash
pip uninstall dgl # [optional] remove the cpu version
pip install dgl-cu101>=0.4.3 # an example for CUDA 10.1
pip uninstall dgl # remove the cpu version
# replace cu101 with your desired CUDA version and run the following
pip install "dgl-cu101>=0.4.3,!=0.8.0.post1"

```

### Install the development version
Expand Down
10 changes: 7 additions & 3 deletions deepvelo/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,15 @@ def __init__(
else:
self.layers.append(GraphConv(layers[-1], out_layer_dim))

def forward(self, x_u, x_s):
def to(self, device):
"""
right now it is jus mlp, and the complexity of the middle part does not make sense;
Change it to the attention model and constrain the information flow
Move the model and graph to the specified device.
"""
super().to(device)
self.g = self.g.to(device)
return self

def forward(self, x_u, x_s):
batch, n_gene = x_u.shape
# h should be (batch, features=2*n_gene)
h = torch.cat([x_u, x_s], dim=1) # features
Expand Down

0 comments on commit a5a4b57

Please sign in to comment.