Skip to content

Commit

Permalink
UNIT: Unsupervised image-to-image translation networks
Browse files Browse the repository at this point in the history
  • Loading branch information
eriklindernoren committed May 24, 2018
1 parent 3a00900 commit 49621f6
Show file tree
Hide file tree
Showing 4 changed files with 468 additions and 0 deletions.
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Collection of PyTorch implementations of Generative Adversarial Network varietie
+ [Softmax GAN](#softmax-gan)
+ [StarGAN](#stargan)
+ [Super-Resolution GAN](#super-resolution-gan)
+ [UNIT](#UNIT)
+ [Wasserstein GAN](#wasserstein-gan)
+ [Wasserstein GAN GP](#wasserstein-gan-gp)

Expand Down Expand Up @@ -576,6 +577,25 @@ $ python3 srgan.py
resolution image
</p>

### UNIT
_Unsupervised Image-to-Image Translation Networks_

#### Authors
Ming-Yu Liu, Thomas Breuel, Jan Kautz

#### Abstract
Unsupervised image-to-image translation aims at learning a joint distribution of images in different domains by using images from the marginal distributions in individual domains. Since there exists an infinite set of joint distributions that can arrive the given marginal distributions, one could infer nothing about the joint distribution from the marginal distributions without additional assumptions. To address the problem, we make a shared-latent space assumption and propose an unsupervised image-to-image translation framework based on Coupled GANs. We compare the proposed framework with competing approaches and present high quality image translation results on various challenging unsupervised image translation tasks, including street scene image translation, animal image translation, and face image translation. We also apply the proposed framework to domain adaptation and achieve state-of-the-art performance on benchmark datasets. Code and additional results are available in this [https URL](https://github.com/mingyuliutw/unit).

[[Paper]](https://arxiv.org/abs/1703.00848) [[Code]](implementations/unit/unit.py)

#### Run Example
```
$ cd data/
$ bash download_cyclegan_dataset.sh apple2orange
$ cd implementations/unit/
$ python3 unit.py --dataset_name apple2orange
```

### Wasserstein GAN
_Wasserstein GAN_

Expand Down
28 changes: 28 additions & 0 deletions implementations/unit/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import glob
import random
import os

from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms

class ImageDataset(Dataset):
def __init__(self, root, transforms_=None, unaligned=False, mode='train'):
self.transform = transforms.Compose(transforms_)
self.unaligned = unaligned

self.files_A = sorted(glob.glob(os.path.join(root, '%s/A' % mode) + '/*.*'))
self.files_B = sorted(glob.glob(os.path.join(root, '%s/B' % mode) + '/*.*'))

def __getitem__(self, index):
item_A = self.transform(Image.open(self.files_A[index % len(self.files_A)]))

if self.unaligned:
item_B = self.transform(Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)]))
else:
item_B = self.transform(Image.open(self.files_B[index % len(self.files_B)]))

return {'A': item_A, 'B': item_B}

def __len__(self):
return max(len(self.files_A), len(self.files_B))
138 changes: 138 additions & 0 deletions implementations/unit/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import torch.nn as nn
import torch.nn.functional as F
import torch
from torch.autograd import Variable
import numpy as np

def weights_init_normal(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm2d') != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)

class LambdaLR():
def __init__(self, n_epochs, offset, decay_start_epoch):
assert ((n_epochs - decay_start_epoch) > 0), "Decay must start before the training session ends!"
self.n_epochs = n_epochs
self.offset = offset
self.decay_start_epoch = decay_start_epoch

def step(self, epoch):
return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch)/(self.n_epochs - self.decay_start_epoch)


##############################
# RESNET
##############################

class ResidualBlock(nn.Module):
def __init__(self, features):
super(ResidualBlock, self).__init__()

conv_block = [ nn.ReflectionPad2d(1),
nn.Conv2d(features, features, 3),
nn.InstanceNorm2d(features),
nn.ReLU(inplace=True),
nn.ReflectionPad2d(1),
nn.Conv2d(features, features, 3),
nn.InstanceNorm2d(features) ]

self.conv_block = nn.Sequential(*conv_block)

def forward(self, x):
return x + self.conv_block(x)

class Encoder(nn.Module):
def __init__(self, in_channels=3, dim=64, n_downsample=2, shared_block=None):
super(Encoder, self).__init__()

# Initial convolution block
layers = [ nn.ReflectionPad2d(3),
nn.Conv2d(in_channels, dim, 7),
nn.InstanceNorm2d(64),
nn.LeakyReLU(0.2, inplace=True) ]

# Downsampling
for _ in range(n_downsample):
layers += [ nn.Conv2d(dim, dim * 2, 4, stride=2, padding=1),
nn.InstanceNorm2d(dim * 2),
nn.ReLU(inplace=True) ]
dim *= 2

# Residual blocks
for _ in range(3):
layers += [ResidualBlock(dim)]

self.model_blocks = nn.Sequential(*layers)
self.shared_block = shared_block

def reparameterization(self, mu):
Tensor = torch.cuda.FloatTensor if mu.is_cuda else torch.FloatTensor
z = Variable(Tensor(np.random.normal(0, 1, mu.shape)))
return z + mu

def forward(self, x):
x = self.model_blocks(x)
mu = self.shared_block(x)
z = self.reparameterization(mu)
return mu, z

class Generator(nn.Module):
def __init__(self, out_channels=3, dim=64, n_upsample=2, shared_block=None):
super(Generator, self).__init__()

self.shared_block = shared_block

layers = []
dim = dim * 2**n_upsample
# Residual blocks
for _ in range(3):
layers += [ResidualBlock(dim)]

# Upsampling
for _ in range(n_upsample):
layers += [ nn.ConvTranspose2d(dim, dim // 2, 4, stride=2, padding=1),
nn.InstanceNorm2d(dim // 2),
nn.LeakyReLU(0.2, inplace=True) ]
dim = dim // 2

# Output layer
layers += [ nn.ReflectionPad2d(3),
nn.Conv2d(dim, out_channels, 7),
nn.Tanh() ]

self.model_blocks = nn.Sequential(*layers)

def forward(self, x):
x = self.shared_block(x)
x = self.model_blocks(x)
return x

##############################
# Discriminator
##############################

class Discriminator(nn.Module):
def __init__(self, in_channels=3):
super(Discriminator, self).__init__()

def discriminator_block(in_filters, out_filters, normalize=True):
"""Returns downsampling layers of each discriminator block"""
layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
if normalize:
layers.append(nn.InstanceNorm2d(out_filters))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers

self.model = nn.Sequential(
*discriminator_block(in_channels, 64, normalize=False),
*discriminator_block(64, 128),
*discriminator_block(128, 256),
*discriminator_block(256, 512),
nn.Conv2d(512, 1, 3, padding=1)
)

def forward(self, img):
return self.model(img)
Loading

0 comments on commit 49621f6

Please sign in to comment.