Skip to content

Commit

Permalink
fix roi_pool_backward npu
Browse files Browse the repository at this point in the history
  • Loading branch information
hust17yixuan committed Jan 21, 2025
1 parent 761d780 commit b17a99c
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions mmcv/ops/csrc/pytorch/npu/roi_pool_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,23 +50,29 @@ void roi_pool_backward_npu(Tensor grad_output, Tensor rois, Tensor argmax,
int64_t pooled_height_64 = pooled_height;
int64_t pooled_width_64 = pooled_width;
int64_t pooled_channel = 1;
at::Tensor argmax_trans = argmax.transpose(1, 2).transpose(2, 3);
at::Tensor grad_output_trans = grad_output.transpose(1, 2).transpose(2, 3);
at::Tensor roi_actual_num =
at::empty_like(rois, rois.options().dtype(at::kInt));
at::Tensor x = at::ones_like(grad_input);
at::Tensor x = at::ones_like(grad_input).transpose(1, 2).transpose(2, 3);
at::Tensor y = at::zeros_like(x);
OpCommand cmd;
cmd.Name("RoiPoolingGradWithArgMax")
.Input(grad_output)
.Input(grad_output_trans)
.Input(x)
.Input(rois)
.Input(roi_actual_num)
.Input(argmax)
.Output(grad_input)
.Input(argmax_trans)
.Output(y)
.Attr("pooled_h", pooled_height_64)
.Attr("pooled_w", pooled_width_64)
.Attr("spatial_scale_h", spatial_scale)
.Attr("spatial_scale_w", spatial_scale)
.Attr("pool_channel", pooled_channel)
.Run();
at::Tensor result = y.transpose(2, 3).transpose(1, 2);
at::Tensor res = result.contiguous();
grad_input.copy_(res);
}

void roi_pool_forward_impl(Tensor input, Tensor rois, Tensor output,
Expand Down

0 comments on commit b17a99c

Please sign in to comment.