forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Autograd graphtask trim unnecessary edges (pytorch#82544)
### Introduction <!-- What did you change and why was it needed? --> Removing unnecessary weight gradient calculation is very important for applications that need high-order derivatives during training. However, this is not supported by the current Autograd engine. For more detail: The backward function of a `matmul` operator (e.g., `linear` `addmm` `mm`), has two matmuls, one for `input gradient` and another for `weight gradient`. For a typical neural network (nn) with a few linear layers and activation functions, if the user calls `torch.autograd.grad()` to calculate the derivative of the nn output `y` w.r.t the nn input `x`, only the `input gradient` of the `matmul` operator is needed, and the `weight gradient` is discarded. However, the current PyTorch autograd engine will always calculate the `weight gradient` if `weight` requires gradient (the calculation of the high-order derivative is performed during training). The figure attached shows the autograd graph of the following code snippet: ```py y = torch.nn.functional.linear(x, weight, bias) y = y.pow(2) # first order derivative y__x, = torch.autograd.grad(y, x, grad_outputs=grad_outputs, create_graph=True) # first order derivative y__x__x, = torch.autograd.grad(y__x, x, grad_outputs=grad_outputs, create_graph=True) ``` The path with ❌ is not needed when calculating derivatives. <img width="50%" alt="image" src="https://user-images.githubusercontent.com/9999318/182018117-719c5a23-bcc6-4a63-8e8d-1bca3ebda2e3.png"> ### Issue <!-- Link to Issue ticket or RFP --> Related issue: pytorch#56500 ### Method When calling `torch.autograd.grad`, `exec_info_` is created for each GraphTask, which allows filtering paths on the graph that are not needed. However, when the GraphTask calls into the node, the node still does not know whether the edges are needed or not. In the case of matmul, `weight.requires_grad is True` so the weight gradient is always calculated. Following pytorch#56500 (comment), this PR passes the graph task's thread_local `exec_info_` into the node, so it could trim unnecessary edges during `torch.autograd.grad` calls. ### Benchmark Benchmark script: https://gist.github.com/yueyericardo/24158433a2021c51eeef9c3e2722df99 Benchmark result: 6 hidden layers, batch size 10000, on A100 FP32 result | hessian benchmark | FP32 (before) | FP32 (After) | FP32 (Functorch v0.1.1) | | ----------------------------- | ------------- | ----------------- | ----------------------- | | Linear + ReLU (no backward) | 55.658 ms | 29.392 ms (1.90X) | 29.547 ms (1.90X) | | Linear + ReLU (with backward) | 81.173 ms | 54.917 ms (1.47X) | 68.988 ms (1.18X) | TF32 result | hessian benchmark | TF32 (before) | TF32 (after) | TF32 (Functorch v0.1.1) | | ----------------------------- | ------------- | ----------------- | ----------------------- | | Linear + ReLU (no backward) | 19.801 ms | 11.259 ms (1.76X) | 10.754 ms (1.84X) | | Linear + ReLU (with backward) | 29.167 ms | 20.466 ms (1.42X) | 22.784 ms (1.28X) | For FP32 result, we could get 1.9X speed up for hessian calculation, and 1.47X speed up during training, which is even faster than functorch `vmap(jacfwd(jacrev` implementation. (functorch has performance regression on v0.2.0, pytorch/functorch#989, so we are using v0.1.1 for benchmark) @zou3519 does functorch also includes similar optimizations during hessian calculation? If not, what do we need to do so the functorch could also benefit from this PR? ### Testing <!-- How did you test your change? --> - [x] we need to figure out a way for unittest ### Thanks Thanks for the great blog: [How Computational Graphs are Executed in PyTorch | PyTorch](https://pytorch.org/blog/how-computational-graphs-are-executed-in-pytorch/) cc @zasdfgbnm @albanD Pull Request resolved: pytorch#82544 Approved by: https://github.com/soulitzer
- Loading branch information
1 parent
d438e86
commit 382ef1f
Showing
15 changed files
with
402 additions
and
179 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.