Skip to content

Commit

Permalink
Adding block_mnist, runs, but needs some work!
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed Jun 19, 2024
1 parent 93c4bfa commit b3f86f2
Show file tree
Hide file tree
Showing 7 changed files with 279 additions and 10 deletions.
9 changes: 9 additions & 0 deletions config/block_mnist.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
max_epochs: 1
accelerator: 'cpu'
n: 5
batch_size: 16
layer_type: polynomial_3d
train_fraction: 1.0

defaults:
- optimizer: sophia
250 changes: 250 additions & 0 deletions examples/block_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
"""
Using a polynomial 3d to solve this problem
"""

import os

import hydra
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from lion_pytorch import Lion
import torchvision
import torchvision.transforms as transforms
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
from torchmetrics.functional import accuracy
from Sophia import SophiaG

from high_order_layers_torch.layers import *

transformStandard = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
transformPoly = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.0,), (1.0,))]
)

normalization = {
"max_abs": MaxAbsNormalizationND,
"max_center": MaxCenterNormalizationND,
}

grid_x, grid_y = torch.meshgrid(
(torch.arange(28) - 14) // 14, (torch.arange(28) - 14) // 14, indexing="ij"
)
grid = torch.stack([grid_x, grid_y])


def collate_fn(batch):

input = []
classification = []
for element in batch:
color_and_xy = torch.cat([element[0], grid]).permute(1, 2, 0).view(-1, 3)
input.append(color_and_xy)

classification.append(element[1])

batch_input = torch.stack(input)
batch_output = torch.tensor(classification)

return batch_input, batch_output


class Net(LightningModule):
def __init__(self, cfg: DictConfig):
super().__init__()
self.save_hyperparameters(cfg)

self._cfg = cfg
try:
self._data_dir = f"{hydra.utils.get_original_cwd()}/data"
except:
self._data_dir = "../data"

n = cfg.n
self._batch_size = cfg.batch_size
self._layer_type = cfg.layer_type
self._train_fraction = cfg.train_fraction

self._transform = transformPoly

layer1 = high_order_fc_layers(
layer_type=cfg.layer_type,
n=n,
in_features=1,
out_features=10,
intialization="constant_random",
device=cfg.accelerator,
)
self.model = nn.Sequential(*[layer1])

def forward(self, x):
#print("x.shape", x.shape)
batch_size, inputs = x.shape[:2]
xin = x.view(-1, 1, 3)
#print("xin.shape", xin.shape)
res = self.model(xin)
res = res.reshape(batch_size, inputs, -1)
output = torch.sum(res,dim=1)
#print("res.shape", output.shape)
# xout = res.view(batch_size, )
return output

def setup(self, stage):
num_train = int(self._train_fraction * 50000)
num_val = 10000

# extra only exist if we aren't training on the full dataset
num_extra = 50000 - num_train

train = torchvision.datasets.MNIST(
root=self._data_dir, train=True, download=True, transform=self._transform
)
self._train_subset, self._val_subset, extra = torch.utils.data.random_split(
train,
[num_train, 10000, num_extra],
generator=torch.Generator().manual_seed(1),
)

def training_step(self, batch, batch_idx):
return self.eval_step(batch, batch_idx, "train")

def train_dataloader(self):
return torch.utils.data.DataLoader(
self._train_subset,
batch_size=self._batch_size,
shuffle=True,
num_workers=10,
collate_fn=collate_fn,
)

def val_dataloader(self):
return torch.utils.data.DataLoader(
self._val_subset,
batch_size=self._batch_size,
shuffle=False,
num_workers=10,
collate_fn=collate_fn,
)

def test_dataloader(self):
testset = torchvision.datasets.MNIST(
root=self._data_dir,
train=False,
download=True,
transform=self._transform,
)
return torch.utils.data.DataLoader(
testset,
batch_size=self._batch_size,
shuffle=False,
num_workers=10,
collate_fn=collate_fn,
)

def validation_step(self, batch, batch_idx):
return self.eval_step(batch, batch_idx, "val")

def eval_step(self, batch, batch_idx, name):
x, y = batch

logits = self(x)
loss = F.cross_entropy(logits, y)
preds = torch.argmax(logits, dim=1)
acc = accuracy(preds, y, task="multiclass", num_classes=10)

# Calling self.log will surface up scalars for you in TensorBoard
self.log(f"{name}_loss", loss, prog_bar=True)
self.log(f"{name}_acc", acc, prog_bar=True)
return loss

def test_step(self, batch, batch_idx):
# Here we just reuse the validation_step for testing
return self.eval_step(batch, batch_idx, "test")

def configure_optimizers(self):
if self._cfg.optimizer.name == "adam":
optimizer = optim.Adam(self.parameters(), lr=self._cfg.optimizer.lr)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
patience=self._cfg.optimizer.patience,
factor=self._cfg.optimizer.factor,
verbose=True,
)
return [optimizer], [
{
"scheduler": lr_scheduler,
"monitor": "val_loss",
"interval": "epoch",
"reduce_on_plateau": True,
"frequency": 1,
}
]
elif self._cfg.optimizer.name == "lion":
optimizer = Lion(self.parameters(), lr=self._cfg.optimizer.lr)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode="min",
patience=self._cfg.optimizer.patience,
factor=self._cfg.optimizer.factor,
verbose=True,
)
return [optimizer], [
{
"scheduler": lr_scheduler,
"monitor": "val_loss",
"interval": "epoch",
"reduce_on_plateau": True,
"frequency": 1,
}
]
elif self._cfg.optimizer.name == "sophia":
optimizer = SophiaG(
self.parameters(),
lr=self._cfg.optimizer.lr,
rho=self._cfg.optimizer.rho,
)
return optimizer


