diff --git a/src/dataset.py b/src/dataset.py index 96aaafebc..209a2f781 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -6,8 +6,8 @@ import numpy as np import torchvision.transforms.functional as F from torch.utils.data import DataLoader +from imageio import imread from PIL import Image -from scipy.misc import imread from skimage.feature import canny from skimage.color import rgb2gray, gray2rgb from .utils import create_mask @@ -54,7 +54,7 @@ def load_item(self, index): size = self.input_size # load image - img = imread(self.data[index]) + img = imread(self.data[index])[:,:,:3] # gray to rgb if len(img.shape) < 3: @@ -87,7 +87,7 @@ def load_edge(self, img, index, mask): # in test mode images are masked (with masked regions), # using 'mask' parameter prevents canny to detect edges for the masked regions - mask = None if self.training else (1 - mask / 255).astype(np.bool) + mask = None if self.training else (1 - mask / 255).astype(bool) # canny if self.edge == 1: @@ -99,12 +99,12 @@ def load_edge(self, img, index, mask): if sigma == 0: sigma = random.randint(1, 4) - return canny(img, sigma=sigma, mask=mask).astype(np.float) + return canny(img, sigma=sigma, mask=mask).astype(float) # external else: imgh, imgw = img.shape[0:2] - edge = imread(self.edge_data[index]) + edge = imread(self.edge_data[index])[:,:,:3] edge = self.resize(edge, imgh, imgw) # non-max suppression @@ -137,7 +137,7 @@ def load_mask(self, img, index): # external if mask_type == 3: mask_index = random.randint(0, len(self.mask_data) - 1) - mask = imread(self.mask_data[mask_index]) + mask = imread(self.mask_data[mask_index])[:,:,:3] mask = self.resize(mask, imgh, imgw) mask = (mask > 0).astype(np.uint8) * 255 # threshold due to interpolation return mask @@ -146,7 +146,8 @@ def load_mask(self, img, index): if mask_type == 6: mask = imread(self.mask_data[index]) mask = self.resize(mask, imgh, imgw, centerCrop=False) - mask = rgb2gray(mask) + if mask.shape[-1] == 3: + mask = rgb2gray(mask) mask = (mask > 0).astype(np.uint8) * 255 return mask @@ -165,7 +166,7 @@ def resize(self, img, height, width, centerCrop=True): i = (imgw - side) // 2 img = img[j:j + side, i:i + side, ...] - img = scipy.misc.imresize(img, [height, width]) + img = np.array(Image.fromarray(img).resize((width, height))) return img