Skip to content

Commit

Permalink
MNIST normalization. Black refactoring.
Browse files Browse the repository at this point in the history
  • Loading branch information
eriklindernoren committed Mar 28, 2019
1 parent 8e4b612 commit c5d6be1
Show file tree
Hide file tree
Showing 20 changed files with 801 additions and 667 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
.DS_Store

data/*/
implementations/*/data
implementations/*/images
implementations/*/saved_models

Expand Down
75 changes: 45 additions & 30 deletions implementations/aae/aae.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,32 +15,34 @@
import torch.nn.functional as F
import torch

os.makedirs('images', exist_ok=True)
os.makedirs("images", exist_ok=True)

parser = argparse.ArgumentParser()
parser.add_argument('--n_epochs', type=int, default=200, help='number of epochs of training')
parser.add_argument('--batch_size', type=int, default=64, help='size of the batches')
parser.add_argument('--lr', type=float, default=0.0002, help='adam: learning rate')
parser.add_argument('--b1', type=float, default=0.5, help='adam: decay of first order momentum of gradient')
parser.add_argument('--b2', type=float, default=0.999, help='adam: decay of first order momentum of gradient')
parser.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation')
parser.add_argument('--latent_dim', type=int, default=10, help='dimensionality of the latent code')
parser.add_argument('--img_size', type=int, default=32, help='size of each image dimension')
parser.add_argument('--channels', type=int, default=1, help='number of image channels')
parser.add_argument('--sample_interval', type=int, default=400, help='interval between image sampling')
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=10, help="dimensionality of the latent code")
parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling")
opt = parser.parse_args()
print(opt)

img_shape = (opt.channels, opt.img_size, opt.img_size)

cuda = True if torch.cuda.is_available() else False


def reparameterization(mu, logvar):
std = torch.exp(logvar / 2)
sampled_z = Variable(Tensor(np.random.normal(0, 1, (mu.size(0), opt.latent_dim))))
z = sampled_z * std + mu
return z


class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
Expand All @@ -50,7 +52,7 @@ def __init__(self):
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 512),
nn.BatchNorm1d(512),
nn.LeakyReLU(0.2, inplace=True)
nn.LeakyReLU(0.2, inplace=True),
)

self.mu = nn.Linear(512, opt.latent_dim)
Expand All @@ -64,6 +66,7 @@ def forward(self, img):
z = reparameterization(mu, logvar)
return z


class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
Expand All @@ -75,14 +78,15 @@ def __init__(self):
nn.BatchNorm1d(512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, int(np.prod(img_shape))),
nn.Tanh()
nn.Tanh(),
)

def forward(self, z):
img_flat = self.model(z)
img = img_flat.view(img_flat.shape[0], *img_shape)
return img


class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
Expand All @@ -93,13 +97,14 @@ def __init__(self):
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid()
nn.Sigmoid(),
)

def forward(self, z):
validity = self.model(z)
return validity


# Use binary cross-entropy loss
adversarial_loss = torch.nn.BCELoss()
pixelwise_loss = torch.nn.L1Loss()
Expand All @@ -117,29 +122,36 @@ def forward(self, z):
pixelwise_loss.cuda()

# Configure data loader
os.makedirs('../../data/mnist', exist_ok=True)
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
datasets.MNIST('../../data/mnist', train=True, download=True,
transform=transforms.Compose([
transforms.Resize(opt.img_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])),
batch_size=opt.batch_size, shuffle=True)
datasets.MNIST(
"../../data/mnist",
train=True,
download=True,
transform=transforms.Compose(
[transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
),
),
batch_size=opt.batch_size,
shuffle=True,
)