def mnist(cfg: DictConfig):
print(OmegaConf.to_yaml(cfg))
print("Working directory : {}".format(os.getcwd()))

try:
print(f"Orig working directory : {hydra.utils.get_original_cwd()}")
except:
pass

lr_monitor = LearningRateMonitor(logging_interval="step")

early_stop_callback = EarlyStopping(
monitor="val_loss", min_delta=0.00, patience=20, verbose=False, mode="min"
)

trainer = Trainer(
max_epochs=cfg.max_epochs,
accelerator=cfg.accelerator,
callbacks=[lr_monitor],
)
model = Net(cfg)
trainer.fit(model)

print("testing")
results = trainer.test(model)

print("finished testing")
return results


@hydra.main(config_path="../config", config_name="block_mnist")
def run(cfg: DictConfig):
mnist(cfg)


if __name__ == "__main__":
run()
1 change: 0 additions & 1 deletion examples/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ def __init__(self, cfg: DictConfig):
)

self.normalize = normalization[cfg.normalization]()
# self.normalize = MaxAbsNormalizationND()

# self.pool = nn.MaxPool2d(2, 2)
self.pool = nn.AvgPool2d(2, 2)
Expand Down
1 change: 0 additions & 1 deletion examples/xor.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,6 @@ def plot_approximation(
trainer.fit(model)
if layer_type == "polynomial_2d" :
thisTest = xTest.reshape(xTest.size(0),1, -1)
print('xtest.shape', thisTest.shape)
predictions = model(thisTest)
else :
thisTest = xTest.reshape(xTest.size(0), -1)
Expand Down
12 changes: 10 additions & 2 deletions high_order_layers_torch/Basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,13 +414,21 @@ class BasisFlatND:
"""

def __init__(
self, n: int, basis: Callable[[Tensor, list[int]], float], dimensions: int
self,
n: int,
basis: Callable[[Tensor, list[int]], float],
dimensions: int,
**kwargs
):
self.n = n
self.basis = basis
self.dimensions = dimensions
a = torch.arange(n)
self.indexes = torch.stack(torch.meshgrid([a]*dimensions)).reshape(dimensions, -1).T.long()
self.indexes = (
torch.stack(torch.meshgrid([a] * dimensions))
.reshape(dimensions, -1)
.T.long()
)
self.num_basis = basis.num_basis

def interpolate(self, x: Tensor, w: Tensor) -> Tensor:
Expand Down
14 changes: 9 additions & 5 deletions high_order_layers_torch/LagrangePolynomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, length: float):
of 1 means there is periodicity 1
"""
self.length = length
self.num_basis = None # Apparently defined elsewhere? How does this work!
self.num_basis = None # Apparently defined elsewhere? How does this work!

def __call__(self, x: Tensor, j: int):
"""
Expand Down Expand Up @@ -77,15 +77,18 @@ def __call__(self, x, j: int):

class LagrangeBasisND:

def __init__(self, n: int, length: float = 2.0, dimensions: int = 2):
def __init__(
self, n: int, length: float = 2.0, dimensions: int = 2, device: str = "cpu", **kwargs
):
self.n = n
self.dimensions = dimensions
self.X = (length / 2.0) * chebyshevLobatto(n)
self.device = device
self.denominators = self._compute_denominators()
self.num_basis = int(math.pow(n, dimensions))

def _compute_denominators(self):
denom = torch.ones([self.n, self.n], dtype=torch.float32)
denom = torch.ones([self.n, self.n], dtype=torch.float32, device=self.device)

for j in range(self.n):
for m in range(self.n):
Expand All @@ -99,6 +102,7 @@ def __call__(self, x, index: list[int]):
:param index : [dimensions]
:returns: basis value [batch, inputs]
"""

x_diff = x.unsqueeze(-1) - self.X # [batch, inputs, dimensions, basis]
r = 1.0
for i, basis_i in enumerate(index):
Expand All @@ -121,7 +125,7 @@ class LagrangeBasis1:
def __init__(self, length: float = 2.0):
self.n = 1
self.X = torch.tensor([0.0])
self.num_basis=1
self.num_basis = 1

def __call__(self, x, j: int):
b = torch.ones_like(x)
Expand Down Expand Up @@ -182,7 +186,7 @@ class LagrangePolyFlatND(BasisFlatND):
def __init__(self, n: int, length: float = 2.0, dimensions: int = 2, **kwargs):
super().__init__(
n,
LagrangeBasisND(n, length, dimensions=dimensions),
LagrangeBasisND(n, length, dimensions=dimensions, **kwargs),
dimensions=dimensions,
**kwargs
)
Expand Down
2 changes: 1 addition & 1 deletion high_order_layers_torch/PolynomialLayers.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __init__(
n=n,
in_features=in_features,
out_features=out_features,
basis=LagrangePolyFlatND(n, length=length, dimensions=dimensions),
basis=LagrangePolyFlatND(n, length=length, dimensions=dimensions, **kwargs),
**kwargs,
)

Expand Down

0 comments on commit b3f86f2

Please sign in to comment.