From b160ef53433ac5a1640b7c296406ff8aabef2305 Mon Sep 17 00:00:00 2001 From: kc611 Date: Tue, 29 Mar 2022 00:01:20 +0530 Subject: [PATCH] Fixed datatype conversion in KanrenRelationSub --- aesara/graph/rewriting/kanren.py | 4 ++++ aesara/tensor/rewriting/math.py | 8 +++----- tests/tensor/rewriting/test_math.py | 1 + 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/aesara/graph/rewriting/kanren.py b/aesara/graph/rewriting/kanren.py index 212d86d02d..7abcb3dcb0 100644 --- a/aesara/graph/rewriting/kanren.py +++ b/aesara/graph/rewriting/kanren.py @@ -95,6 +95,10 @@ def transform(self, fgraph, node): else: new_outputs = [eval_if_etuple(chosen_res)] + new_outputs = [ + _new_out.astype(_inp_expr.dtype) + for _inp_expr, _new_out in zip(node.outputs, new_outputs) + ] return new_outputs else: return False diff --git a/aesara/tensor/rewriting/math.py b/aesara/tensor/rewriting/math.py index f4b750d864..fd440e84c5 100644 --- a/aesara/tensor/rewriting/math.py +++ b/aesara/tensor/rewriting/math.py @@ -9,9 +9,9 @@ from etuples import etuple from kanren import fact, heado, tailo from kanren.assoccomm import associative, commutative +from kanren.constraints import neq from kanren.core import lall, lany from kanren.graph import mapo -from kanten.constraints import neq from unification import vars as lvars import aesara.scalar.basic as aes @@ -3558,13 +3558,11 @@ def distributive_collect(in_lv, out_lv): A_lv, op_lv, all_term_lv, all_cdr_lv, cdr_lv, all_flat_lv = lvars(6) return lall( lany( - heado(at.add, all_term_lv), - heado(at.sub, all_term_lv), + lall(heado(at.add, all_term_lv), eq(cons(at.add, cdr_lv), out_lv)), + lall(heado(at.sub, all_term_lv), eq(cons(at.sub, cdr_lv), out_lv)), ), # Get the flattened `add` arguments tailo(all_cdr_lv, all_term_lv), - # Add all the arguments and set the output - lany(eq(cons(at.add, cdr_lv), out_lv), eq(cons(at.sub, cdr_lv), out_lv)), lany( lall( lany( diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index d83a5248d0..bab1202e9a 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -1958,6 +1958,7 @@ def test_local_elemwise_sub_zeros(): "ShapeOpt", "local_fill_to_alloc", "local_elemwise_alloc", + "dist_collect_opt", ) .including("local_elemwise_sub_zeros") )