Skip to content

Commit

Permalink
fix: removed optax, it doesn't work with complex numbers yet
Browse files Browse the repository at this point in the history
  • Loading branch information
astanziola committed Apr 11, 2022
1 parent 88e611c commit 4a85333
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 56 deletions.
10 changes: 5 additions & 5 deletions data/darcy/generate_darcy_dataset.m
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
%% Settings
filename ='darcy_238.mat'; % Name of the dataset file
s = 238; % Number of grid points on [0,1]^2
filename ='darcy_211.mat'; % Name of the dataset file
s = 211; % Number of grid points on [0,1]^2
num_samples = 1200; % Number of samples

% Parameters of covariance C = tau^(2*alpha-2)*(-Laplacian + tau^2 I)^(-alpha)
alpha = 2;
tau = 3;

%Forcing function, f(x) = 1
%Forcing function, f(x) = 1
f = ones(s,s);

%% Generation
Expand All @@ -25,7 +25,7 @@
thresh_a = zeros(s,s);
thresh_a(norm_a >= 0) = 12;
thresh_a(norm_a < 0) = 4;

%Solve PDE: - div(a(x)*grad(p(x))) = f(x)
p = solve_gwf(thresh_a,f);

Expand All @@ -38,4 +38,4 @@
end

%% Saving
save(filename, "outputs", "inputs");
save(filename, "outputs", "inputs");
55 changes: 28 additions & 27 deletions fno/modules.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from turtle import width
from typing import Callable

import flax.linen as nn
from flax.linen.initializers import normal
from jax import numpy as jnp
from jax import random


def normal(stddev=1e-2, dtype = jnp.float32) -> Callable:
def init(key, shape, dtype=dtype):
keys = random.split(key)
return (random.normal(keys[0], shape) + 1j*random.normal(keys[1], shape)) * stddev
return init

class SpectralConv2d(nn.Module):
out_channels: int = 32
modes1: int = 12
Expand Down Expand Up @@ -35,41 +40,37 @@ def __call__(self, x):
# output signal will have dimensions (N, C, H, W//2+1).
# Therefore the kernel weigths will have different dimensions
# for the two axis.
kernel = self.param(
'kernel',
kernel_1 = self.param(
'kernel_1',
normal(scale, jnp.complex64),
(in_channels, self.out_channels, self.modes1, self.modes2),
jnp.complex64
)
kernel_2 = self.param(
'kernel_2',
normal(scale, jnp.complex64),
(in_channels, self.out_channels, 2*self.modes1+1, self.modes2),
(in_channels, self.out_channels, self.modes1, self.modes2),
jnp.complex64
)

# Perform fft of the input
x_ft = jnp.fft.rfftn(x, axes=(1, 2))

# Shift the zero frequency to the center of the array, only for the
# first axis.
x_ft = jnp.fft.fftshift(x_ft, axes=1)

# Get the center of the spectrum
center_idx = x_ft.shape[1]//2
center_spect = x_ft[:, -self.modes1+center_idx:self.modes1+center_idx+1, :self.modes2,:]

# Multiply the center of the spectrum by the kernel
x_filt = jnp.einsum('bijc,coij->bijo', center_spect, kernel)

# Pad the kernel with zeros to restore the original size
pad_size_1 = (x_ft.shape[1] - 2*self.modes1)//2
pad_size_2 = x_ft.shape[2] - self.modes2
x_filt = jnp.pad(
x_filt,
((0, 0), (pad_size_1, pad_size_1-1), (0, pad_size_2), (0, 0)),
mode='constant'
)

# Restore the zero frequency to the beginning of the array
x_filt = jnp.fft.ifftshift(x_filt, axes=1)
out_ft = jnp.zeros_like(x_ft)
s1 = jnp.einsum(
'bijc,coij->bijo',
x_ft[:, :self.modes1, :self.modes2, :],
kernel_1)
s2 = jnp.einsum(
'bijc,coij->bijo',
x_ft[:, -self.modes1:, :self.modes2, :],
kernel_2)
out_ft = out_ft.at[:, :self.modes1, :self.modes2, :].set(s1)
out_ft = out_ft.at[:, -self.modes1:, :self.modes2, :].set(s2)

# Go back to the spatial domain
y = jnp.fft.irfftn(x_filt, axes=(1, 2))
y = jnp.fft.irfftn(out_ft, axes=(1, 2))

return y

Expand Down
1 change: 0 additions & 1 deletion requirements-train.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,5 @@ wandb
addict
scipy
pytest
optax
tqdm
wandb
44 changes: 21 additions & 23 deletions train_darcy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from addict import Dict
from jax import numpy as jnp
from jax import random
from optax import adamw, apply_updates
from jax.example_libraries import optimizers
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm

Expand All @@ -12,20 +12,20 @@

# Settings dictionary
SETTINGS = Dict()
SETTINGS.data_path = 'data/darcy/darcy_238.mat'
SETTINGS.data_path = 'data/darcy/darcy_211.mat'
SETTINGS.n_train = 1000
SETTINGS.n_test = 200
SETTINGS.batch_size = 20
SETTINGS.learning_rate = 0.0001 # TODO: This should be scheduled
SETTINGS.weight_decay = 1e-4
SETTINGS.n_epochs = 100
SETTINGS.n_epochs = 1000
SETTINGS.nrg = random.PRNGKey(0)

SETTINGS.fno.modes = 12
SETTINGS.fno.width = 32
SETTINGS.fno.depth = 4
SETTINGS.fno.channels_last_proj = 128
SETTINGS.fno.padding = 18
SETTINGS.fno.padding = 45

def main():
# Loading and splitting dataset
Expand All @@ -46,7 +46,7 @@ def main():
test_loader = DataLoader(
test_dataset,
batch_size=SETTINGS.batch_size,
shuffle=False,
shuffle=True,
collate_fn=collate_fn
)

Expand All @@ -57,31 +57,30 @@ def main():
width=SETTINGS.fno.width,
depth=SETTINGS.fno.depth,
channels_last_proj=SETTINGS.fno.channels_last_proj,
padding=SETTINGS.fno.padding
padding=SETTINGS.fno.padding,
)
_x, _ = train_dataset[0]
_x = jnp.expand_dims(_x, axis=0)
_, model_params = model.init_with_output(SETTINGS.nrg, _x)
del _x

