diff --git a/aeppl/mixture.py b/aeppl/mixture.py index c604e7dd..22f3afdf 100644 --- a/aeppl/mixture.py +++ b/aeppl/mixture.py @@ -345,7 +345,7 @@ def switch_ifelse_mixture_replace(fgraph, node): *([NoneConst, as_nontensor_scalar(node.inputs[0])] + mixture_rvs) ) else: - new_node = mix_op.make_node(*([NoneConst, node.inputs[0]] + mixture_rvs)) + new_node = mix_op.make_node(*([at.constant(0), node.inputs[0]] + mixture_rvs)) new_mixture_rv = new_node.default_output() diff --git a/tests/test_mixture.py b/tests/test_mixture.py index f0864ccc..d84d89bf 100644 --- a/tests/test_mixture.py +++ b/tests/test_mixture.py @@ -252,25 +252,6 @@ def test_hetero_mixture_binomial(p_val, size): (), 0, ), - ( - ( - np.array(0, dtype=aesara.config.floatX), - np.array(1, dtype=aesara.config.floatX), - ), - ( - np.array(0.5, dtype=aesara.config.floatX), - np.array(0.5, dtype=aesara.config.floatX), - ), - ( - np.array(100, dtype=aesara.config.floatX), - np.array(1, dtype=aesara.config.floatX), - ), - np.array([0.1, 0.5, 0.4], dtype=aesara.config.floatX), - (), - (), - (), - 0, - ), ( ( np.array(0, dtype=aesara.config.floatX), @@ -716,14 +697,118 @@ def test_mixture_with_DiracDelta(): assert m_vv in logp_res -@pytest.mark.parametrize("op", [at.switch, ifelse]) -def test_switch_ifelse_mixture(op): +@pytest.mark.parametrize( + "op, X_args, Y_args, p_val, comp_size, idx_size", + [ + [op] + list(test_args) + for op in [at.switch, ifelse] + for test_args in [ + ( + ( + np.array(-10, dtype=aesara.config.floatX), + np.array(0.1, dtype=aesara.config.floatX), + ), + ( + np.array(10, dtype=aesara.config.floatX), + np.array(0.1, dtype=aesara.config.floatX), + ), + np.array(0.5, dtype=aesara.config.floatX), + (), + (), + ), + ( + ( + np.array(-10, dtype=aesara.config.floatX), + np.array(0.1, dtype=aesara.config.floatX), + ), + ( + np.array(10, dtype=aesara.config.floatX), + np.array(0.1, dtype=aesara.config.floatX), + ), + np.array(0.5, dtype=aesara.config.floatX), + (), + (6,), + ), + ( + ( + np.array([10, 20], dtype=aesara.config.floatX), + np.array(0.1, dtype=aesara.config.floatX), + ), + ( + np.array([-10, -20], dtype=aesara.config.floatX), + np.array(0.1, dtype=aesara.config.floatX), + ), + np.array([0.9, 0.1], dtype=aesara.config.floatX), + (2,), + (2,), + ), + ( + ( + np.array([10, 20], dtype=aesara.config.floatX), + np.array(0.1, dtype=aesara.config.floatX), + ), + ( + np.array([-10, -20], dtype=aesara.config.floatX), + np.array(0.1, dtype=aesara.config.floatX), + ), + np.array([0.9, 0.1], dtype=aesara.config.floatX), + None, + None, + ), + ( + ( + np.array(-10, dtype=aesara.config.floatX), + np.array(0.1, dtype=aesara.config.floatX), + ), + ( + np.array(10, dtype=aesara.config.floatX), + np.array(0.1, dtype=aesara.config.floatX), + ), + np.array(0.5, dtype=aesara.config.floatX), + (2, 3), + (2, 3), + ), + ( + ( + np.array(10, dtype=aesara.config.floatX), + np.array(0.1, dtype=aesara.config.floatX), + ), + ( + np.array(-10, dtype=aesara.config.floatX), + np.array(0.1, dtype=aesara.config.floatX), + ), + np.array(0.5, dtype=aesara.config.floatX), + (2, 3), + (), + ), + ( + ( + np.array(10, dtype=aesara.config.floatX), + np.array(0.1, dtype=aesara.config.floatX), + ), + ( + np.array(-10, dtype=aesara.config.floatX), + np.array(0.1, dtype=aesara.config.floatX), + ), + np.array(0.5, dtype=aesara.config.floatX), + (3,), + (3,), + ), + ] + if not ((test_args[-1] is None or len(test_args[-1]) > 0) and op == ifelse) + ], +) +def test_switch_ifelse_mixture(op, X_args, Y_args, p_val, comp_size, idx_size): + """ + The argument size is both the input to srng.normal and the expected + size of the mixture RV Z1_rv + """ srng = at.random.RandomStream(29833) - X_rv = srng.normal(-10.0, 0.1, name="X") - Y_rv = srng.normal(10.0, 0.1, name="Y") + X_rv = srng.normal(*X_args, size=comp_size, name="X") + Y_rv = srng.normal(*Y_args, size=comp_size, name="Y") - I_rv = srng.bernoulli(0.5, name="I") + I_rv = srng.bernoulli(p_val, size=idx_size, name="I") i_vv = I_rv.clone() i_vv.name = "i" @@ -755,9 +840,3 @@ def test_switch_ifelse_mixture(op): z1_logp = joint_logprob({Z1_rv: z_vv, I_rv: i_vv}) z2_logp = joint_logprob({Z2_rv: z_vv, I_rv: i_vv}) - - # below should follow immediately from the equal_computations assertion above - assert equal_computations([z1_logp], [z2_logp]) - - np.testing.assert_almost_equal(0.69049938, z1_logp.eval({z_vv: -10, i_vv: 0})) - np.testing.assert_almost_equal(0.69049938, z2_logp.eval({z_vv: -10, i_vv: 0}))