diff --git a/Dataloader.py b/Dataloader.py new file mode 100644 index 0000000..5eac9fe --- /dev/null +++ b/Dataloader.py @@ -0,0 +1,82 @@ +from torch.utils.data import Dataset, DataLoader +import pandas as pd +import cv2 +from os.path import join +import torch +import numpy as np +from sklearn import preprocessing + +Scaler = preprocessing.MinMaxScaler() + +dataindex = 201 + +class MMIUnseenDataset(Dataset): + + def __init__(self, z_dim,points_path): + self.data = pd.read_csv(points_path,header=None).to_numpy() + self.z_dim = z_dim + + def __getitem__(self,index): + item = self.data[index] + # print(item) + # print(item.shape) + # points = item[0:dataindex-1].astype(np.float64) + points = torch.from_numpy(item.astype(np.float64)) + points = torch.hstack([points, torch.randn(self.z_dim - len(points))]) + points = points.reshape([self.z_dim, 1, 1]) + # print(points.shape) + return points + + + +class MMIDataset(Dataset): + + def __init__(self, img_size, z_dim, points_path, img_folder): + self.data = pd.read_csv(points_path, header=0, index_col=None).to_numpy() + # self.data = pd.read_csv(points_path, header=0).to_numpy() + self.img_folder = img_folder + self.img_size = img_size + self.z_dim = z_dim + + def __getitem__(self, index): + item = self.data[index] + img = cv2.imread(self.img_folder + '\\{}.png'.format(item[0]), cv2.IMREAD_GRAYSCALE) + img = cv2.resize(img, (self.img_size, self.img_size))[:, :, np.newaxis] + img = img / 255.0 + img = img.transpose(2, 0, 1) + img = torch.from_numpy(img) + points21 = item[1:dataindex].astype(np.float64).reshape(-1, 1) + # points21 = item[1:dataindex].astype(np.float64) + points21 = Scaler.fit_transform(points21) + points21 = torch.from_numpy(points21).flatten(0) + # points21 = torch.from_numpy(points21) + + points = item[1:dataindex].astype(np.float64).reshape(-1,1) + # points = item[1:dataindex].astype(np.float64) + # points = Scaler.fit_transform(points) + points = torch.from_numpy(points).flatten(0) + # points = torch.from_numpy(points) + assert len(points) <= self.z_dim + points = torch.hstack([points, torch.randn(self.z_dim - len(points))]) + points = points.reshape([self.z_dim, 1, 1]) + # the shape of points should be [Z_DIM, CHANNELS_IMG, FEATURES_GEN] + + return points, img, points21 + + def __len__(self): + return len(self.data) + + +def get_loader( + img_size, + batch_size, + z_dim, + points_path='C:/Users/Administrator/Desktop/pythonProject/pr1/new1e0.csv', + img_folder='C:/Users/Administrator/Desktop/pythonProject/pr1/Training_Data/image/new', + shuffle=True, +): + return DataLoader(MMIDataset(img_size, z_dim, points_path, img_folder), + batch_size=batch_size, shuffle=shuffle) + +if __name__ == "__main__": + pass \ No newline at end of file diff --git a/Pred_Known.py b/Pred_Known.py new file mode 100644 index 0000000..550ed22 --- /dev/null +++ b/Pred_Known.py @@ -0,0 +1,47 @@ +import torch +from torch import nn +import cv2 +from Dataloader import MMIDataset +import numpy as np + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +# load the saved trained Generator info +#model_path = r'C:/Users/Administrator/Desktop/pythonProject/pr1/logs/WGAN77/netG230.pt' +model_path = r'E:/newline1/netG10.pt' + +# Load the dataset +dataset = MMIDataset(img_size=64, + z_dim=250, + points_path=r'C:/Users/Administrator/Desktop/pythonProject/pr1/2000test.csv', + img_folder=r'C:/Users/Administrator/Desktop/pythonProject/pr1/Training_Data/image/new', + ) + +# Output the results path & load the data into Generator +results_folder = r'C:\Users\Administrator\Desktop\pythonProject\pr1' +gen = torch.load(model_path) +gen = gen.to(device) +gen = gen.eval() + +# Generate the image array from given dataset +def predict(net: nn.Module, points): + return net(points).squeeze(0).squeeze(0).cpu().detach().numpy() + +# Generate the desired number of results and save to path +# 0 means to data 1st row +# 40000 means last row in dataset +stop_p = 100 +i = 0 + +for p in dataset: + if i >= stop_p: + break + data = p[0].to(device, dtype=torch.float).unsqueeze(0) + img_out = predict(gen, data) + img = (img_out + 1) / 2 + img = np.round(255 * img) + img = cv2.normalize(img, None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX) + + cv2.imwrite(results_folder + '\\' + 'map200_' + str(i + 1) + '.png', img) + # cv2.imwrite(results_folder + '\\' + str(i) + '-test.png', img) + i += 1 + diff --git a/gradientpenalty.py b/gradientpenalty.py new file mode 100644 index 0000000..750367c --- /dev/null +++ b/gradientpenalty.py @@ -0,0 +1,36 @@ +import torch +import torch.nn as nn + + +def gradient_penalty(critic, points21, real,fake, device="cpu"): + # torch.autograd.set_detect_anomaly(True) + BATCH_SIZE, C, H, W = real.shape + alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device) + interpolated_images = real * alpha + fake * (1 - alpha) + + # Calculate critic scores + mixed_scores = critic(interpolated_images, points21) + + # Take the gradient of the scores with respect to the images + gradient = torch.autograd.grad( + inputs=interpolated_images, + outputs=mixed_scores, + grad_outputs=torch.ones_like(mixed_scores), + create_graph=True, + retain_graph=True, + )[0] + gradient = gradient.view(gradient.shape[0], -1) + gradient_norm = gradient.norm(2, dim=1) + gradient_penalty = torch.mean((gradient_norm - 1) ** 2) + return gradient_penalty + + +def save_checkpoint(state, filename="wgan_gp.pth.tar"): + print("=> Saving checkpoint") + torch.save(state, filename) + + +def load_checkpoint(checkpoint, gen, disc): + print("=> Loading checkpoint") + gen.load_state_dict(checkpoint['gen']) + disc.load_state_dict(checkpoint['disc']) \ No newline at end of file