Skip to content

Commit

Permalink
Fixed datatype conversion in KanrenRelationSub
Browse files Browse the repository at this point in the history
  • Loading branch information
kc611 authored and rlouf committed Oct 17, 2022
1 parent d019910 commit b160ef5
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 5 deletions.
4 changes: 4 additions & 0 deletions aesara/graph/rewriting/kanren.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ def transform(self, fgraph, node):
else:
new_outputs = [eval_if_etuple(chosen_res)]

new_outputs = [
_new_out.astype(_inp_expr.dtype)
for _inp_expr, _new_out in zip(node.outputs, new_outputs)
]
return new_outputs
else:
return False
8 changes: 3 additions & 5 deletions aesara/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
from etuples import etuple
from kanren import fact, heado, tailo
from kanren.assoccomm import associative, commutative
from kanren.constraints import neq
from kanren.core import lall, lany
from kanren.graph import mapo
from kanten.constraints import neq
from unification import vars as lvars

import aesara.scalar.basic as aes
Expand Down Expand Up @@ -3558,13 +3558,11 @@ def distributive_collect(in_lv, out_lv):
A_lv, op_lv, all_term_lv, all_cdr_lv, cdr_lv, all_flat_lv = lvars(6)
return lall(
lany(
heado(at.add, all_term_lv),
heado(at.sub, all_term_lv),
lall(heado(at.add, all_term_lv), eq(cons(at.add, cdr_lv), out_lv)),
lall(heado(at.sub, all_term_lv), eq(cons(at.sub, cdr_lv), out_lv)),
),
# Get the flattened `add` arguments
tailo(all_cdr_lv, all_term_lv),
# Add all the arguments and set the output
lany(eq(cons(at.add, cdr_lv), out_lv), eq(cons(at.sub, cdr_lv), out_lv)),
lany(
lall(
lany(
Expand Down
1 change: 1 addition & 0 deletions tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1958,6 +1958,7 @@ def test_local_elemwise_sub_zeros():
"ShapeOpt",
"local_fill_to_alloc",
"local_elemwise_alloc",
"dist_collect_opt",
)
.including("local_elemwise_sub_zeros")
)
Expand Down

0 comments on commit b160ef5

Please sign in to comment.