Skip to content

Commit

Permalink
Merge pull request alinlab#2 from icks/torch-v1
Browse files Browse the repository at this point in the history
PyTorch v1.x compatibility
  • Loading branch information
pokaxpoka authored Dec 18, 2019
2 parents 91753e0 + 8277732 commit 462db01
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 6 deletions.
5 changes: 2 additions & 3 deletions src/run_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import torchvision.utils as vutils
import models

from torch.utils.serialization import load_lua
from torchvision import datasets, transforms
from torch.autograd import Variable

Expand Down Expand Up @@ -74,7 +73,7 @@ def train(epoch):
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.data[0]))
100. * batch_idx / len(train_loader), loss.data.item()))

def test(epoch):
model.eval()
Expand All @@ -87,7 +86,7 @@ def test(epoch):
data, target = data.cuda(), target.cuda()
data, target = Variable(data, volatile=True), Variable(target)
output = F.log_softmax(model(data))
test_loss += F.nll_loss(output, target).data[0]
test_loss += F.nll_loss(output, target).data.item()
pred = output.data.max(1)[1] # get the index of the max log-probability
correct += pred.eq(target.data).cpu().sum()

Expand Down
5 changes: 2 additions & 3 deletions src/run_joint_confidence.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import torchvision.utils as vutils
import models

from torch.utils.serialization import load_lua
from torchvision import datasets, transforms
from torch.autograd import Variable

Expand Down Expand Up @@ -162,7 +161,7 @@ def train(epoch):
if batch_idx % args.log_interval == 0:
print('Classification Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}, KL fake Loss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.data[0], KL_loss_fake.data[0]))
100. * batch_idx / len(train_loader), loss.data.item(), KL_loss_fake.data.item()))
fake = netG(fixed_noise)
vutils.save_image(fake.data, '%s/gan_samples_epoch_%03d.png'%(args.outf, epoch), normalize=True)

Expand All @@ -177,7 +176,7 @@ def test(epoch):
data, target = data.cuda(), target.cuda()
data, target = Variable(data, volatile=True), Variable(target)
output = F.log_softmax(model(data))
test_loss += F.nll_loss(output, target).data[0]
test_loss += F.nll_loss(output, target).data.item()
pred = output.data.max(1)[1] # get the index of the max log-probability
correct += pred.eq(target.data).cpu().sum()

Expand Down

0 comments on commit 462db01

Please sign in to comment.