diff --git a/config/function.yaml b/config/function.yaml new file mode 100644 index 0000000..e0dec4a --- /dev/null +++ b/config/function.yaml @@ -0,0 +1,3 @@ +optimizer : sophia # lion, adam, sparse_lion +epochs: 20 +segments: 5 diff --git a/examples/function_example.py b/examples/function_example.py index 3895be2..cbb0ec0 100644 --- a/examples/function_example.py +++ b/examples/function_example.py @@ -3,6 +3,7 @@ functions using a single input and single output with polynomial synaptic weights """ + import matplotlib.pyplot as plt import numpy as np import torch @@ -12,6 +13,10 @@ from torch.nn import functional as F from torch.utils.data import DataLoader, Dataset from lion_pytorch import Lion +from Sophia import SophiaG +import hydra +from omegaconf import DictConfig + from high_order_layers_torch.sparse_optimizers import SparseLion from high_order_layers_torch.layers import * @@ -139,6 +144,8 @@ def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=0.001) elif self.optimizer == "lion": return Lion(self.parameters(), lr=0.001) + elif self.optimizer == "sophia": + return SophiaG(self.parameters(), lr=0.01, rho=0.035) elif self.optimizer == "sparse_lion": print(f"Using sparse lion") return SparseLion(self.parameters(), lr=0.001) @@ -248,7 +255,11 @@ def plot_approximation( def plot_results( - epochs: int = 20, segments: int = 5, plot: bool = True, first_only: bool = False + epochs: int = 20, + segments: int = 5, + plot: bool = True, + first_only: bool = False, + optimizer="lion", ): """ plt.figure(0) @@ -289,11 +300,11 @@ def plot_results( plot_approximation( element["layer"], element["model_set"], - 5, + segments, epochs, accelerator="cpu", periodicity=2, - opt="sparse_lion", + opt="sophia", first_only=first_only, ) @@ -304,5 +315,15 @@ def plot_results( plt.show() +@hydra.main(config_path="../config", config_name="function") +def run(cfg: DictConfig): + plot_results( + epochs=cfg.epochs, + segments=cfg.segments, + first_only=False, + optimizer=cfg.optimizer, + ) + + if __name__ == "__main__": - plot_results(first_only=False) + run() diff --git a/examples/xor.py b/examples/xor.py index d912c62..5c030d0 100644 --- a/examples/xor.py +++ b/examples/xor.py @@ -4,20 +4,14 @@ synaptic weights """ -import math -import os - import matplotlib.pyplot as plt -import numpy as np import torch from pytorch_lightning import LightningModule, Trainer from torch.nn import functional as F -from torch.utils.data import DataLoader, Dataset, random_split -from torchvision import transforms -from torchvision.datasets import MNIST +from torch.utils.data import DataLoader, Dataset from lion_pytorch import Lion import hydra -from omegaconf import DictConfig, OmegaConf +from omegaconf import DictConfig import high_order_layers_torch.PolynomialLayers as poly diff --git a/plots/polynomial.png b/plots/polynomial.png index bd32fe7..184b31f 100644 Binary files a/plots/polynomial.png and b/plots/polynomial.png differ diff --git a/plots/xor_polynomial.png b/plots/xor_polynomial.png index 4f01259..bd32fe7 100644 Binary files a/plots/xor_polynomial.png and b/plots/xor_polynomial.png differ