diff --git a/README.md b/README.md index 6aaa6646..bf7ed03a 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ Collection of PyTorch implementations of Generative Adversarial Network varietie + [Softmax GAN](#softmax-gan) + [StarGAN](#stargan) + [Super-Resolution GAN](#super-resolution-gan) - + [UNIT](#UNIT) + + [UNIT](#unit) + [Wasserstein GAN](#wasserstein-gan) + [Wasserstein GAN GP](#wasserstein-gan-gp) diff --git a/implementations/pix2pix/models.py b/implementations/pix2pix/models.py index 7788cfcf..7df5547b 100644 --- a/implementations/pix2pix/models.py +++ b/implementations/pix2pix/models.py @@ -52,11 +52,11 @@ def __init__(self, in_channels=3, out_channels=3): self.down1 = UNetDown(in_channels, 64, normalize=False) self.down2 = UNetDown(64, 128) self.down3 = UNetDown(128, 256) - self.down4 = UNetDown(256, 512) - self.down5 = UNetDown(512, 512) - self.down6 = UNetDown(512, 512) - self.down7 = UNetDown(512, 512) - self.down8 = UNetDown(512, 512, normalize=False) + self.down4 = UNetDown(256, 512, dropout=0.5) + self.down5 = UNetDown(512, 512, dropout=0.5) + self.down6 = UNetDown(512, 512, dropout=0.5) + self.down7 = UNetDown(512, 512, dropout=0.5) + self.down8 = UNetDown(512, 512, normalize=False, dropout=0.5) self.up1 = UNetUp(512, 512, dropout=0.5) self.up2 = UNetUp(1024, 512, dropout=0.5) diff --git a/implementations/pix2pix/pix2pix.py b/implementations/pix2pix/pix2pix.py index a8ed7c6a..fe5f79a8 100644 --- a/implementations/pix2pix/pix2pix.py +++ b/implementations/pix2pix/pix2pix.py @@ -104,7 +104,7 @@ def sample_images(batches_done): # Training # ---------- -start_time = time.time() +prev_time = time.time() for epoch in range(opt.epoch, opt.n_epochs): for i, batch in enumerate(dataloader): @@ -164,7 +164,8 @@ def sample_images(batches_done): # Determine approximate time left batches_done = epoch * len(dataloader) + i batches_left = opt.n_epochs * len(dataloader) - batches_done - time_left = datetime.timedelta(seconds=batches_left * (time.time() - start_time)/ (batches_done + 1)) + time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time)) + prev_time = time.time() # Print log sys.stdout.write("\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, pixel: %f, adv: %f] ETA: %s" % diff --git a/requirements.txt b/requirements.txt index ad423af4..91545b78 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,8 @@ -torch +torch>=0.4.0 torchvision matplotlib numpy scipy pillow -urllib -skimage -gzip -pickle +urllib3 +scikit-image