From 2aa7c57547ec43271a9f1c6906dd6989d8cbe34f Mon Sep 17 00:00:00 2001 From: Adrian Bulat Date: Sat, 7 Oct 2017 13:10:36 +0100 Subject: [PATCH] Fixes #16 --- examples/detect_landmarks_in_image.py | 2 +- face_alignment/utils.py | 10 +++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/examples/detect_landmarks_in_image.py b/examples/detect_landmarks_in_image.py index a77615fa..be11d7a6 100644 --- a/examples/detect_landmarks_in_image.py +++ b/examples/detect_landmarks_in_image.py @@ -38,4 +38,4 @@ ax.view_init(elev=90., azim=90.) ax.set_xlim(ax.get_xlim()[::-1]) -plt.show() \ No newline at end of file +plt.show() diff --git a/face_alignment/utils.py b/face_alignment/utils.py index 2ea40204..fbbdc34c 100644 --- a/face_alignment/utils.py +++ b/face_alignment/utils.py @@ -209,8 +209,13 @@ def shuffle_lr(parts, pairs=None): def flip(tensor, is_label=False): + was_cuda = False if isinstance(tensor, torch.Tensor): tensor = tensor.numpy() + elif isinstance(tensor, torch.cuda.FloatTensor): + tensor = tensor.cpu().numpy() + was_cuda = True + was_squeezed = False if tensor.ndim == 4: tensor = np.squeeze(tensor) @@ -223,4 +228,7 @@ def flip(tensor, is_label=False): tensor = cv2.flip(tensor, 1).reshape(tensor.shape) if was_squeezed: tensor = np.expand_dims(tensor, axis=0) - return torch.from_numpy(tensor) + tensor = torch.from_numpy(tensor) + if was_cuda: + tensor = tensor.cuda() + return tensor