Skip to content

Commit

Permalink
Dev/grad checkpoint (#48)
Browse files Browse the repository at this point in the history
* Rollout with gradient checkpointing

* Dev/enable cpu (#49)

* enable cpu training

* Update read me

* Remove grad checkpointing from this PR

---------

Co-authored-by: Krishna Kumar <[email protected]>

* Fix CPU CI testing (#50)

* Print test paths and sample folder

* Fix path to dataset in CI test

* Refactor training loop for CPU and GPU and cleanup only for GPU

* Add save files for cpu version (#51)

* Print test paths and sample folder

* Fix path to dataset in CI test

* Refactor training loop for CPU and GPU and cleanup only for GPU

* Test rollout prediction

* Try for EOF issue

* Try removing .git at clone step

* Fix rollout path

* Debug to see if model files are written

* Debug model path for rollout

* Save steps embedded in cpu mode as well

* Save model file in CPU

* Add instructions to test

* Dev/cpu env (#52)

* Rollout with gradient checkpointing

* enable cpu training

* Latest torch cluster installation in conda

* Remove unrelated files

* Add yes to installation with conda

* Use structured arrays to store positions and particle types (#53)

* Fixes #56 references.bib

* Fixes #57 paranthesis in references

* Fixes #54 DOIs

* Fixes #55 unclosed paranthesis in paper

* Update citation to v1.1.0

* Update title in citation

* Update citation file to point to doi on zeondo

* add a doc for explaining details about training data (#59)

* add a doc for explaining details about training data

* Add image

* Change heading levels

* Add MeshNet to _sidebar.md

* JOSS citation.cff file

* JOSS Citation in README

* Bug fix in cpu and gpu conditioning (#60)

* bug fix for cpu and gpu conditioning

* include loss in training_state file

* Use rank if cuda, else use device (which is cpu), to deal with different behaviors depending on cpu and gpu.

* reset flags

* JOSS citation.cff file

* JOSS Citation in README

* merge the recent upstream change

* rollback flags

* rollback flags

* typo

* bug fix for cpu and gpu conditioning

* include loss in training_state file

* Use rank if cuda, else use device (which is cpu), to deal with different behaviors depending on cpu and gpu.

* reset flags

* rollback flags

* rollback flags

* typo

* Define device id

---------

Co-authored-by: Krishna Kumar <[email protected]>

* Include grad checkpoint feature into the rollout function.

* `data_loader.py` accepts `.npz` with material property feature

* add boundary clamp limit option

* Enable `learned_simulator.py` to accept material property feature

* Add feature that can take optional metadata depending on whether `rollout` or `train` mode

* Update render_rollout.py

* Rollback to regular rollout without grad checkpoint, enable taking material property feature

* Rollback to regular `predict.py` without grad checkpoint, enable taking material property feature

* Enable `train.py` to take material property feature

* fixup! add boundary clamp limit option

* add another way of init simulator

* remove redundant parameters in `predict`

* add example for solving inverse problem using gns

* Remove previous grad checkpoint file

* Update requirements.txt

* Add inputfile flag

* Add config file for the inverse analysis

* Add readme

* Fix minor errors related to material property feature conditioning

* Update test

* Add figs for inverse example

* Update animation

* Update figs

* fix fig for initial config

* Move inverse example doc to `/docs`.

* Add rollout function's doc string

* Fix dataloader

---------

Co-authored-by: Krishna Kumar <[email protected]>
Co-authored-by: Krishna Kumar <[email protected]>
  • Loading branch information
3 people authored Oct 9, 2023
1 parent cb73499 commit 169713a
Show file tree
Hide file tree
Showing 18 changed files with 803 additions and 119 deletions.
104 changes: 104 additions & 0 deletions docs/example-1.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Solving Inverse Problem in Granular Flow Using GNS
This example shows an example for solving inverse problem in granular flow using GNS.
The example uses the gradient-based optimization method using
the fully differential nature of GNS and automatic differentiation (AD).

## Problem Statement

Consider the multi-layered granular column which has different initial velocity for each layer
(see the figure below).

![initial_condition](img/initial_vel.png)

The ground truth simulation result of the above configuration
using material point method (MPM) is as follows.

![simulation_mpm](img/true_ani.gif)

The objective of the inverse problem in this example is to estimate the initial velocity
of each layer only with the information about the final deposit for the last few timesteps.

## Data
### Download link
The necessary data for this example can be downloaded [here](https://utexas.box.com/s/i4x1n1gzb7r27ccfqr963xpzc3jtxe59).

### Description
* `particle_group.txt`: Particle coordinate information for each layer
* `model.pt`: GNS simulator
* `gns_metadata.json`: Configuration file for GNS simulator
* `mpm_input.json`: Information about ground truth simulation (MPM)

By default configuration, it is recommended to store the above files under `./data` directory.

### Configuration file
```toml
# Top-level entries
path = "data/"

# Optimization sub-table
[optimization]
nepoch = 30
inverse_timestep_range = [300, 380]
checkpoint_interval = 1
lr = 0.1
initial_velocities = [[0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0]]

# Ground Truth sub-table
[ground_truth]
ground_truth_npz = "sand2d_inverse_eval30.npz"
ground_truth_mpm_inputfile = "mpm_input.json"

# Forward Simulator sub-table
[forward_simulator]
dt_mpm = 0.0025
model_path = "data/"
model_file = "model.pt"
simulator_metadata_path = "data/"
simulator_metadata_file = "gns_metadata.json"

# Resume sub-table
[resume]
resume = false
epoch = 1

# Output sub-table
[output]
output_dir = "data/outputs/"
save_step = 1
```

## Core Features
### Gradient checkpoint
The downside of using AD is that it requires a significant amount of memory for
large-scale neural networks because it retains all the activations for
all the intermediate layers during the backpropagation.
Since GNS computes $\boldsymbol{X}_t\rightarrow \boldsymbol{X}_{t+1}$ using multiple MLPs
between the large number of edges, and the entire simulation even entails the accumulation of GNS
$\boldsymbol{X}_t\rightarrow \boldsymbol{X}_{t+1}$ for $k$ steps,
computing gradients requires extensive memory capacity.

To mitigate this issue, we employ gradient checkpointing as an effective solution.
This technique allows us to significantly reduce memory consumption by selectively storing
only certain intermediate activations during the forward pass,
and then recomputing the omitted values as needed during the backward pass.
This enables substantial memory savings, so we can conduct backpropagation
at the desired timestep $k$.


### Resume
The code returns the optimization status with the specified step interval.
By taking the optimization status as the input, we can resume the optimization from the
previous state.

## Result

The optimization result for the velocities are as follows. As the iteration increases,
the velocity profile becomes simular to the true value (black line).

![vel_hist](img/vel_hist.png)

The GNS simulation result with the estimated velocity values at `iteration=29` is as follows.
The result shows a good agreement with the ground truth simulation.

![gns_ani](img/pred_ani.gif)

Empty file added docs/img/initial_condition.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/img/initial_vel.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/img/loss_hist.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/img/pred_ani.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/img/true_ani.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/img/vel_hist.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
34 changes: 34 additions & 0 deletions example/inverse_problem/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Top-level entries
path = "data/"

# Optimization sub-table
[optimization]
nepoch = 30
inverse_timestep_range = [300, 380]
checkpoint_interval = 1
lr = 0.1
initial_velocities = [[0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0]]

# Ground Truth sub-table
[ground_truth]
ground_truth_npz = "sand2d_inverse_eval30.npz"
ground_truth_mpm_inputfile = "mpm_input.json"

# Forward Simulator sub-table
[forward_simulator]
dt_mpm = 0.0025
model_path = "data/"
model_file = "model-7020000.pt"
simulator_metadata_path = "data/"
simulator_metadata_file = "gns_metadata.json"

# Resume sub-table
[resume]
resume = false
epoch = 1

# Output sub-table
[output]
output_dir = "data/outputs/"
save_step = 1

57 changes: 57 additions & 0 deletions example/inverse_problem/forward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import torch
from gns import learned_simulator
from tqdm import tqdm


def rollout_with_checkpointing(
simulator: learned_simulator.LearnedSimulator,
initial_positions: torch.tensor,
particle_types: torch.tensor,
n_particles_per_example: torch.tensor,
nsteps: int,
checkpoint_interval: int = 1,
material_property: torch.tensor = None
):
""" Rollout with gradient checkpointing to reduce memory accumulation over the forward steps during backpropagation.
Args:
simulator: learned_simulator
initial_positions: initial positions of particles for 6 timesteps with shape=(nparticles, 6, ndims).
particle_types: particle types shape=(nparticles, ).
n_particles_per_example: number of particles.
nsteps: number of forward steps to rollout.
checkpoint_interval: frequency of gradient checkpointing.
material_property: Friction angle normalized by tan() with shape (nparticles, )
Returns:
GNS rollout of particles positions
"""

current_positions = initial_positions
predictions = []

for step in tqdm(range(nsteps), total=nsteps):
if step % checkpoint_interval == 0:
next_position = torch.utils.checkpoint.checkpoint(
simulator.predict_positions,
current_positions,
[n_particles_per_example],
particle_types,
material_property
)
else:
next_position = simulator.predict_positions(
current_positions,
[n_particles_per_example],
particle_types,
material_property
)

predictions.append(next_position)

# Shift `current_positions`, removing the oldest position in the sequence
# and appending the next position at the end.
current_positions = torch.cat(
[current_positions[:, 1:], next_position[:, None, :]], dim=1)

return torch.cat(
(initial_positions.permute(1, 0, 2), torch.stack(predictions))
)
Loading

0 comments on commit 169713a

Please sign in to comment.