diff --git a/inference.py b/inference.py index a1819b41..e29d338e 100644 --- a/inference.py +++ b/inference.py @@ -29,7 +29,12 @@ args = options() def main(): - device = 'cuda' if torch.cuda.is_available() else 'cpu' + if torch.cuda.is_available(): + device = torch.device('cuda') + elif torch.backends.mps.is_available(): + device = torch.device('mps') + else: + device = torch.device('cpu') print('[Info] Using {} for inference.'.format(device)) os.makedirs(os.path.join('temp', args.tmp_dir), exist_ok=True) diff --git a/models/ffc.py b/models/ffc.py index 89a5c4c0..ca32f9aa 100644 --- a/models/ffc.py +++ b/models/ffc.py @@ -96,7 +96,10 @@ def forward(self, x): r_size = x.size() # (batch, c, h, w/2+1, 2) fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1) - ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm) + if x.device.type == 'mps': + ffted = torch.fft.rfftn(x.cpu(), dim=fft_dim, norm=self.fft_norm).to(x.device) + else: + ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm) ffted = torch.stack((ffted.real, ffted.imag), dim=-1) ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1) ffted = ffted.view((batch, -1,) + ffted.size()[3:]) @@ -115,10 +118,16 @@ def forward(self, x): ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute( 0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2) - ffted = torch.complex(ffted[..., 0], ffted[..., 1]) - + if ffted.device.type == 'mps': + ffted = torch.complex(ffted[..., 0].cpu(), ffted[..., 1].cpu()).to(ffted.device) + else: + ffted = torch.complex(ffted[..., 0], ffted[..., 1]) ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:] - output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm) + + if ffted.device.type == 'mps': + output = torch.fft.irfftn(ffted.cpu(), s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm).to(ffted.device) + else: + output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm) if self.spatial_scale_factor is not None: output = F.interpolate(output, size=orig_size, mode=self.spatial_scale_mode, align_corners=False) diff --git a/third_part/GPEN/face_parse/face_parsing.py b/third_part/GPEN/face_parse/face_parsing.py index 39d7cb7f..88ba419a 100644 --- a/third_part/GPEN/face_parse/face_parsing.py +++ b/third_part/GPEN/face_parse/face_parsing.py @@ -64,7 +64,9 @@ def process_tensor(self, imt): def img2tensor(self, img): img = img[..., ::-1] # BGR to RGB img = img / 255. * 2 - 1 - img_tensor = torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).to(self.device) + a = img.transpose(2, 0, 1) + a = np.float32(a) + img_tensor = torch.from_numpy(a).unsqueeze(0).to(self.device) return img_tensor.float() def tenor2mask(self, tensor, masks): @@ -125,12 +127,6 @@ def process(self, im, masks=[0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255 mask = self.tenor2mask(pred_mask, masks) return mask - # def img2tensor(self, img): - # img = img[..., ::-1] # BGR to RGB - # img = img / 255. * 2 - 1 - # img_tensor = torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).to(self.device) - # return img_tensor.float() - def tenor2mask(self, tensor, masks): if len(tensor.shape) < 4: tensor = tensor.unsqueeze(0)