Skip to content

Commit

Permalink
quick update of utils directory
Browse files Browse the repository at this point in the history
  • Loading branch information
zeonzir committed Dec 17, 2021
1 parent f343bd8 commit 379ab15
Show file tree
Hide file tree
Showing 17 changed files with 241 additions and 1,033 deletions.
4 changes: 2 additions & 2 deletions resnet50/utils/checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import torch

#### Save Current Checkpoint Details ####
def save_checkpoint(epoch, step, model, optimizer, filename):
state = { 'epoch':epoch,
'step': step,
Expand All @@ -12,8 +13,7 @@ def save_checkpoint(epoch, step, model, optimizer, filename):

torch.save(state, path)



#### Load state dict ####
def load_checkpoint(name):
checkpoint = torch.load(os.path.join('results',name))
return checkpoint["state_dict"]
211 changes: 0 additions & 211 deletions resnet50/utils/count_params.py

This file was deleted.

14 changes: 7 additions & 7 deletions resnet50/utils/metrics.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import torch

import numpy as np
import torch.nn as nn

#### Compute And Return Accuracy ####
def accuracy(net, testloader, device):
correct = 0
total = 0
Expand All @@ -12,16 +14,14 @@ def accuracy(net, testloader, device):
images, labels = data
images, labels = images.to(device), labels.to(device)

#Var = []
#for looper in range(10):
outputs = net(images, labels=True)
_, predicted = torch.max(outputs.data, 1)
#if len(Var):
# Var = np.dstack((outputs.cpu().numpy(), Var))
#else:
# Var = outputs.cpu().numpy()

total += labels.size(0)

correct += (predicted == labels).sum().item()

# END FOR

# END WITH

return float(correct) / total
Loading

0 comments on commit 379ab15

Please sign in to comment.