diff --git a/mmcv/ops/csrc/pytorch/npu/roi_pool_npu.cpp b/mmcv/ops/csrc/pytorch/npu/roi_pool_npu.cpp index c7a11e8c6d..b7015439b9 100644 --- a/mmcv/ops/csrc/pytorch/npu/roi_pool_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/roi_pool_npu.cpp @@ -50,23 +50,29 @@ void roi_pool_backward_npu(Tensor grad_output, Tensor rois, Tensor argmax, int64_t pooled_height_64 = pooled_height; int64_t pooled_width_64 = pooled_width; int64_t pooled_channel = 1; + at::Tensor argmax_trans = argmax.transpose(1, 2).transpose(2, 3); + at::Tensor grad_output_trans = grad_output.transpose(1, 2).transpose(2, 3); at::Tensor roi_actual_num = at::empty_like(rois, rois.options().dtype(at::kInt)); - at::Tensor x = at::ones_like(grad_input); + at::Tensor x = at::ones_like(grad_input).transpose(1, 2).transpose(2, 3); + at::Tensor y = at::zeros_like(x); OpCommand cmd; cmd.Name("RoiPoolingGradWithArgMax") - .Input(grad_output) + .Input(grad_output_trans) .Input(x) .Input(rois) .Input(roi_actual_num) - .Input(argmax) - .Output(grad_input) + .Input(argmax_trans) + .Output(y) .Attr("pooled_h", pooled_height_64) .Attr("pooled_w", pooled_width_64) .Attr("spatial_scale_h", spatial_scale) .Attr("spatial_scale_w", spatial_scale) .Attr("pool_channel", pooled_channel) .Run(); + at::Tensor result = y.transpose(2, 3).transpose(1, 2); + at::Tensor res = result.contiguous(); + grad_input.copy_(res); } void roi_pool_forward_impl(Tensor input, Tensor rois, Tensor output,