# Optimizers
optimizer_G = torch.optim.Adam( itertools.chain(encoder.parameters(), decoder.parameters()),
lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_G = torch.optim.Adam(
itertools.chain(encoder.parameters(), decoder.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2)
)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor


def sample_image(n_row, batches_done):
"""Saves a grid of generated digits"""
# Sample noise
z = Variable(Tensor(np.random.normal(0, 1, (n_row**2, opt.latent_dim))))
z = Variable(Tensor(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))))
gen_imgs = decoder(z)
save_image(gen_imgs.data, 'images/%d.png' % batches_done, nrow=n_row, normalize=True)
save_image(gen_imgs.data, "images/%d.png" % batches_done, nrow=n_row, normalize=True)


# ----------
# Training
Expand All @@ -165,8 +177,9 @@ def sample_image(n_row, batches_done):
decoded_imgs = decoder(encoded_imgs)

# Loss measures generator's ability to fool the discriminator
g_loss = 0.001 * adversarial_loss(discriminator(encoded_imgs), valid) + \
0.999 * pixelwise_loss(decoded_imgs, real_imgs)
g_loss = 0.001 * adversarial_loss(discriminator(encoded_imgs), valid) + 0.999 * pixelwise_loss(
decoded_imgs, real_imgs
)

g_loss.backward()
optimizer_G.step()
Expand All @@ -188,8 +201,10 @@ def sample_image(n_row, batches_done):
d_loss.backward()
optimizer_D.step()

print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % (epoch, opt.n_epochs, i, len(dataloader),
d_loss.item(), g_loss.item()))
print(
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
)

batches_done = epoch * len(dataloader) + i
if batches_done % opt.sample_interval == 0:
Expand Down
92 changes: 48 additions & 44 deletions implementations/acgan/acgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,41 +14,43 @@
import torch.nn.functional as F
import torch

os.makedirs('images', exist_ok=True)
os.makedirs("images", exist_ok=True)

parser = argparse.ArgumentParser()
parser.add_argument('--n_epochs', type=int, default=200, help='number of epochs of training')
parser.add_argument('--batch_size', type=int, default=64, help='size of the batches')
parser.add_argument('--lr', type=float, default=0.0002, help='adam: learning rate')
parser.add_argument('--b1', type=float, default=0.5, help='adam: decay of first order momentum of gradient')
parser.add_argument('--b2', type=float, default=0.999, help='adam: decay of first order momentum of gradient')
parser.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation')
parser.add_argument('--latent_dim', type=int, default=100, help='dimensionality of the latent space')
parser.add_argument('--n_classes', type=int, default=10, help='number of classes for dataset')
parser.add_argument('--img_size', type=int, default=32, help='size of each image dimension')
parser.add_argument('--channels', type=int, default=1, help='number of image channels')
parser.add_argument('--sample_interval', type=int, default=400, help='interval between image sampling')
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--n_classes", type=int, default=10, help="number of classes for dataset")
parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling")
opt = parser.parse_args()
print(opt)

cuda = True if torch.cuda.is_available() else False


def weights_init_normal(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
if classname.find("Conv") != -1:
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm2d') != -1:
elif classname.find("BatchNorm2d") != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)


class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()

self.label_emb = nn.Embedding(opt.n_classes, opt.latent_dim)

self.init_size = opt.img_size // 4 # Initial size before upsampling
self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128*self.init_size**2))
self.init_size = opt.img_size // 4 # Initial size before upsampling
self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))

self.conv_blocks = nn.Sequential(
nn.BatchNorm2d(128),
Expand All @@ -61,7 +63,7 @@ def __init__(self):
nn.BatchNorm2d(64, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
nn.Tanh()
nn.Tanh(),
)

def forward(self, noise, labels):
Expand All @@ -71,15 +73,14 @@ def forward(self, noise, labels):
img = self.conv_blocks(out)
return img


class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()

def discriminator_block(in_filters, out_filters, bn=True):
"""Returns layers of each discriminator block"""
block = [ nn.Conv2d(in_filters, out_filters, 3, 2, 1),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout2d(0.25)]
block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
if bn:
block.append(nn.BatchNorm2d(out_filters, 0.8))
return block
Expand All @@ -92,13 +93,11 @@ def discriminator_block(in_filters, out_filters, bn=True):
)

