Skip to content

Commit

Permalink
Add ResMLP testing
Browse files Browse the repository at this point in the history
  • Loading branch information
slyubomirsky authored and gussmith23 committed Nov 23, 2021
1 parent d30035f commit 454d1f6
Show file tree
Hide file tree
Showing 8 changed files with 442 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,6 @@ dmypy.json
# Cython debug symbols
cython_debug/

# don't commit data
*.csv
e2e/resmlp/data/
1 change: 1 addition & 0 deletions e2e/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
This directory contains end-to-end evaluation trials.
6 changes: 6 additions & 0 deletions e2e/resmlp/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
End-to-end evaluation for res-mlp. The repo contains a model trained on CIFAR-10 using the provided train script.
As it takes a long time to train, it is not recommended to train it in the Dockerfile.
The trial script provides options for selecting the number of trials to report.
The digest script processes and prints the results of the trials.

The ResMLP implementation is this one: https://github.com/lucidrains/res-mlp-pytorch
Binary file added e2e/resmlp/cifar_net.pth
Binary file not shown.
34 changes: 34 additions & 0 deletions e2e/resmlp/digest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""
Simply reads the TVM test results and prints a digest
"""
import pandas as pd


def main():
numerical = pd.read_csv("./numerical.csv")
pred = pd.read_csv("./pred.csv")

use_accel = (numerical["accel_time"][0] != "None")

print(f"Average Relay time: {numerical['relay_time'].mean()}")
print(f"Average PT time: {numerical['pt_time'].mean()}")
if use_accel:
print(f"Average accelerated time: {numerical['accel_time'].mean()}")

# cast to float so everything else becomes a float
total = float(len(pred["relay_faithful"]))
pt_correct = pred[pred.pt_correct == True].shape[0]
relay_correct = pred[pred.relay_correct == True].shape[0]
relay_faithful = pred[pred.relay_faithful == True].shape[0]
print(f"PT accuracy: {(pt_correct/total)*100}")
print(f"Relay accuracy: {(relay_correct/total)*100}%")
print(f"Relay faithfulness: {(relay_faithful/total) * 100}%")
if use_accel:
accel_faithful = pred[pred.accel_faithful == True].shape[0]
accel_correct = pred[pred.accel_correct == True].shape[0]
print(f"Accelerator faithfulness: {(accel_faithful/total)*100}%")
print(f"Accelerator accuracy: {(accel_correct/total)*100}%")


if __name__ == "__main__":
main()
59 changes: 59 additions & 0 deletions e2e/resmlp/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import torch
from torch import nn, einsum
from einops.layers.torch import Rearrange, Reduce

# helpers

def pair(val):
return (val, val) if not isinstance(val, tuple) else val

# classes

class Affine(nn.Module):
def __init__(self, dim):
super().__init__()
self.g = nn.Parameter(torch.ones(1, 1, dim))
self.b = nn.Parameter(torch.zeros(1, 1, dim))

def forward(self, x):
return x * self.g + self.b

class PreAffinePostLayerScale(nn.Module): # https://arxiv.org/abs/2103.17239
def __init__(self, dim, depth, fn):
super().__init__()
if depth <= 18:
init_eps = 0.1
elif depth > 18 and depth <= 24:
init_eps = 1e-5
else:
init_eps = 1e-6

scale = torch.zeros(1, 1, dim).fill_(init_eps)
self.scale = nn.Parameter(scale)
self.affine = Affine(dim)
self.fn = fn

def forward(self, x):
return self.fn(self.affine(x)) * self.scale + x

def ResMLP(*, image_size, patch_size, dim, depth, num_classes, expansion_factor = 4):
image_height, image_width = pair(image_size)
assert (image_height % patch_size) == 0 and (image_width % patch_size) == 0, 'image height and width must be divisible by patch size'
num_patches = (image_height // patch_size) * (image_width // patch_size)
wrapper = lambda i, fn: PreAffinePostLayerScale(dim, i + 1, fn)

return nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
nn.Linear((patch_size ** 2) * 3, dim),
*[nn.Sequential(
wrapper(i, nn.Conv1d(num_patches, num_patches, 1)),
wrapper(i, nn.Sequential(
nn.Linear(dim, dim * expansion_factor),
nn.GELU(),
nn.Linear(dim * expansion_factor, dim)
))
) for i in range(depth)],
Affine(dim),
Reduce('b n c -> b c', 'mean'),
nn.Linear(dim, num_classes)
)
76 changes: 76 additions & 0 deletions e2e/resmlp/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import time

import torch
import torchvision
import torchvision.transforms as transforms

import torch.optim as optim
import torch.nn as nn
from model import ResMLP

# going by the book, following this CIFAR tutorial: https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html

transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 32

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

net = ResMLP(image_size=32,
patch_size=16,
dim=512,
depth=12,
num_classes=len(classes))

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
net.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

epochs = 90 # the least used in the ResMLP paper
start_time = time.time()
for epoch in range(epochs):

running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data[0].to(device), data[1].to(device)

# zero the parameter gradients
optimizer.zero_grad()

# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()

# print statistics
running_loss += loss.item()
if i % 250 == 249:
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 250))
running_loss = 0.0
end_time = time.time()

print('Finished Training')

PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH)
print(end_time - start_time)
# took 7665 seconds the last time
Loading

0 comments on commit 454d1f6

Please sign in to comment.