Skip to content

Commit

Permalink
Update to PLDI artifact.
Browse files Browse the repository at this point in the history
  • Loading branch information
femtomc committed Mar 12, 2024
1 parent 81e7590 commit 18101f2
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 7 deletions.
22 changes: 20 additions & 2 deletions experiments.ipynb

Large diffs are not rendered by default.

14 changes: 13 additions & 1 deletion experiments/fig_7_air_estimator_evaluation/air_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
hybrid_iwae_mvd_air = pd.read_csv(
"./training_runs/genjax_air_iwae_2_hybrid_mvd_enum_epochs_41.csv",
)
rws_air_mvd = pd.read_csv("./training_runs/genjax_air_rws_10_mvd_epochs_41.csv")
rws_air_mvd = pd.read_csv("./training_runs/genjax_air_rws_10_mvd_epochs_6.csv")
rws_air_mvd_bs1 = pd.read_csv("./training_runs/genjax_air_rws_10_mvd_epochs_6_bs1.csv")
pyro_reinforce_air = pd.read_csv(
"./training_runs/pyro_air_reinforce_epochs_41.csv",
)
Expand Down Expand Up @@ -336,6 +337,17 @@ def go_plot_rws(ax, df, x, mean, label, cmap, color_idx, marker):
num_lines = 2
cmap = plt.cm.get_cmap("cividis", num_lines)

rws_air_l = go_plot_rws(
ax3,
rws_air_mvd_bs1,
"Epoch wall clock times",
"Accuracy",
"Ours (batch size = 1, RWS(K = 10))",
cmap,
0,
"x",
)

rws_air_l = go_plot_rws(
ax3,
rws_air_mvd,
Expand Down
8 changes: 4 additions & 4 deletions experiments/fig_7_air_estimator_evaluation/genjax_rws_air.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,22 +789,22 @@ def body_fn(carry, xs):

key, sub_key = jax.random.split(key)
(p_losses, q_losses), accuracy, wall_clock_times, params = train(
sub_key, learning_rate=1.0e-3, n=10, batch_size=64, num_epochs=40
sub_key, learning_rate=1.0e-3, n=10, batch_size=64, num_epochs=5
)

arr = np.array([p_losses, q_losses, accuracy, wall_clock_times])
df = pd.DataFrame(
arr.T, columns=["P Loss", "Q Loss", "Accuracy", "Epoch wall clock times"]
)
df.to_csv("./training_runs/genjax_air_rws_10_mvd_epochs_41.csv", index=False)
df.to_csv("./training_runs/genjax_air_rws_10_mvd_epochs_6.csv", index=False)

key, sub_key = jax.random.split(key)
(p_losses, q_losses), accuracy, wall_clock_times, params = train(
sub_key, learning_rate=1.0e-4, n=10, batch_size=1, num_epochs=40
sub_key, learning_rate=1.0e-4, n=10, batch_size=1, num_epochs=5
)

arr = np.array([p_losses, q_losses, accuracy, wall_clock_times])
df = pd.DataFrame(
arr.T, columns=["P Loss", "Q Loss", "Accuracy", "Epoch wall clock times"]
)
df.to_csv("./training_runs/genjax_air_rws_10_mvd_epochs_41_bs1.csv", index=False)
df.to_csv("./training_runs/genjax_air_rws_10_mvd_epochs_6_bs1.csv", index=False)

0 comments on commit 18101f2

Please sign in to comment.