diff --git a/experiments/fig_7_air_estimator_evaluation/genjax_rws_air.py b/experiments/fig_7_air_estimator_evaluation/genjax_rws_air.py index 051f4a6..c63976e 100644 --- a/experiments/fig_7_air_estimator_evaluation/genjax_rws_air.py +++ b/experiments/fig_7_air_estimator_evaluation/genjax_rws_air.py @@ -796,3 +796,14 @@ def body_fn(carry, xs): arr.T, columns=["P Loss", "Q Loss", "Accuracy", "Epoch wall clock times"] ) df.to_csv("./training_runs/genjax_air_rws_10_mvd_epochs_41.csv", index=False) + +key, sub_key = jax.random.split(key) +(p_losses, q_losses), accuracy, wall_clock_times, params = train( + sub_key, learning_rate=1.0e-4, n=10, batch_size=1, num_epochs=40 +) + +arr = np.array([p_losses, q_losses, accuracy, wall_clock_times]) +df = pd.DataFrame( + arr.T, columns=["P Loss", "Q Loss", "Accuracy", "Epoch wall clock times"] +) +df.to_csv("./training_runs/genjax_air_rws_10_mvd_epochs_41_bs1.csv", index=False)