Skip to content

Commit

Permalink
add no_grad context to reduce memory consumption
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed Feb 26, 2022
1 parent 24d9496 commit 51e5495
Show file tree
Hide file tree
Showing 8 changed files with 37 additions and 15 deletions.
3 changes: 2 additions & 1 deletion pytorch-lightning_ipynb/cnn/cnn-lenet5-quickdraw.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1247,6 +1247,7 @@
" \n",
" with torch.no_grad(): # since we don't need to backprop\n",
" logits = lightning_model(features)\n",
"\n",
" predicted_labels = torch.argmax(logits, dim=1)\n",
" all_predicted_labels.append(predicted_labels)\n",
" all_true_labels.append(labels)\n",
Expand Down Expand Up @@ -1502,7 +1503,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.12"
"version": "3.9.7"
}
},
"nbformat": 4,
Expand Down
7 changes: 5 additions & 2 deletions pytorch-lightning_ipynb/cnn/cnn-mobilenet-v2-cifar10-2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1725,7 +1725,10 @@
"all_predicted_labels = []\n",
"for batch in test_dataloader:\n",
" features, labels = batch\n",
" logits = lightning_model(features)\n",
" \n",
" with torch.no_grad():\n",
" logits = lightning_model(features)\n",
"\n",
" predicted_labels = torch.argmax(logits, dim=1)\n",
" all_predicted_labels.append(predicted_labels)\n",
" all_true_labels.append(labels)\n",
Expand Down Expand Up @@ -2106,7 +2109,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.12"
"version": "3.9.7"
}
},
"nbformat": 4,
Expand Down
7 changes: 5 additions & 2 deletions pytorch-lightning_ipynb/cnn/cnn-mobilenet-v2-cifar10.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1729,7 +1729,10 @@
"\n",
"for batch in test_dataloader:\n",
" features, labels = batch\n",
" logits = lightning_model(features)\n",
" \n",
" with torch.no_grad():\n",
" logits = lightning_model(features)\n",
"\n",
" predicted_labels = torch.argmax(logits, dim=1)\n",
" all_predicted_labels.append(predicted_labels)\n",
" all_true_labels.append(labels)\n",
Expand Down Expand Up @@ -2117,7 +2120,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.12"
"version": "3.9.7"
}
},
"nbformat": 4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3126,7 +3126,10 @@
"all_predicted_labels = []\n",
"for batch in test_dataloader:\n",
" features, labels = batch\n",
" logits = lightning_model(features)\n",
" \n",
" with torch.no_grad():\n",
" logits = lightning_model(features)\n",
"\n",
" predicted_labels = torch.argmax(logits, dim=1)\n",
" all_predicted_labels.append(predicted_labels)\n",
" all_true_labels.append(labels)\n",
Expand Down Expand Up @@ -3514,7 +3517,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.12"
"version": "3.9.7"
}
},
"nbformat": 4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1724,7 +1724,10 @@
"all_predicted_labels = []\n",
"for batch in test_dataloader:\n",
" features, labels = batch\n",
" logits = lightning_model(features)\n",
"\n",
" with torch.no_grad():\n",
" logits = lightning_model(features)\n",
"\n",
" predicted_labels = torch.argmax(logits, dim=1)\n",
" all_predicted_labels.append(predicted_labels)\n",
" all_true_labels.append(labels)\n",
Expand Down Expand Up @@ -2112,7 +2115,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.12"
"version": "3.9.7"
}
},
"nbformat": 4,
Expand Down
7 changes: 5 additions & 2 deletions pytorch-lightning_ipynb/cnn/cnn-nin-cifar10.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2435,7 +2435,10 @@
"all_predicted_labels = []\n",
"for batch in test_dataloader:\n",
" features, labels = batch\n",
" logits = lightning_model(features)\n",
" \n",
" with torch.no_grad():\n",
" logits = lightning_model(features)\n",
"\n",
" predicted_labels = torch.argmax(logits, dim=1)\n",
" all_predicted_labels.append(predicted_labels)\n",
" all_true_labels.append(labels)\n",
Expand Down Expand Up @@ -2823,7 +2826,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.12"
"version": "3.9.7"
}
},
"nbformat": 4,
Expand Down
7 changes: 5 additions & 2 deletions pytorch-lightning_ipynb/cnn/cnn-vgg16.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1507,7 +1507,10 @@
"all_predicted_labels = []\n",
"for batch in test_dataloader:\n",
" features, labels = batch\n",
" logits = lightning_model(features)\n",
" \n",
" with torch.no_grad():\n",
" logits = lightning_model(features)\n",
"\n",
" predicted_labels = torch.argmax(logits, dim=1)\n",
" all_predicted_labels.append(predicted_labels)\n",
" all_true_labels.append(labels)\n",
Expand Down Expand Up @@ -1895,7 +1898,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.12"
"version": "3.9.7"
}
},
"nbformat": 4,
Expand Down
7 changes: 5 additions & 2 deletions pytorch-lightning_ipynb/cnn/cnn-vgg19.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1533,7 +1533,10 @@
"all_predicted_labels = []\n",
"for batch in test_dataloader:\n",
" features, labels = batch\n",
" logits = lightning_model(features)\n",
" \n",
" with torch.no_grad():\n",
" logits = lightning_model(features)\n",
"\n",
" predicted_labels = torch.argmax(logits, dim=1)\n",
" all_predicted_labels.append(predicted_labels)\n",
" all_true_labels.append(labels)\n",
Expand Down Expand Up @@ -1928,7 +1931,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.12"
"version": "3.9.7"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 51e5495

Please sign in to comment.