Sometimes, we have two branches after a backbone model. In this way, the computation graph of both branches will be stored in the GPU memory. This makes my model unable to train with limited GPU memory.
Assume the network architecture is
a ----a1_net----> a1 -----a21_net-----> a21 --> loss1
└----a22_net-----> a22 --> loss2
will destroy the computation graph of a->a1
. However, if we set retain_graph=True
, the computation graph of a1->a21
will also be preserved.
To precisely control the backward procedure, there are two ways.
For earlier PyTorch version, we can use
, and setonly_inputs=False
. However, it's deprecated in current versions. -
Now, we can use
. Here provides the example codes.
a1 = a1_net(a)
a1_clone = a1.detach().clone()
a21 = loss(a21_net(a1_clone))
torch.autograd.backward(a21, grad_tensors=torch.ones_like(a21), inputs=[a1_clone, *a21_net.parameters()])
a22 = loss(a22_net(a1_clone))
torch.autograd.backward(a22, grad_tensors=torch.ones_like(a22), inputs=[a1_clone, *a22_net.parameters()])
a1.grad = a1_clone.grad
torch.autograd.backward(a1, grad_tensors=a1.grad, inputs=[*a1_net.parameters()])
we can verify this by
b = a.detach().clone()
b.requires_grad_(True) # suppose a.requires_grad = True
b1 = b1_net(b) # suppose that b1_net has the same parameters as a1_net
b21 = loss(b21_net(b1)) # suppose that b21_net has the same parameters as a21_net
b22 = loss(b22_net(b1)) # suppose that b22_net has the same parameters as a22_net
loss = b21 + b22
print(a.grad, b.grad) # should be the same