Skip to content

Commit

Permalink
Using hydra for function example
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed Jun 4, 2024
1 parent 8ca313f commit e7fcdb7
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 12 deletions.
3 changes: 3 additions & 0 deletions config/function.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
optimizer : sophia # lion, adam, sparse_lion
epochs: 20
segments: 5
29 changes: 25 additions & 4 deletions examples/function_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
functions using a single input and single output with polynomial
synaptic weights
"""

import matplotlib.pyplot as plt
import numpy as np
import torch
Expand All @@ -12,6 +13,10 @@
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from lion_pytorch import Lion
from Sophia import SophiaG
import hydra
from omegaconf import DictConfig

from high_order_layers_torch.sparse_optimizers import SparseLion

from high_order_layers_torch.layers import *
Expand Down Expand Up @@ -139,6 +144,8 @@ def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.001)
elif self.optimizer == "lion":
return Lion(self.parameters(), lr=0.001)
elif self.optimizer == "sophia":
return SophiaG(self.parameters(), lr=0.01, rho=0.035)
elif self.optimizer == "sparse_lion":
print(f"Using sparse lion")
return SparseLion(self.parameters(), lr=0.001)
Expand Down Expand Up @@ -248,7 +255,11 @@ def plot_approximation(


def plot_results(
epochs: int = 20, segments: int = 5, plot: bool = True, first_only: bool = False
epochs: int = 20,
segments: int = 5,
plot: bool = True,
first_only: bool = False,
optimizer="lion",
):
"""
plt.figure(0)
Expand Down Expand Up @@ -289,11 +300,11 @@ def plot_results(
plot_approximation(
element["layer"],
element["model_set"],
5,
segments,
epochs,
accelerator="cpu",
periodicity=2,
opt="sparse_lion",
opt="sophia",
first_only=first_only,
)

Expand All @@ -304,5 +315,15 @@ def plot_results(
plt.show()


@hydra.main(config_path="../config", config_name="function")
def run(cfg: DictConfig):
plot_results(
epochs=cfg.epochs,
segments=cfg.segments,
first_only=False,
optimizer=cfg.optimizer,
)


if __name__ == "__main__":
plot_results(first_only=False)
run()
10 changes: 2 additions & 8 deletions examples/xor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,14 @@
synaptic weights
"""

import math
import os

import matplotlib.pyplot as plt
import numpy as np
import torch
from pytorch_lightning import LightningModule, Trainer
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, Dataset
from lion_pytorch import Lion
import hydra
from omegaconf import DictConfig, OmegaConf
from omegaconf import DictConfig


import high_order_layers_torch.PolynomialLayers as poly
Expand Down
Binary file modified plots/polynomial.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified plots/xor_polynomial.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit e7fcdb7

Please sign in to comment.