# Initialize optimizers
optimizer = adamw(
SETTINGS.learning_rate,
weight_decay=SETTINGS.weight_decay
init_fun, update_fun, get_params = optimizers.adam(
SETTINGS.learning_rate
)
opt_state = optimizer.init(model_params)
opt_state = init_fun(model_params)

# Define loss function
def loss_fn(params, x, y):
y_pred = model.apply(params, x)
return jnp.mean(jnp.square(y - y_pred))

@jax.jit
def update(params, opt_state, x, y):
def update(opt_state, x, y, step):
params = get_params(opt_state)
lossval, grads = jax.value_and_grad(loss_fn)(params, x, y)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = apply_updates(params, updates)
return params, opt_state, lossval
opt_state = update_fun(step, grads, opt_state)
return opt_state, lossval

# Initialize wandb
print("Training...")
Expand All @@ -92,28 +91,27 @@ def update(params, opt_state, x, y):
for epoch in range(SETTINGS.n_epochs):
print('Epoch {}'.format(epoch))

# Log a training image
_x, _y = train_dataset[0]
_x, _y = jnp.expand_dims(_x, axis=0), jnp.expand_dims(_y, axis=0)
_y_pred = model.apply(model_params, _x)
log_wandb_image(wandb, "Training image", step, _x[0], _y[0], _y_pred[0])

# Perform one epoch of training
with tqdm(train_loader, unit="batch") as tepoch:
for batch in tepoch:
tepoch.set_description(f"Epoch {epoch}")

# Update parameters
x, y = batch
model_params, opt_state, lossval = update(
model_params, opt_state, x, y
)
opt_state, lossval = update(opt_state, x, y, step)

# Log
wandb.log({"loss": lossval}, step=step)
tepoch.set_postfix(loss=lossval)
step += 1

# Get new parameters
model_params = get_params(opt_state)

# Log a training image
y_pred = model.apply(model_params, x)
log_wandb_image(wandb, "Training image", step, x[0], y[0], y_pred[0])

# Validation
avg_loss = 0
val_steps = 0
Expand Down

0 comments on commit 4a85333

Please sign in to comment.