diff --git a/extension_cpp/csrc/lltm.cpp b/extension_cpp/csrc/lltm.cpp index c915dd9..c3bb5bc 100644 --- a/extension_cpp/csrc/lltm.cpp +++ b/extension_cpp/csrc/lltm.cpp @@ -89,7 +89,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} // Defines the operators TORCH_LIBRARY(extension_cpp, m) { - m.impl_abstract_pystub("extension_cpp.ops"); m.def("lltm_forward(Tensor input, Tensor weights, Tensor bias, Tensor old_h, Tensor old_cell) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)"); m.def("lltm_backward(Tensor grad_h, Tensor grad_cell, Tensor new_cell, Tensor input_gate, Tensor output_gate, Tensor candidate_cell, Tensor X, Tensor gate_weights, Tensor weights) -> (Tensor, Tensor, Tensor, Tensor, Tensor)"); } diff --git a/extension_cpp/ops.py b/extension_cpp/ops.py index 16c0311..51ecc84 100644 --- a/extension_cpp/ops.py +++ b/extension_cpp/ops.py @@ -8,37 +8,41 @@ def lltm( input: Tensor, weights: Tensor, bias: Tensor, old_h: Tensor, old_cell: Tensor ) -> Tuple[Tensor, Tensor]: - return LLTMFunction.apply(input, weights, bias, old_h, old_cell) - - -class LLTMFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, input, weights, bias, old_h, old_cell): - outputs = torch.ops.extension_cpp.lltm_forward.default( - input, weights, bias, old_h, old_cell - ) - new_h, new_cell = outputs[:2] - variables = list(outputs[1:]) + [weights] - ctx.save_for_backward(*variables) - - return new_h, new_cell - - @staticmethod - @torch.autograd.function.once_differentiable - def backward(ctx, grad_h, grad_cell): - ( - d_old_h, - d_input, - d_weights, - d_bias, - d_old_cell, - ) = torch.ops.extension_cpp.lltm_backward.default( - grad_h, grad_cell, *ctx.saved_tensors - ) - return d_input, d_weights, d_bias, d_old_h, d_old_cell - - -@torch.library.impl_abstract("extension_cpp::lltm_forward") + """The lltm API""" + outputs = torch.ops.extension_cpp.lltm_forward.default( + input, weights, bias, old_h, old_cell + ) + new_h, new_cell = outputs[:2] + return new_h, new_cell + + +# This is the backward for lltm_forward. +# lltm_forward has 7 returns so they all get gradients. +def backward(ctx, grad_h, grad_cell, _0, _1, _2, _3, _4): + ( + d_old_h, + d_input, + d_weights, + d_bias, + d_old_cell, + ) = torch.ops.extension_cpp.lltm_backward.default( + grad_h, grad_cell, *ctx.saved_tensors + ) + return d_input, d_weights, d_bias, d_old_h, d_old_cell + + +def setup_context(ctx, inputs, output): + weights = inputs[1] + new_h, new_cell = output[:2] + variables = list(output[1:]) + [weights] + ctx.save_for_backward(*variables) + + +torch.library.register_autograd( + "extension_cpp::lltm_forward", backward, setup_context=setup_context) + + +@torch.library.register_fake("extension_cpp::lltm_forward") def _(input, weights, bias, old_h, old_cell): X = torch.cat([old_h, input], dim=1) gate_weights = torch.nn.functional.linear(X, weights, bias)