Skip to content

Commit

Permalink
Autograd graphtask trim unnecessary edges (pytorch#82544)
Browse files Browse the repository at this point in the history
### Introduction
<!-- What did you change and why was it needed? -->

Removing unnecessary weight gradient calculation is very important for applications that need high-order derivatives during training. However, this is not supported by the current Autograd engine.

For more detail: The backward function of a `matmul` operator (e.g., `linear` `addmm` `mm`), has two matmuls, one for `input gradient` and another for `weight gradient`. For a typical neural network (nn) with a few linear layers and activation functions, if the user calls `torch.autograd.grad()` to calculate the derivative of the nn output `y` w.r.t the nn input `x`,  only the `input gradient` of the `matmul` operator is needed, and the `weight gradient` is discarded. However, the current PyTorch autograd engine will always calculate the `weight gradient` if `weight` requires gradient (the calculation of the high-order derivative is performed during training).

The figure attached shows the autograd graph of the following code snippet:
```py
y = torch.nn.functional.linear(x, weight, bias)
y = y.pow(2)
# first order derivative
y__x, = torch.autograd.grad(y, x, grad_outputs=grad_outputs, create_graph=True)
# first order derivative
y__x__x, = torch.autograd.grad(y__x, x, grad_outputs=grad_outputs, create_graph=True)
```
The path with ❌ is not needed when calculating derivatives.

<img width="50%" alt="image" src="https://user-images.githubusercontent.com/9999318/182018117-719c5a23-bcc6-4a63-8e8d-1bca3ebda2e3.png">

### Issue
<!-- Link to Issue ticket or RFP -->
Related issue: pytorch#56500

### Method
When calling `torch.autograd.grad`, `exec_info_` is created for each GraphTask, which allows filtering paths on the graph that are not needed. However, when the GraphTask calls into the node, the node still does not know whether the edges are needed or not. In the case of matmul, `weight.requires_grad is True` so the weight gradient is always calculated.

Following pytorch#56500 (comment), this PR passes the graph task's thread_local `exec_info_` into the node, so it could trim unnecessary edges during `torch.autograd.grad` calls.

### Benchmark
Benchmark script: https://gist.github.com/yueyericardo/24158433a2021c51eeef9c3e2722df99

Benchmark result:
6 hidden layers, batch size 10000, on A100

FP32 result
| hessian benchmark             | FP32 (before) | FP32 (After)      | FP32 (Functorch v0.1.1) |
| ----------------------------- | ------------- | ----------------- | ----------------------- |
| Linear + ReLU (no backward)   | 55.658 ms     | 29.392 ms (1.90X) | 29.547 ms (1.90X)       |
| Linear + ReLU (with backward) | 81.173 ms     | 54.917 ms (1.47X) | 68.988 ms (1.18X)       |

TF32 result
| hessian benchmark             | TF32 (before) | TF32 (after)      | TF32 (Functorch v0.1.1) |
| ----------------------------- | ------------- | ----------------- | ----------------------- |
| Linear + ReLU (no backward)   | 19.801 ms     | 11.259 ms (1.76X) | 10.754 ms (1.84X)       |
| Linear + ReLU (with backward) | 29.167 ms     | 20.466 ms (1.42X) | 22.784 ms (1.28X)       |

For FP32 result, we could get 1.9X speed up for hessian calculation, and 1.47X speed up during training, which is even faster than functorch `vmap(jacfwd(jacrev` implementation. (functorch has performance regression on v0.2.0, pytorch/functorch#989, so we are using v0.1.1 for benchmark)

@zou3519 does functorch also includes similar optimizations during hessian calculation? If not, what do we need to do so the functorch could also benefit from this PR?

### Testing
<!-- How did you test your change? -->

- [x] we need to figure out a way for unittest

### Thanks
Thanks for the great blog: [How Computational Graphs are Executed in PyTorch | PyTorch](https://pytorch.org/blog/how-computational-graphs-are-executed-in-pytorch/)

cc @zasdfgbnm @albanD
Pull Request resolved: pytorch#82544
Approved by: https://github.com/soulitzer
  • Loading branch information
yueyericardo authored and pytorchmergebot committed Aug 11, 2022
1 parent d438e86 commit 382ef1f
Show file tree
Hide file tree
Showing 15 changed files with 402 additions and 179 deletions.
3 changes: 2 additions & 1 deletion functorch/functorch/csrc/CustomFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <functorch/csrc/CustomFunction.h>
#include <ATen/ATen.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/graph_task.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/autograd/saved_variable.h>
#include <torch/csrc/autograd/FunctionsManual.h>
Expand Down Expand Up @@ -192,7 +193,7 @@ variable_list GenericPythonBackward::apply(variable_list&& grads) {
args.emplace_back(saved.unpack(shared_from_this()));
}

if (should_compute_output({ tensors_ix })) {
if (task_should_compute_output({ tensors_ix })) {
auto handle = backward_fn_->typed<custom_function_t>();
auto grad_result = handle.call(args);
grad_inputs = grad_result;
Expand Down
97 changes: 97 additions & 0 deletions test/cpp/api/autograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <ATen/core/op_registration/op_registration.h>
#include <torch/torch.h>

#include <torch/csrc/autograd/FunctionsManual.h>
#include <torch/csrc/autograd/functions/basic_ops.h>

#include <test/cpp/api/support.h>
Expand Down Expand Up @@ -276,6 +277,102 @@ TEST(CustomAutogradTest, CustomFunction) {
ASSERT_VARIABLE_EQ(y.grad(), x + torch::ones({5, 5}) * 2);
}

TEST(CustomAutogradTest, GraphTaskTrimEdges) {
struct MyFunction : public Function<MyFunction> {
static Variable forward(
AutogradContext* ctx,
Variable var1,
Variable var2,
int mul,
bool needs_input1_grad,
bool needs_input2_grad) {
// setup the expected should and should not compute idx
ctx->saved_data["needs_input1_grad"] = needs_input1_grad;
ctx->saved_data["needs_input2_grad"] = needs_input2_grad;

ctx->saved_data["mul"] = mul;
ctx->save_for_backward({var1, var2});
return var1 + mul * var2 + var1 * var2;
}

static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
// Test `needs_input_grad` method is working correctly.
// We have to test this within the backward function.
auto needs_input1_grad = ctx->saved_data["needs_input1_grad"].toBool();
auto needs_input2_grad = ctx->saved_data["needs_input2_grad"].toBool();
IndexRange var1_idx = {0, 1};
IndexRange var2_idx = {1, 2};
EXPECT_EQ(ctx->needs_input_grad(0), needs_input1_grad);
EXPECT_EQ(ctx->needs_input_grad(1), needs_input2_grad);
EXPECT_EQ(ctx->needs_input_grad({var1_idx}), needs_input1_grad);
EXPECT_EQ(ctx->needs_input_grad({var2_idx}), needs_input2_grad);
EXPECT_EQ(
ctx->needs_input_grad({var1_idx, var2_idx}),
needs_input1_grad || needs_input2_grad);

// calculate gradients
int mul = ctx->saved_data["mul"].toInt();
auto saved = ctx->get_saved_variables();
auto var1 = saved[0];
auto var2 = saved[1];

Variable grad_var1, grad_var2;
if (ctx->needs_input_grad(0)) {
grad_var1 = grad_output[0] + grad_output[0] * var2;
}
if (ctx->needs_input_grad(1)) {
grad_var2 = grad_output[0] * mul + grad_output[0] * var1;
}
variable_list output = {
grad_var1,
grad_var2,
Variable(),
Variable(),
Variable(),
};
return output;
}
};

Variable x = torch::randn({5, 5}, torch::requires_grad());
Variable y = torch::randn({5, 5}, torch::requires_grad());
auto go = torch::ones_like(x);
Variable out;

// grad_x
out = MyFunction::apply(
x,
y,
2,
/* needs_input1_grad= */ true,
/* needs_input2_grad= */ false);
auto grad_x = torch::autograd::grad({out}, {x}, {go})[0];
ASSERT_VARIABLE_EQ(grad_x, y + torch::ones({5, 5}));

// grad_y
out = MyFunction::apply(
x,
y,
2,
/* needs_input1_grad= */ false,
/* needs_input2_grad= */ true);
auto grad_y = torch::autograd::grad({out}, {y}, {go})[0];
ASSERT_VARIABLE_EQ(grad_y, x + torch::ones({5, 5}) * 2);

// grad_x and grad_y
out = MyFunction::apply(
x,
y,
2,
/* needs_input1_grad= */ true,
/* needs_input2_grad= */ true);
auto grads = torch::autograd::grad({out}, {x, y}, {go});
ASSERT_VARIABLE_EQ(grads[0], y + torch::ones({5, 5}));
ASSERT_VARIABLE_EQ(grads[1], x + torch::ones({5, 5}) * 2);
}

TEST(CustomAutogradTest, FunctionReturnsInput) {
struct MyFunction : public Function<MyFunction> {
static Variable forward(AutogradContext* ctx, Variable var1) {
Expand Down
10 changes: 6 additions & 4 deletions tools/autograd/gen_autograd_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@

DERIVATIVE_SINGLE = CodeTemplate(
"""\
if (should_compute_output({ ${name}_ix })) {
if (task_should_compute_output({ ${name}_ix })) {
auto grad_result = ${derivative};
copy_range(grad_inputs, ${name}_ix, grad_result);
}
Expand All @@ -96,15 +96,15 @@

DERIVATIVE_MULTI_COPY_RANGE = CodeTemplate(
"""\
if (should_compute_output({ ${name}_ix })) {
if (task_should_compute_output({ ${name}_ix })) {
copy_range(grad_inputs, ${name}_ix, std::get<${i}>(grad_result));
}
"""
)

DERIVATIVE_MULTI = CodeTemplate(
"""\
if (should_compute_output({ ${idx_ranges} })) {
if (task_should_compute_output({ ${idx_ranges} })) {
${grad_input_mask}
auto grad_result = ${derivative};
${copy_ranges}
Expand Down Expand Up @@ -673,7 +673,9 @@ def emit_derivative(
)
else:
if "grad_input_mask" in formula:
masks = [f"should_compute_output({{ {n}_ix }})," for n in var_names]
masks = [
f"task_should_compute_output({{ {n}_ix }})," for n in var_names
]
grad_input_mask = GRAD_INPUT_MASK.substitute(
masks=masks, n=len(var_names)
)
Expand Down
6 changes: 6 additions & 0 deletions torch/csrc/autograd/autograd.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
#include <torch/csrc/autograd/autograd.h>
#include <torch/csrc/autograd/variable.h>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/ones_like.h>
#endif

#include <torch/csrc/autograd/edge.h>
#include <torch/csrc/autograd/engine.h>
#include <torch/csrc/autograd/function.h>
Expand Down
13 changes: 13 additions & 0 deletions torch/csrc/autograd/custom_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,19 @@ variable_list AutogradContext::get_saved_variables() const {
return saved;
}

bool AutogradContext::needs_input_grad(size_t output_edge_index) const {
auto ptr = grad_fn_.lock();
TORCH_INTERNAL_ASSERT(ptr);
return ptr->task_should_compute_output(output_edge_index);
}

bool AutogradContext::needs_input_grad(
std::initializer_list<IndexRange> idxs) const {
auto ptr = grad_fn_.lock();
TORCH_INTERNAL_ASSERT(ptr);
return ptr->task_should_compute_output(idxs);
}

void AutogradContext::mark_dirty(const variable_list& inputs) {
dirty_inputs_.clear();
dirty_inputs_.reserve(inputs.size());
Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/autograd/custom_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@ struct TORCH_API AutogradContext {
const std::unordered_set<at::TensorImpl*>& get_and_bump_dirty() const;
const std::unordered_set<at::TensorImpl*>& get_non_differentiable() const;

/// Expose the Node's `task_should_compute_output` method to the cpp
/// custom autograd Function as `needs_input_grad`.
bool needs_input_grad(size_t output_edge_index) const;
bool needs_input_grad(std::initializer_list<IndexRange> idxs) const;

private:
std::unordered_set<at::TensorImpl*> non_differentiable_;
std::unordered_set<at::TensorImpl*> dirty_inputs_;
Expand Down
17 changes: 17 additions & 0 deletions torch/csrc/autograd/engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@
#include <ATen/Parallel.h>
#include <ATen/detail/CUDAHooksInterface.h>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/isnan.h>
#endif

#include <c10/core/DeviceGuard.h>
#include <c10/core/Event.h>
#include <c10/core/Stream.h>
Expand Down Expand Up @@ -368,6 +374,17 @@ void GraphTaskGuard::restore_current_graph_task() {
current_graph_task = std::move(last_graph_task_);
}

// The current graph task's exec_info is being used to trim unnecessary edegs
// during node evaluation, see `Node.task_should_compute_output()` function.
const std::unordered_map<Node*, GraphTask::ExecInfo>*
get_current_graph_task_exec_info() {
return current_graph_task ? &current_graph_task->exec_info_ : nullptr;
}

void add_node_to_current_graph_task_exec_info(Node* fn) {
current_graph_task->exec_info_[fn].needed_ = true;
}

// NOTE: graph_tasks do not necessarily form a stack. Imagine this
// case:
//
Expand Down
Loading

0 comments on commit 382ef1f

Please sign in to comment.