-
-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce graph rewrite for mixture sub-graphs defined via IfElse
Op
#169
base: main
Are you sure you want to change the base?
Conversation
d0d4c7b
to
3458414
Compare
Codecov ReportBase: 95.15% // Head: 94.94% // Decreases project coverage by
Additional details and impacted files@@ Coverage Diff @@
## main #169 +/- ##
==========================================
- Coverage 95.15% 94.94% -0.22%
==========================================
Files 12 12
Lines 2023 1878 -145
Branches 253 280 +27
==========================================
- Hits 1925 1783 -142
+ Misses 56 53 -3
Partials 42 42
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
28b20db
to
47b2117
Compare
I added many tests; I used |
Here are some notes regarding why some tests are failing. The discrepancy between graphs occur here:
This comment would serve as a reminder on what I am stuck on before the upcoming meeting. |
47b2117
to
b284e35
Compare
b284e35
to
3f5ae48
Compare
I'm revisiting this PR slowly and after quite some time. I'm investigating one of my failing test cases and I probably have forgotten many details just due to time passing... Consider the following code. import aesara
import aesara.tensor as at
from aeppl.rewriting import construct_ir_fgraph
srng = at.random.RandomStream(29833)
X_rv = srng.normal(loc=[10, 20], scale=0.1, size=(2,), name="X")
Y_rv = srng.normal(loc=[-10, -20], scale=0.1, size=(2,), name="Y")
I_rv = srng.bernoulli([0.9, 0.1], size=(2,), name="I")
i_vv = I_rv.clone()
i_vv.name = "i"
Z1_rv = at.switch(I_rv, X_rv, Y_rv)
z_vv = Z1_rv.clone()
z_vv.name = "z1"
fgraph, _, _ = construct_ir_fgraph({Z1_rv: z_vv, I_rv: i_vv})
aesara.dprint(fgraph.outputs[0]) yields SpecifyShape [id A]
|MixtureRV{indices_end_idx=2, out_dtype='float64', out_broadcastable=(False,)} [id B]
| |TensorConstant{0} [id C]
| |bernoulli_rv{0, (0,), int64, False}.1 [id D] 'I'
| | |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x16488AB20>) [id E]
| | |TensorConstant{(1,) of 2} [id F]
| | |TensorConstant{4} [id G]
| | |TensorConstant{[0.9 0.1]} [id H]
| |normal_rv{0, (0, 0), floatX, False}.1 [id I] 'X'
| | |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x1648899A0>) [id J]
| | |TensorConstant{(1,) of 2} [id F]
| | |TensorConstant{11} [id K]
| | |TensorConstant{[10 20]} [id L]
| | |TensorConstant{0.1} [id M]
| |normal_rv{0, (0, 0), floatX, False}.1 [id N] 'Y'
| |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x16488A340>) [id O]
| |TensorConstant{(1,) of 2} [id F]
| |TensorConstant{11} [id K]
| |TensorConstant{[-10 -20]} [id P]
| |TensorConstant{0.1} [id M]
|TensorConstant{2} [id Q]
bernoulli_rv{0, (0,), int64, False}.1 [id D] 'I' Where does the |
Hey, thanks for revisiting the PR! Do you mean running the code on this PR branch? If I run your code snippet on Elemwise{switch,no_inplace} [id A]
|bernoulli_rv{0, (0,), int64, False}.1 [id B] 'I'
| |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F0239D0E880>) [id C]
| |TensorConstant{(1,) of 2} [id D]
| |TensorConstant{4} [id E]
| |TensorConstant{[0.9 0.1]} [id F]
|normal_rv{0, (0, 0), floatX, False}.1 [id G] 'X'
| |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F023B681B60>) [id H]
| |TensorConstant{(1,) of 2} [id D]
| |TensorConstant{11} [id I]
| |TensorConstant{[10 20]} [id J]
| |TensorConstant{0.1} [id K]
|normal_rv{0, (0, 0), floatX, False}.1 [id L] 'Y'
|RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F0239D0DEE0>) [id M]
|TensorConstant{(1,) of 2} [id D]
|TensorConstant{11} [id I]
|TensorConstant{[-10 -20]} [id N]
|TensorConstant{0.1} [id K] Have you considered dispatching You will also need to resolve the (small) merge conflict due to the new |
Yes, running the code on this branch! The graph rewrite for
Yes, but I felt like that graph rewrite for both would be very similar. I can separate them as I work through them, for now...
Okay sounds good! |
They likely will, and we may want to merge them later. But I think it would be easier for you to move from one stable state to another, changing one thing at a time, and always keeping a reference implementation ( |
3f5ae48
to
5ff4105
Compare
I just rebased my code. I am working on having Switch/IfElse-induced mixture subgraphs yield the same canonical (or IR? the graph obtained after running |
Of course: https://aesara.readthedocs.io/en/latest/extending/graph_rewriting.html#detailed-profiling-of-aesara-rewrites. Alternatively, you can add a breakpoint here, |
There's also In general, as @rlouf said, don't be afraid to put |
Adding to this, you can set any reasonable IDE up so when you run tests it will open a debugger console whenever it hits a breakpoint or fails. If you don't have that in place already, spend some time setting it up; it was a huge boost in my productivity. |
5ff4105
to
f475bc9
Compare
Thanks for the tip. I also saw your recent related tweet 😅 As for this PR, I am thinking that it's best to close it to 1) split the tasks into smaller sub-PRs (I felt like too much was going on at once) and 2) address some other issues that came up. As for the latter, I divided them into subsections below. Any guidance would be helpful... Reworking
|
new_node = mix_op.make_node( | |
*([NoneConst, as_nontensor_scalar(node.inputs[0])] + mixture_rvs) | |
) |
SpecifyShape
Op
The appearance of the SpecifyShape
Op seems to be new... perhaps due to this recent addition to Aesara? Maybe a good first step would be to replace out_broadcastable
in MixtureRV
with the corresponding static shapes, if available. Would this be a good first step?
Lines 180 to 197 in 473c1e6
class MixtureRV(Op): | |
"""A placeholder used to specify a log-likelihood for a mixture sub-graph.""" | |
__props__ = ("indices_end_idx", "out_dtype", "out_broadcastable") | |
def __init__(self, indices_end_idx, out_dtype, out_broadcastable): | |
super().__init__() | |
self.indices_end_idx = indices_end_idx | |
self.out_dtype = out_dtype | |
self.out_broadcastable = out_broadcastable | |
def make_node(self, *inputs): | |
return Apply( | |
self, list(inputs), [TensorType(self.out_dtype, self.out_broadcastable)()] | |
) | |
def perform(self, node, inputs, outputs): | |
raise NotImplementedError("This is a stand-in Op.") # pragma: no cover |
Mismatch in MixtureRV
shapes generated by Switch
vs. at.stack
With the hot fix replacing broadcastable
with shape, the MixtureRV
shapes seem to be different if they are generated by a Switch
vs. Join
. Is this because subtensors don't have static shape inference yet? That would be my guess (Aesara issue #922?), but I'm not sure. Below is an example that I created using this branch's additions.
import aesara.tensor as at
from aeppl.rewriting import construct_ir_fgraph
from aeppl.mixture import MixtureRV
srng = at.random.RandomStream(29833)
X_rv = srng.normal([10, 20], 0.1, size=(2,), name="X")
Y_rv = srng.normal([-10, -20], 0.1, size=(2,), name="Y")
I_rv = srng.bernoulli([0.99, 0.01], size=(2,), name="I")
i_vv = I_rv.clone()
i_vv.name = "i"
Z1_rv = at.switch(I_rv, X_rv, Y_rv)
z_vv = Z1_rv.clone()
z_vv.name = "z1"
fgraph, _, _ = construct_ir_fgraph({Z1_rv: z_vv, I_rv: i_vv})
assert isinstance(fgraph.outputs[0].owner.op, MixtureRV)
assert not hasattr(
fgraph.outputs[0].tag, "test_value"
) # aesara.config.compute_test_value == "off"
assert fgraph.outputs[0].name is None
Z1_rv.name = "Z1"
fgraph, _, _ = construct_ir_fgraph({Z1_rv: z_vv, I_rv: i_vv})
assert fgraph.outputs[0].name == "Z1-mixture"
# building the identical graph but with a stack to check that mixture computations are identical
Z2_rv = at.stack((X_rv, Y_rv))[I_rv]
fgraph2, _, _ = construct_ir_fgraph({Z2_rv: z_vv, I_rv: i_vv})
fgraph.outputs[0].type.shape # (2,)
fgraph2.outputs[0].type.shape # (None, None)
IfElse
mixture subgraphs
Given that IfElse
requires scalar conditions, maybe it would be good to start with them instead of refining switch-mixtures... Happy to hear any thoughts about these points above. I feel like there's a lot going on, and it can be challenging to address all at once (especially given that this is continuation from this summer's work...)
PRs that touch on core mechanisms in Aesara, or simply that implement big changes, can easily get frustrating. Breaking the problem down like you did is a great reaction to this situation. Do you mind if I keep it open and I come back to you later next week at least with some questions, maybe some insight? |
Of course, not a problem at all! |
f475bc9
to
8c9c0f3
Compare
@rlouf Just a quick update that @brandonwillard and I conversed recently, hence the recent force-push. The current focus is to ensure that the current mixture indexing operations via |
Glad to hear this is back on track! |
b6aa902
to
0dd44af
Compare
Closes #76.
Akin to #154, this PR introduces a
node_rewriter
forIfElse
. Effectively, this builds on the recently addedswitch_mixture_replace
to accommodate mixture sub-graphs as the same essence but defined with a differentOp
:IfElse
. Below is an example of the new functionality.