Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
ZooBeasts authored Mar 24, 2024
1 parent 5f786df commit 1910045
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 0 deletions.
82 changes: 82 additions & 0 deletions Dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import cv2
from os.path import join
import torch
import numpy as np
from sklearn import preprocessing

Scaler = preprocessing.MinMaxScaler()

dataindex = 201

class MMIUnseenDataset(Dataset):

def __init__(self, z_dim,points_path):
self.data = pd.read_csv(points_path,header=None).to_numpy()
self.z_dim = z_dim

def __getitem__(self,index):
item = self.data[index]
# print(item)
# print(item.shape)
# points = item[0:dataindex-1].astype(np.float64)
points = torch.from_numpy(item.astype(np.float64))
points = torch.hstack([points, torch.randn(self.z_dim - len(points))])
points = points.reshape([self.z_dim, 1, 1])
# print(points.shape)
return points



class MMIDataset(Dataset):

def __init__(self, img_size, z_dim, points_path, img_folder):
self.data = pd.read_csv(points_path, header=0, index_col=None).to_numpy()
# self.data = pd.read_csv(points_path, header=0).to_numpy()
self.img_folder = img_folder
self.img_size = img_size
self.z_dim = z_dim

def __getitem__(self, index):
item = self.data[index]
img = cv2.imread(self.img_folder + '\\{}.png'.format(item[0]), cv2.IMREAD_GRAYSCALE)
img = cv2.resize(img, (self.img_size, self.img_size))[:, :, np.newaxis]
img = img / 255.0
img = img.transpose(2, 0, 1)
img = torch.from_numpy(img)
points21 = item[1:dataindex].astype(np.float64).reshape(-1, 1)
# points21 = item[1:dataindex].astype(np.float64)
points21 = Scaler.fit_transform(points21)
points21 = torch.from_numpy(points21).flatten(0)
# points21 = torch.from_numpy(points21)

points = item[1:dataindex].astype(np.float64).reshape(-1,1)
# points = item[1:dataindex].astype(np.float64)
# points = Scaler.fit_transform(points)
points = torch.from_numpy(points).flatten(0)
# points = torch.from_numpy(points)
assert len(points) <= self.z_dim
points = torch.hstack([points, torch.randn(self.z_dim - len(points))])
points = points.reshape([self.z_dim, 1, 1])
# the shape of points should be [Z_DIM, CHANNELS_IMG, FEATURES_GEN]

return points, img, points21

def __len__(self):
return len(self.data)


def get_loader(
img_size,
batch_size,
z_dim,
points_path='C:/Users/Administrator/Desktop/pythonProject/pr1/new1e0.csv',
img_folder='C:/Users/Administrator/Desktop/pythonProject/pr1/Training_Data/image/new',
shuffle=True,
):
return DataLoader(MMIDataset(img_size, z_dim, points_path, img_folder),
batch_size=batch_size, shuffle=shuffle)

if __name__ == "__main__":
pass
47 changes: 47 additions & 0 deletions Pred_Known.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import torch
from torch import nn
import cv2
from Dataloader import MMIDataset
import numpy as np

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# load the saved trained Generator info
#model_path = r'C:/Users/Administrator/Desktop/pythonProject/pr1/logs/WGAN77/netG230.pt'
model_path = r'E:/newline1/netG10.pt'

# Load the dataset
dataset = MMIDataset(img_size=64,
z_dim=250,
points_path=r'C:/Users/Administrator/Desktop/pythonProject/pr1/2000test.csv',
img_folder=r'C:/Users/Administrator/Desktop/pythonProject/pr1/Training_Data/image/new',
)

# Output the results path & load the data into Generator
results_folder = r'C:\Users\Administrator\Desktop\pythonProject\pr1'
gen = torch.load(model_path)
gen = gen.to(device)
gen = gen.eval()

# Generate the image array from given dataset
def predict(net: nn.Module, points):
return net(points).squeeze(0).squeeze(0).cpu().detach().numpy()

# Generate the desired number of results and save to path
# 0 means to data 1st row
# 40000 means last row in dataset
stop_p = 100
i = 0

for p in dataset:
if i >= stop_p:
break
data = p[0].to(device, dtype=torch.float).unsqueeze(0)
img_out = predict(gen, data)
img = (img_out + 1) / 2
img = np.round(255 * img)
img = cv2.normalize(img, None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX)

cv2.imwrite(results_folder + '\\' + 'map200_' + str(i + 1) + '.png', img)
# cv2.imwrite(results_folder + '\\' + str(i) + '-test.png', img)
i += 1

36 changes: 36 additions & 0 deletions gradientpenalty.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch
import torch.nn as nn


def gradient_penalty(critic, points21, real,fake, device="cpu"):
# torch.autograd.set_detect_anomaly(True)
BATCH_SIZE, C, H, W = real.shape
alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
interpolated_images = real * alpha + fake * (1 - alpha)

# Calculate critic scores
mixed_scores = critic(interpolated_images, points21)

# Take the gradient of the scores with respect to the images
gradient = torch.autograd.grad(
inputs=interpolated_images,
outputs=mixed_scores,
grad_outputs=torch.ones_like(mixed_scores),
create_graph=True,
retain_graph=True,
)[0]
gradient = gradient.view(gradient.shape[0], -1)
gradient_norm = gradient.norm(2, dim=1)
gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
return gradient_penalty


def save_checkpoint(state, filename="wgan_gp.pth.tar"):
print("=> Saving checkpoint")
torch.save(state, filename)


def load_checkpoint(checkpoint, gen, disc):
print("=> Loading checkpoint")
gen.load_state_dict(checkpoint['gen'])
disc.load_state_dict(checkpoint['disc'])

0 comments on commit 1910045

Please sign in to comment.