Skip to content

Commit

Permalink
Small README nit.
Browse files Browse the repository at this point in the history
  • Loading branch information
femtomc committed Mar 12, 2024
1 parent 6a50998 commit 41f41f8
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions experiments/fig_7_air_estimator_evaluation/genjax_rws_air.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,3 +796,14 @@ def body_fn(carry, xs):
arr.T, columns=["P Loss", "Q Loss", "Accuracy", "Epoch wall clock times"]
)
df.to_csv("./training_runs/genjax_air_rws_10_mvd_epochs_41.csv", index=False)

key, sub_key = jax.random.split(key)
(p_losses, q_losses), accuracy, wall_clock_times, params = train(
sub_key, learning_rate=1.0e-4, n=10, batch_size=1, num_epochs=40
)

arr = np.array([p_losses, q_losses, accuracy, wall_clock_times])
df = pd.DataFrame(
arr.T, columns=["P Loss", "Q Loss", "Accuracy", "Epoch wall clock times"]
)
df.to_csv("./training_runs/genjax_air_rws_10_mvd_epochs_41_bs1.csv", index=False)

0 comments on commit 41f41f8

Please sign in to comment.