-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b03147c
commit 64c901c
Showing
6 changed files
with
222 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,3 +12,6 @@ | |
|
||
# Python | ||
*__pycache__ | ||
|
||
# Logging | ||
wandb |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,3 +3,6 @@ wandb | |
addict | ||
scipy | ||
pytest | ||
optax | ||
tqdm | ||
wandb |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
from typing import Tuple | ||
|
||
import numpy as np | ||
from matplotlib import pyplot as plt | ||
from scipy.io import loadmat | ||
from torch.utils.data import Dataset | ||
|
||
|
||
def gaussian_normalize( | ||
x: np.ndarray, | ||
eps=0.00001 | ||
) -> Tuple[np.ndarray, float, float]: | ||
'''Adapted from | ||
https://github.com/zongyi-li/fourier_neural_operator/blob/c13b475dcc9bcd855d959851104b770bbcdd7c79/utilities3.py#L73 | ||
Attributes: | ||
x (np.ndarray): Input array | ||
eps (float): Small number to avoid division by zero | ||
Returns: | ||
Tuple[np.ndarray, float, float]: Normalized array, mean and standard deviation | ||
''' | ||
mean = np.mean(x, 0, keepdims=True) | ||
std = np.std(x, 0, keepdims=True) | ||
x = (x - mean) / (std + eps) | ||
return x, mean, std | ||
|
||
class MatlabDataset(Dataset): | ||
def __init__(self, path: str): | ||
# Load matfile | ||
mat = loadmat(path) | ||
|
||
# Extract data and put batch dimension in front | ||
self.x = mat['inputs'].astype(np.float32) | ||
self.y = mat['outputs'].astype(np.float32) | ||
self.x = np.moveaxis(self.x, -1, 0) | ||
self.y = np.moveaxis(self.y, -1, 0) | ||
|
||
# Add channel dimension at the end | ||
self.x = self.x[..., np.newaxis] | ||
self.y = self.y[..., np.newaxis] | ||
|
||
def __len__(self): | ||
return self.x.shape[0] | ||
|
||
def __getitem__(self, idx): | ||
return self.x[idx], self.y[idx] | ||
|
||
def collate_fn(batch): | ||
x, y = zip(*batch) | ||
x = np.stack(x, axis=0) | ||
y = np.stack(y, axis=0) | ||
return x, y | ||
|
||
def log_wandb_image( | ||
wandb, | ||
name: str, | ||
step: int, | ||
x: np.ndarray, | ||
y: np.ndarray, | ||
y_pred: np.ndarray | ||
) -> None: | ||
fig, ax = plt.subplots(1, 3, figsize=(12, 4)) | ||
|
||
ax[0].imshow(x, cmap="inferno") | ||
ax[0].set_title("Input map") | ||
|
||
ax[1].imshow(y, cmap="inferno") | ||
ax[1].set_title("Target field") | ||
|
||
ax[2].imshow(y_pred, cmap="inferno") | ||
ax[2].set_title("Predicted field") | ||
|
||
img = wandb.Image(plt) | ||
wandb.log({name: img}, step=step) | ||
plt.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
import jax | ||
from addict import Dict | ||
from jax import numpy as jnp | ||
from jax import random | ||
from optax import adamw, apply_updates | ||
from torch.utils.data import DataLoader, random_split | ||
from tqdm import tqdm | ||
|
||
import wandb | ||
from fno import FNO2D | ||
from scripts.utils import MatlabDataset, collate_fn, log_wandb_image | ||
|
||
# Settings dictionary | ||
SETTINGS = Dict() | ||
SETTINGS.data_path = 'data/darcy/darcy_238.mat' | ||
SETTINGS.n_train = 1000 | ||
SETTINGS.n_test = 200 | ||
SETTINGS.batch_size = 20 | ||
SETTINGS.learning_rate = 0.0001 # TODO: This should be scheduled | ||
SETTINGS.weight_decay = 1e-4 | ||
SETTINGS.n_epochs = 100 | ||
SETTINGS.nrg = random.PRNGKey(0) | ||
|
||
SETTINGS.fno.modes = 12 | ||
SETTINGS.fno.width = 32 | ||
SETTINGS.fno.depth = 4 | ||
SETTINGS.fno.channels_last_proj = 128 | ||
SETTINGS.fno.padding = 18 | ||
|
||
def main(): | ||
# Loading and splitting dataset | ||
dataset = MatlabDataset(SETTINGS.data_path) | ||
train_dataset, test_dataset = random_split( | ||
dataset, | ||
[SETTINGS.n_train, SETTINGS.n_test] | ||
) | ||
|
||
# Making dataloaders | ||
train_loader = DataLoader( | ||
train_dataset, | ||
batch_size=SETTINGS.batch_size, | ||
shuffle=True, | ||
collate_fn=collate_fn, | ||
drop_last=True | ||
) | ||
test_loader = DataLoader( | ||
test_dataset, | ||
batch_size=SETTINGS.batch_size, | ||
shuffle=False, | ||
collate_fn=collate_fn | ||
) | ||
|
||
# Initialize model | ||
model = FNO2D( | ||
modes1=SETTINGS.fno.modes, | ||
modes2=SETTINGS.fno.modes, | ||
width=SETTINGS.fno.width, | ||
depth=SETTINGS.fno.depth, | ||
channels_last_proj=SETTINGS.fno.channels_last_proj, | ||
padding=SETTINGS.fno.padding | ||
) | ||
_x, _ = train_dataset[0] | ||
_x = jnp.expand_dims(_x, axis=0) | ||
_, model_params = model.init_with_output(SETTINGS.nrg, _x) | ||
del _x | ||
|
||
# Initialize optimizers | ||
optimizer = adamw( | ||
SETTINGS.learning_rate, | ||
weight_decay=SETTINGS.weight_decay | ||
) | ||
opt_state = optimizer.init(model_params) | ||
|
||
# Define loss function | ||
def loss_fn(params, x, y): | ||
y_pred = model.apply(params, x) | ||
return jnp.mean(jnp.square(y - y_pred)) | ||
|
||
@jax.jit | ||
def update(params, opt_state, x, y): | ||
lossval, grads = jax.value_and_grad(loss_fn)(params, x, y) | ||
updates, opt_state = optimizer.update(grads, opt_state, params) | ||
params = apply_updates(params, updates) | ||
return params, opt_state, lossval | ||
|
||
# Initialize wandb | ||
print("Training...") | ||
wandb.init('fourier-neural-operator') | ||
|
||
# Training loop | ||
step = 0 | ||
for epoch in range(SETTINGS.n_epochs): | ||
print('Epoch {}'.format(epoch)) | ||
|
||
# Log a training image | ||
_x, _y = train_dataset[0] | ||
_x, _y = jnp.expand_dims(_x, axis=0), jnp.expand_dims(_y, axis=0) | ||
_y_pred = model.apply(model_params, _x) | ||
log_wandb_image(wandb, "Training image", step, _x[0], _y[0], _y_pred[0]) | ||
|
||
# Perform one epoch of training | ||
with tqdm(train_loader, unit="batch") as tepoch: | ||
for batch in tepoch: | ||
tepoch.set_description(f"Epoch {epoch}") | ||
|
||
# Update parameters | ||
x, y = batch | ||
model_params, opt_state, lossval = update( | ||
model_params, opt_state, x, y | ||
) | ||
|
||
# Log | ||
wandb.log({"loss": lossval}, step=step) | ||
tepoch.set_postfix(loss=lossval) | ||
step += 1 | ||
|
||
# Validation | ||
avg_loss = 0 | ||
val_steps = 0 | ||
with tqdm(test_loader, unit="batch") as tval: | ||
for batch in tval: | ||
tval.set_description(f"Epoch (val) {epoch}") | ||
x, y = batch | ||
lossval = loss_fn(model_params, x, y) | ||
avg_loss += lossval*len(x) | ||
|
||
tval.set_postfix(loss=lossval) | ||
val_steps += 1 | ||
|
||
wandb.log({"val_loss": lossval/val_steps}, step=step) | ||
|
||
# Log validation image | ||
_x, _y = test_dataset[0] | ||
_x, _y = jnp.expand_dims(_x, axis=0), jnp.expand_dims(_y, axis=0) | ||
_y_pred = model.apply(model_params, _x) | ||
log_wandb_image(wandb, "Validation image", step, _x[0], _y[0], _y_pred[0]) | ||
|
||
if __name__ == '__main__': | ||
main() |