Skip to content

Commit

Permalink
Added support for Apple Metal Performance Shaders (mps) on M3
Browse files Browse the repository at this point in the history
  • Loading branch information
zoharbabin committed Mar 9, 2024
1 parent d32e8e5 commit dc69e77
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 12 deletions.
7 changes: 6 additions & 1 deletion inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@
args = options()

def main():
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if torch.cuda.is_available():
device = torch.device('cuda')
elif torch.backends.mps.is_available():
device = torch.device('mps')
else:
device = torch.device('cpu')
print('[Info] Using {} for inference.'.format(device))
os.makedirs(os.path.join('temp', args.tmp_dir), exist_ok=True)

Expand Down
17 changes: 13 additions & 4 deletions models/ffc.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,10 @@ def forward(self, x):
r_size = x.size()
# (batch, c, h, w/2+1, 2)
fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1)
ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
if x.device.type == 'mps':
ffted = torch.fft.rfftn(x.cpu(), dim=fft_dim, norm=self.fft_norm).to(x.device)
else:
ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1)
ffted = ffted.view((batch, -1,) + ffted.size()[3:])
Expand All @@ -115,10 +118,16 @@ def forward(self, x):

ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(
0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2)
ffted = torch.complex(ffted[..., 0], ffted[..., 1])

if ffted.device.type == 'mps':
ffted = torch.complex(ffted[..., 0].cpu(), ffted[..., 1].cpu()).to(ffted.device)
else:
ffted = torch.complex(ffted[..., 0], ffted[..., 1])
ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:]
output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm)

if ffted.device.type == 'mps':
output = torch.fft.irfftn(ffted.cpu(), s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm).to(ffted.device)
else:
output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm)

if self.spatial_scale_factor is not None:
output = F.interpolate(output, size=orig_size, mode=self.spatial_scale_mode, align_corners=False)
Expand Down
10 changes: 3 additions & 7 deletions third_part/GPEN/face_parse/face_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ def process_tensor(self, imt):
def img2tensor(self, img):
img = img[..., ::-1] # BGR to RGB
img = img / 255. * 2 - 1
img_tensor = torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).to(self.device)
a = img.transpose(2, 0, 1)
a = np.float32(a)
img_tensor = torch.from_numpy(a).unsqueeze(0).to(self.device)
return img_tensor.float()

def tenor2mask(self, tensor, masks):
Expand Down Expand Up @@ -125,12 +127,6 @@ def process(self, im, masks=[0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255
mask = self.tenor2mask(pred_mask, masks)
return mask

# def img2tensor(self, img):
# img = img[..., ::-1] # BGR to RGB
# img = img / 255. * 2 - 1
# img_tensor = torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).to(self.device)
# return img_tensor.float()

def tenor2mask(self, tensor, masks):
if len(tensor.shape) < 4:
tensor = tensor.unsqueeze(0)
Expand Down

0 comments on commit dc69e77

Please sign in to comment.