Skip to content

Commit

Permalink
fix: Added weight decayl. All weights are real, otherwise many optimi…
Browse files Browse the repository at this point in the history
…zation libraries and algorithms (e.g. adam, momentum based, etc) fail
  • Loading branch information
astanziola committed Apr 12, 2022
1 parent 89156c0 commit 26c5064
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 11 deletions.
34 changes: 23 additions & 11 deletions fno/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
def normal(stddev=1e-2, dtype = jnp.float32) -> Callable:
def init(key, shape, dtype=dtype):
keys = random.split(key)
return (random.normal(keys[0], shape) + 1j*random.normal(keys[1], shape)) * stddev
return random.normal(keys[0], shape) * stddev
return init

class SpectralConv2d(nn.Module):
Expand Down Expand Up @@ -40,17 +40,29 @@ def __call__(self, x):
# output signal will have dimensions (N, C, H, W//2+1).
# Therefore the kernel weigths will have different dimensions
# for the two axis.
kernel_1 = self.param(
'kernel_1',
normal(scale, jnp.complex64),
kernel_1_r = self.param(
'kernel_1_r',
normal(scale, jnp.float32),
(in_channels, self.out_channels, self.modes1, self.modes2),
jnp.complex64
jnp.float32
)
kernel_2 = self.param(
'kernel_2',
normal(scale, jnp.complex64),
kernel_1_i = self.param(
'kernel_1_i',
normal(scale, jnp.float32),
(in_channels, self.out_channels, self.modes1, self.modes2),
jnp.complex64
jnp.float32
)
kernel_2_r = self.param(
'kernel_2_r',
normal(scale, jnp.float32),
(in_channels, self.out_channels, self.modes1, self.modes2),
jnp.float32
)
kernel_2_i = self.param(
'kernel_2_i',
normal(scale, jnp.float32),
(in_channels, self.out_channels, self.modes1, self.modes2),
jnp.float32
)

# Perform fft of the input
Expand All @@ -61,11 +73,11 @@ def __call__(self, x):
s1 = jnp.einsum(
'bijc,coij->bijo',
x_ft[:, :self.modes1, :self.modes2, :],
kernel_1)
kernel_1_r + 1j*kernel_1_i)
s2 = jnp.einsum(
'bijc,coij->bijo',
x_ft[:, -self.modes1:, :self.modes2, :],
kernel_2)
kernel_2_r + 1j*kernel_2_i)
out_ft = out_ft.at[:, :self.modes1, :self.modes2, :].set(s1)
out_ft = out_ft.at[:, -self.modes1:, :self.modes2, :].set(s2)

Expand Down
11 changes: 11 additions & 0 deletions train_darcy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import jax
from addict import Dict
from flax.core.frozen_dict import freeze
from jax import numpy as jnp
from jax import random
from jax.example_libraries import optimizers
Expand Down Expand Up @@ -79,6 +80,16 @@ def loss_fn(params, x, y):
def update(opt_state, x, y, step):
params = get_params(opt_state)
lossval, grads = jax.value_and_grad(loss_fn)(params, x, y)

# Add weight decay
grads = {'params' : jax.tree_multimap(
lambda g, p: g + SETTINGS.weight_decay * p,
grads['params'].unfreeze(), params['params'].unfreeze()
)
}
grads = freeze(grads)


opt_state = update_fun(step, grads, opt_state)
return opt_state, lossval

Expand Down

0 comments on commit 26c5064

Please sign in to comment.