Skip to content

Commit

Permalink
add apex support with opt level o1
Browse files Browse the repository at this point in the history
  • Loading branch information
lbin committed Nov 18, 2019
1 parent 6d012ee commit bc8d039
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions dcn_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,15 @@

import _ext as _backend

try:
from apex import amp
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")


class _DCNv2(Function):
@staticmethod
@amp.float_function
def forward(ctx, input, offset, mask, weight, bias,
stride, padding, dilation, deformable_groups):
ctx.stride = _pair(stride)
Expand All @@ -34,6 +40,7 @@ def forward(ctx, input, offset, mask, weight, bias,

@staticmethod
@once_differentiable
@amp.float_function
def backward(ctx, grad_output):
input, offset, mask, weight, bias = ctx.saved_tensors
grad_input, grad_offset, grad_mask, grad_weight, grad_bias = \
Expand Down

0 comments on commit bc8d039

Please sign in to comment.