# The height and width of downsampled image
ds_size = opt.img_size // 2**4
ds_size = opt.img_size // 2 ** 4

# Output layers
self.adv_layer = nn.Sequential( nn.Linear(128*ds_size**2, 1),
nn.Sigmoid())
self.aux_layer = nn.Sequential( nn.Linear(128*ds_size**2, opt.n_classes),
nn.Softmax())
self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())
self.aux_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, opt.n_classes), nn.Softmax())

def forward(self, img):
out = self.conv_blocks(img)
Expand All @@ -108,6 +107,7 @@ def forward(self, img):

return validity, label


# Loss functions
adversarial_loss = torch.nn.BCELoss()
auxiliary_loss = torch.nn.CrossEntropyLoss()
Expand All @@ -127,15 +127,19 @@ def forward(self, img):
discriminator.apply(weights_init_normal)

# Configure data loader
os.makedirs('../../data/mnist', exist_ok=True)
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
datasets.MNIST('../../data/mnist', train=True, download=True,
transform=transforms.Compose([
transforms.Resize(opt.img_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])),
batch_size=opt.batch_size, shuffle=True)
datasets.MNIST(
"../../data/mnist",
train=True,
download=True,
transform=transforms.Compose(
[transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
),
),
batch_size=opt.batch_size,
shuffle=True,
)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
Expand All @@ -144,15 +148,17 @@ def forward(self, img):
FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor


def sample_image(n_row, batches_done):
"""Saves a grid of generated digits ranging from 0 to n_classes"""
# Sample noise
z = Variable(FloatTensor(np.random.normal(0, 1, (n_row**2, opt.latent_dim))))
z = Variable(FloatTensor(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))))
# Get labels ranging from 0 to n_classes for n rows
labels = np.array([num for _ in range(n_row) for num in range(n_row)])
labels = Variable(LongTensor(labels))
gen_imgs = generator(z, labels)
save_image(gen_imgs.data, 'images/%d.png' % batches_done, nrow=n_row, normalize=True)
save_image(gen_imgs.data, "images/%d.png" % batches_done, nrow=n_row, normalize=True)


# ----------
# Training
Expand Down Expand Up @@ -186,8 +192,7 @@ def sample_image(n_row, batches_done):

# Loss measures generator's ability to fool the discriminator
validity, pred_label = discriminator(gen_imgs)
g_loss = 0.5 * (adversarial_loss(validity, valid) + \
auxiliary_loss(pred_label, gen_labels))
g_loss = 0.5 * (adversarial_loss(validity, valid) + auxiliary_loss(pred_label, gen_labels))

g_loss.backward()
optimizer_G.step()
Expand All @@ -200,13 +205,11 @@ def sample_image(n_row, batches_done):

# Loss for real images
real_pred, real_aux = discriminator(real_imgs)
d_real_loss = (adversarial_loss(real_pred, valid) + \
auxiliary_loss(real_aux, labels)) / 2
d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, labels)) / 2

# Loss for fake images
fake_pred, fake_aux = discriminator(gen_imgs.detach())
d_fake_loss = (adversarial_loss(fake_pred, fake) + \
auxiliary_loss(fake_aux, gen_labels)) / 2
d_fake_loss = (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, gen_labels)) / 2

# Total discriminator loss
d_loss = (d_real_loss + d_fake_loss) / 2
Expand All @@ -219,9 +222,10 @@ def sample_image(n_row, batches_done):
d_loss.backward()
optimizer_D.step()

print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %d%%] [G loss: %f]" % (epoch, opt.n_epochs, i, len(dataloader),
d_loss.item(), 100 * d_acc,
g_loss.item()))
print(
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %d%%] [G loss: %f]"
% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), 100 * d_acc, g_loss.item())
)
batches_done = epoch * len(dataloader) + i
if batches_done % opt.sample_interval == 0:
sample_image(n_row=10, batches_done=batches_done)
Loading

0 comments on commit c5d6be1

Please sign in to comment.