Skip to content

Commit

Permalink
Add tests with non-null size for ifelse and switch mixtures
Browse files Browse the repository at this point in the history
  • Loading branch information
larryshamalama authored and brandonwillard committed Sep 7, 2022
1 parent dc39ef8 commit b284e35
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 31 deletions.
2 changes: 1 addition & 1 deletion aeppl/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def switch_ifelse_mixture_replace(fgraph, node):
*([NoneConst, as_nontensor_scalar(node.inputs[0])] + mixture_rvs)
)
else:
new_node = mix_op.make_node(*([NoneConst, node.inputs[0]] + mixture_rvs))
new_node = mix_op.make_node(*([at.constant(0), node.inputs[0]] + mixture_rvs))

new_mixture_rv = new_node.default_output()

Expand Down
139 changes: 109 additions & 30 deletions tests/test_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,25 +252,6 @@ def test_hetero_mixture_binomial(p_val, size):
(),
0,
),
(
(
np.array(0, dtype=aesara.config.floatX),
np.array(1, dtype=aesara.config.floatX),
),
(
np.array(0.5, dtype=aesara.config.floatX),
np.array(0.5, dtype=aesara.config.floatX),
),
(
np.array(100, dtype=aesara.config.floatX),
np.array(1, dtype=aesara.config.floatX),
),
np.array([0.1, 0.5, 0.4], dtype=aesara.config.floatX),
(),
(),
(),
0,
),
(
(
np.array(0, dtype=aesara.config.floatX),
Expand Down Expand Up @@ -716,14 +697,118 @@ def test_mixture_with_DiracDelta():
assert m_vv in logp_res


@pytest.mark.parametrize("op", [at.switch, ifelse])
def test_switch_ifelse_mixture(op):
@pytest.mark.parametrize(
"op, X_args, Y_args, p_val, comp_size, idx_size",
[
[op] + list(test_args)
for op in [at.switch, ifelse]
for test_args in [
(
(
np.array(-10, dtype=aesara.config.floatX),
np.array(0.1, dtype=aesara.config.floatX),
),
(
np.array(10, dtype=aesara.config.floatX),
np.array(0.1, dtype=aesara.config.floatX),
),
np.array(0.5, dtype=aesara.config.floatX),
(),
(),
),
(
(
np.array(-10, dtype=aesara.config.floatX),
np.array(0.1, dtype=aesara.config.floatX),
),
(
np.array(10, dtype=aesara.config.floatX),
np.array(0.1, dtype=aesara.config.floatX),
),
np.array(0.5, dtype=aesara.config.floatX),
(),
(6,),
),
(
(
np.array([10, 20], dtype=aesara.config.floatX),
np.array(0.1, dtype=aesara.config.floatX),
),
(
np.array([-10, -20], dtype=aesara.config.floatX),
np.array(0.1, dtype=aesara.config.floatX),
),
np.array([0.9, 0.1], dtype=aesara.config.floatX),
(2,),
(2,),
),
(
(
np.array([10, 20], dtype=aesara.config.floatX),
np.array(0.1, dtype=aesara.config.floatX),
),
(
np.array([-10, -20], dtype=aesara.config.floatX),
np.array(0.1, dtype=aesara.config.floatX),
),
np.array([0.9, 0.1], dtype=aesara.config.floatX),
None,
None,
),
(
(
np.array(-10, dtype=aesara.config.floatX),
np.array(0.1, dtype=aesara.config.floatX),
),
(
np.array(10, dtype=aesara.config.floatX),
np.array(0.1, dtype=aesara.config.floatX),
),
np.array(0.5, dtype=aesara.config.floatX),
(2, 3),
(2, 3),
),
(
(
np.array(10, dtype=aesara.config.floatX),
np.array(0.1, dtype=aesara.config.floatX),
),
(
np.array(-10, dtype=aesara.config.floatX),
np.array(0.1, dtype=aesara.config.floatX),
),
np.array(0.5, dtype=aesara.config.floatX),
(2, 3),
(),
),
(
(
np.array(10, dtype=aesara.config.floatX),
np.array(0.1, dtype=aesara.config.floatX),
),
(
np.array(-10, dtype=aesara.config.floatX),
np.array(0.1, dtype=aesara.config.floatX),
),
np.array(0.5, dtype=aesara.config.floatX),
(3,),
(3,),
),
]
if not ((test_args[-1] is None or len(test_args[-1]) > 0) and op == ifelse)
],
)
def test_switch_ifelse_mixture(op, X_args, Y_args, p_val, comp_size, idx_size):
"""
The argument size is both the input to srng.normal and the expected
size of the mixture RV Z1_rv
"""
srng = at.random.RandomStream(29833)

X_rv = srng.normal(-10.0, 0.1, name="X")
Y_rv = srng.normal(10.0, 0.1, name="Y")
X_rv = srng.normal(*X_args, size=comp_size, name="X")
Y_rv = srng.normal(*Y_args, size=comp_size, name="Y")

I_rv = srng.bernoulli(0.5, name="I")
I_rv = srng.bernoulli(p_val, size=idx_size, name="I")
i_vv = I_rv.clone()
i_vv.name = "i"

Expand Down Expand Up @@ -755,9 +840,3 @@ def test_switch_ifelse_mixture(op):

z1_logp = joint_logprob({Z1_rv: z_vv, I_rv: i_vv})
z2_logp = joint_logprob({Z2_rv: z_vv, I_rv: i_vv})

# below should follow immediately from the equal_computations assertion above
assert equal_computations([z1_logp], [z2_logp])

np.testing.assert_almost_equal(0.69049938, z1_logp.eval({z_vv: -10, i_vv: 0}))
np.testing.assert_almost_equal(0.69049938, z2_logp.eval({z_vv: -10, i_vv: 0}))

0 comments on commit b284e35

Please sign in to comment.