Skip to content

Commit

Permalink
Update to PLDI artifact.
Browse files Browse the repository at this point in the history
  • Loading branch information
femtomc committed Mar 12, 2024
1 parent 41f41f8 commit 81e7590
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 18 deletions.
6 changes: 5 additions & 1 deletion experiments.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1939,7 +1939,11 @@
"Epoch=28, current_epoch_step_time=96.85, loss=-619.26\n",
"accuracy=0.821066677570343, counts=tensor([[19725, 223, 0, 0],\n",
" [ 252, 11009, 8669, 90],\n",
" [ 0, 1008, 18530, 494]])\n"
" [ 0, 1008, 18530, 494]])\n",
"Epoch=29, current_epoch_step_time=97.58, loss=-619.89\n",
"accuracy=0.7887166738510132, counts=tensor([[19670, 278, 0, 0],\n",
" [ 112, 8512, 11172, 224],\n",
" [ 0, 317, 19141, 574]])\n"
]
}
],
Expand Down
17 changes: 4 additions & 13 deletions experiments/fig_7_air_estimator_evaluation/genjax_reinforce_air.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,24 +427,15 @@ def guide_step(
rnn_input = jnp.concatenate([data, prev_z_where, prev_z_what, prev_z_pres])
h, c = rnn(rnn_input, (prev_h, prev_c))
z_pres_p, z_where_loc, z_where_scale = predict(h)
z_pres_p = z_pres_p[0] * prev_z_pres[0]
z_pres_p = jnp.clip(z_pres_p, 0.001, 1.0)
z_pres = vi.flip_reinforce(z_pres_p) @ f"z_pres_{t}"
(z_where_loc, z_where_scale) = jtu.tree_map(
lambda v1, v2: z_pres * v1 + (1 - z_pres) * v2,
(z_where_loc, z_where_scale),
(z_where_prior_loc, z_where_prior_scale),
z_pres = (
vi.flip_reinforce((eps + (z_pres_p[0] * prev_z_pres[0])) / (1 + 1.01 * eps))
@ f"z_pres_{t}"
)
z_pres = jnp.array([z_pres.astype(int)])
z_where = vi.mv_normal_diag_reparam(z_where_loc, z_where_scale) @ f"z_where_{t}"
x_att = image_to_object(z_where, data)
z_what_loc, z_what_scale = encoder(x_att)
(z_what_loc, z_what_scale) = jtu.tree_map(
lambda v1, v2: z_pres * v1 + (1 - z_pres) * v2,
(z_what_loc, z_what_scale),
(z_what_prior_loc, z_what_prior_scale),
)
z_what = vi.mv_normal_diag_reparam(z_what_loc, z_what_scale) @ f"z_what_{t}"
z_pres = jnp.array([z_pres.astype(int)])
return z_where, z_what, z_pres, h, c


Expand Down
9 changes: 5 additions & 4 deletions experiments/fig_7_air_estimator_evaluation/genjax_rws_air.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,8 @@ def image_to_object(z_where, image):
z_where_prior_scale = jnp.array([0.2, 1.0, 1.0])
z_what_prior_loc = jnp.zeros(50, dtype=float)
z_what_prior_scale = jnp.ones(50, dtype=float)
z_pres_prior = 0.008
z_pres_prior = 0.05
# z_pres_prior = [0.05, 0.05**2.3, 0.05 ** (5)]
eps = 1e-4


Expand All @@ -371,7 +372,7 @@ def step(
prev_x: FloatArray,
prev_z_pres: IntArray,
):
z_pres = vi.flip_mvd(z_pres_prior ** (2 * t + 1)) @ f"z_pres_{t}"
z_pres = vi.flip_mvd(z_pres_prior) @ f"z_pres_{t}"
z_pres = jnp.array([z_pres.astype(int)])
z_where = (
vi.mv_normal_diag_reparam(z_where_prior_loc, z_where_prior_scale)
Expand Down Expand Up @@ -428,7 +429,7 @@ def guide_step(
h, c = rnn(rnn_input, (prev_h, prev_c))
z_pres_p, z_where_loc, z_where_scale = predict(h)
z_pres = (
vi.flip_mvd((eps + (z_pres_p[0] * prev_z_pres[0])) / (1 + 2 * eps))
vi.flip_mvd((eps + (z_pres_p[0] * prev_z_pres[0])) / (1 + 1.01 * eps))
@ f"z_pres_{t}"
)
z_pres = jnp.array([z_pres.astype(int)])
Expand Down Expand Up @@ -788,7 +789,7 @@ def body_fn(carry, xs):

key, sub_key = jax.random.split(key)
(p_losses, q_losses), accuracy, wall_clock_times, params = train(
sub_key, learning_rate=3.0e-3, n=10, batch_size=64, num_epochs=40
sub_key, learning_rate=1.0e-3, n=10, batch_size=64, num_epochs=40
)

arr = np.array([p_losses, q_losses, accuracy, wall_clock_times])
Expand Down

0 comments on commit 81e7590

Please sign in to comment.