-
Notifications
You must be signed in to change notification settings - Fork 36
/
Copy pathdata.py
66 lines (59 loc) · 2 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
import os.path
import random
import numpy as np
import cv2
import torch
import torch.utils.data as data
def RGB_np2Tensor(imgIn, imgTar):
ts = (2, 0, 1)
imgIn = torch.Tensor(imgIn.transpose(ts).astype(float)).mul_(1.0)
imgTar = torch.Tensor(imgTar.transpose(ts).astype(float)).mul_(1.0)
return imgIn, imgTar
def augment(imgIn, imgTar):
if random.random() < 0.3:
imgIn = imgIn[:, ::-1, :]
imgTar = imgTar[:, ::-1, :]
if random.random() < 0.3:
imgIn = imgIn[::-1, :, :]
imgTar = imgTar[::-1, :, :]
return imgIn, imgTar
def getPatch(imgIn, imgTar, args, scale):
(ih, iw, c) = imgIn.shape
(th, tw) = (scale * ih, scale * iw)
tp = args.patchSize
ip = tp // scale
ix = random.randrange(0, iw - ip + 1)
iy = random.randrange(0, ih - ip + 1)
(tx, ty) = (scale * ix, scale * iy)
imgIn = imgIn[iy:iy + ip, ix:ix + ip, :]
imgTar = imgTar[ty:ty + tp, tx:tx + tp, :]
return imgIn, imgTar
class DIV2K(data.Dataset):
def __init__(self, args):
self.args = args
self.scale = args.scale
apath = args.dataDir
dirHR = 'HR'
dirLR = 'LR'
self.dirIn = os.path.join(apath, dirLR)
self.dirTar = os.path.join(apath, dirHR)
self.fileList= os.listdir(self.dirHR)
self.nTrain = len(self.fileList)
def __getitem__(self, idx):
scale = self.scale
nameIn, nameTar = self.getFileName(idx)
imgIn = cv2.imread(nameIn)
imgTar = cv2.imread(nameTar)
if self.args.need_patch:
imgIn, imgTar = getPatch(imgIn, imgTar, self.args, scale)
imgIn, imgTar = augment(imgIn, imgTar)
return RGB_np2Tensor(imgIn, imgTar)
def __len__(self):
return self.nTrain
def getFileName(self, idx):
name = self.fileList[idx]
nameTar = os.path.join(self.dirTar, name)
name = name[0:-4] + 'x3' + '.png'
nameIn = os.path.join(self.dirIn, name)
return nameIn, nameTar