diff --git a/DCVC-DC/src/models/video_model.py b/DCVC-DC/src/models/video_model.py index 7cf5bc7..31388fb 100644 --- a/DCVC-DC/src/models/video_model.py +++ b/DCVC-DC/src/models/video_model.py @@ -53,8 +53,8 @@ def forward(self, x, aux_feature, flow): # warp offset = offset.view(B * self.group_num * self.offset_num, 2, H, W) mask = mask.view(B * self.group_num * self.offset_num, 1, H, W) - x = x.view(B * self.group_num, C // self.group_num, H, W) - x = x.repeat(self.offset_num, 1, 1, 1) + x = x.repeat(1, self.offset_num, 1, 1) + x = x.view(B * self.group_num * self.offset_num, C // self.group_num, H, W) x = flow_warp(x, offset) x = x * mask x = x.view(B, C * self.offset_num, H, W)