Skip to content

Commit

Permalink
darcy training
Browse files Browse the repository at this point in the history
  • Loading branch information
astanziola committed Apr 11, 2022
1 parent b03147c commit 64c901c
Show file tree
Hide file tree
Showing 6 changed files with 222 additions and 6 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,6 @@

# Python
*__pycache__

# Logging
wandb
7 changes: 1 addition & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,4 @@ To install `fno`, you just need to have `flax` installed (see the `requirements.

Then:
- To generate the data you need `MATLAB` (or, probably, Octave: haven't tested it).
- To run the training scripts, you'll need:
- `scipy` for loading the data in the `.mat` files
- `torch` for the `Dataset` and `DataLoader` classes
- [`wandb`](https://wandb.ai/site) for logging
- `matplotlib` for producing the plots.
- [`addict`](https://github.com/mewwts/addict) for cool `Dict` objects.
- To run the training scripts, you'll need to install the packages in `requirements-train.txt`, for example using `pip install -r requirements-train.txt`.
3 changes: 3 additions & 0 deletions requirements-train.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,6 @@ wandb
addict
scipy
pytest
optax
tqdm
wandb
Empty file added scripts/__init__.py
Empty file.
76 changes: 76 additions & 0 deletions scripts/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from typing import Tuple

import numpy as np
from matplotlib import pyplot as plt
from scipy.io import loadmat
from torch.utils.data import Dataset


def gaussian_normalize(
x: np.ndarray,
eps=0.00001
) -> Tuple[np.ndarray, float, float]:
'''Adapted from
https://github.com/zongyi-li/fourier_neural_operator/blob/c13b475dcc9bcd855d959851104b770bbcdd7c79/utilities3.py#L73
Attributes:
x (np.ndarray): Input array
eps (float): Small number to avoid division by zero
Returns:
Tuple[np.ndarray, float, float]: Normalized array, mean and standard deviation
'''
mean = np.mean(x, 0, keepdims=True)
std = np.std(x, 0, keepdims=True)
x = (x - mean) / (std + eps)
return x, mean, std

class MatlabDataset(Dataset):
def __init__(self, path: str):
# Load matfile
mat = loadmat(path)

# Extract data and put batch dimension in front
self.x = mat['inputs'].astype(np.float32)
self.y = mat['outputs'].astype(np.float32)
self.x = np.moveaxis(self.x, -1, 0)
self.y = np.moveaxis(self.y, -1, 0)

# Add channel dimension at the end
self.x = self.x[..., np.newaxis]
self.y = self.y[..., np.newaxis]

def __len__(self):
return self.x.shape[0]

def __getitem__(self, idx):
return self.x[idx], self.y[idx]

def collate_fn(batch):
x, y = zip(*batch)
x = np.stack(x, axis=0)
y = np.stack(y, axis=0)
return x, y

def log_wandb_image(
wandb,
name: str,
step: int,
x: np.ndarray,
y: np.ndarray,
y_pred: np.ndarray
) -> None:
fig, ax = plt.subplots(1, 3, figsize=(12, 4))

ax[0].imshow(x, cmap="inferno")
ax[0].set_title("Input map")

ax[1].imshow(y, cmap="inferno")
ax[1].set_title("Target field")

ax[2].imshow(y_pred, cmap="inferno")
ax[2].set_title("Predicted field")

img = wandb.Image(plt)
wandb.log({name: img}, step=step)
plt.close()
139 changes: 139 additions & 0 deletions train_darcy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import jax
from addict import Dict
from jax import numpy as jnp
from jax import random
from optax import adamw, apply_updates
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm

import wandb
from fno import FNO2D
from scripts.utils import MatlabDataset, collate_fn, log_wandb_image

# Settings dictionary
SETTINGS = Dict()
SETTINGS.data_path = 'data/darcy/darcy_238.mat'
SETTINGS.n_train = 1000
SETTINGS.n_test = 200
SETTINGS.batch_size = 20
SETTINGS.learning_rate = 0.0001 # TODO: This should be scheduled
SETTINGS.weight_decay = 1e-4
SETTINGS.n_epochs = 100
SETTINGS.nrg = random.PRNGKey(0)

SETTINGS.fno.modes = 12
SETTINGS.fno.width = 32
SETTINGS.fno.depth = 4
SETTINGS.fno.channels_last_proj = 128
SETTINGS.fno.padding = 18

def main():
# Loading and splitting dataset
dataset = MatlabDataset(SETTINGS.data_path)
train_dataset, test_dataset = random_split(
dataset,
[SETTINGS.n_train, SETTINGS.n_test]
)

# Making dataloaders
train_loader = DataLoader(
train_dataset,
batch_size=SETTINGS.batch_size,
shuffle=True,
collate_fn=collate_fn,
drop_last=True
)
test_loader = DataLoader(
test_dataset,
batch_size=SETTINGS.batch_size,
shuffle=False,
collate_fn=collate_fn
)

# Initialize model
model = FNO2D(
modes1=SETTINGS.fno.modes,
modes2=SETTINGS.fno.modes,
width=SETTINGS.fno.width,
depth=SETTINGS.fno.depth,
channels_last_proj=SETTINGS.fno.channels_last_proj,
padding=SETTINGS.fno.padding
)
_x, _ = train_dataset[0]
_x = jnp.expand_dims(_x, axis=0)
_, model_params = model.init_with_output(SETTINGS.nrg, _x)
del _x

# Initialize optimizers
optimizer = adamw(
SETTINGS.learning_rate,
weight_decay=SETTINGS.weight_decay
)
opt_state = optimizer.init(model_params)

# Define loss function
def loss_fn(params, x, y):
y_pred = model.apply(params, x)
return jnp.mean(jnp.square(y - y_pred))

@jax.jit
def update(params, opt_state, x, y):
lossval, grads = jax.value_and_grad(loss_fn)(params, x, y)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = apply_updates(params, updates)
return params, opt_state, lossval

# Initialize wandb
print("Training...")
wandb.init('fourier-neural-operator')

# Training loop
step = 0
for epoch in range(SETTINGS.n_epochs):
print('Epoch {}'.format(epoch))

# Log a training image
_x, _y = train_dataset[0]
_x, _y = jnp.expand_dims(_x, axis=0), jnp.expand_dims(_y, axis=0)
_y_pred = model.apply(model_params, _x)
log_wandb_image(wandb, "Training image", step, _x[0], _y[0], _y_pred[0])

# Perform one epoch of training
with tqdm(train_loader, unit="batch") as tepoch:
for batch in tepoch:
tepoch.set_description(f"Epoch {epoch}")

# Update parameters
x, y = batch
model_params, opt_state, lossval = update(
model_params, opt_state, x, y
)

# Log
wandb.log({"loss": lossval}, step=step)
tepoch.set_postfix(loss=lossval)
step += 1

# Validation
avg_loss = 0
val_steps = 0
with tqdm(test_loader, unit="batch") as tval:
for batch in tval:
tval.set_description(f"Epoch (val) {epoch}")
x, y = batch
lossval = loss_fn(model_params, x, y)
avg_loss += lossval*len(x)

tval.set_postfix(loss=lossval)
val_steps += 1

wandb.log({"val_loss": lossval/val_steps}, step=step)

# Log validation image
_x, _y = test_dataset[0]
_x, _y = jnp.expand_dims(_x, axis=0), jnp.expand_dims(_y, axis=0)
_y_pred = model.apply(model_params, _x)
log_wandb_image(wandb, "Validation image", step, _x[0], _y[0], _y_pred[0])

if __name__ == '__main__':
main()

0 comments on commit 64c901c

Please sign in to comment.