From e6542eb7142e75835460cc3b16879f80bb7d90d9 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Wed, 24 Nov 2021 17:11:18 +0100 Subject: [PATCH] Move test helper softmax_graph to test module --- aesara/tensor/nnet/__init__.py | 1 - aesara/tensor/nnet/basic.py | 4 ---- tests/scan/test_basic.py | 3 ++- tests/tensor/nnet/test_basic.py | 5 ++++- 4 files changed, 6 insertions(+), 7 deletions(-) diff --git a/aesara/tensor/nnet/__init__.py b/aesara/tensor/nnet/__init__.py index 4a1f046bdc..57d7c8f186 100644 --- a/aesara/tensor/nnet/__init__.py +++ b/aesara/tensor/nnet/__init__.py @@ -35,7 +35,6 @@ sigmoid_binary_crossentropy, softmax, softmax_grad_legacy, - softmax_graph, softmax_legacy, softmax_simplifier, softmax_with_bias, diff --git a/aesara/tensor/nnet/basic.py b/aesara/tensor/nnet/basic.py index 43d5148f1c..b6702d6446 100644 --- a/aesara/tensor/nnet/basic.py +++ b/aesara/tensor/nnet/basic.py @@ -1122,10 +1122,6 @@ def local_logsoftmax_grad(fgraph, node): return [ret] -def softmax_graph(c): - return exp(c) / exp(c).sum(axis=-1, keepdims=True) - - UNSET_AXIS = object() diff --git a/tests/scan/test_basic.py b/tests/scan/test_basic.py index 33c925afd6..1825ca3f26 100644 --- a/tests/scan/test_basic.py +++ b/tests/scan/test_basic.py @@ -56,7 +56,7 @@ from aesara.tensor.math import dot, mean, sigmoid from aesara.tensor.math import sum as aet_sum from aesara.tensor.math import tanh -from aesara.tensor.nnet import categorical_crossentropy, softmax_graph +from aesara.tensor.nnet import categorical_crossentropy from aesara.tensor.random.utils import RandomStream from aesara.tensor.shape import Shape_i, reshape, shape, specify_shape from aesara.tensor.sharedvar import SharedVariable @@ -81,6 +81,7 @@ vector, ) from tests import unittest_tools as utt +from tests.tensor.nnet.test_basic import softmax_graph if config.mode == "FAST_COMPILE": diff --git a/tests/tensor/nnet/test_basic.py b/tests/tensor/nnet/test_basic.py index ddd7dbbfaf..061ce16323 100644 --- a/tests/tensor/nnet/test_basic.py +++ b/tests/tensor/nnet/test_basic.py @@ -52,7 +52,6 @@ sigmoid_binary_crossentropy, softmax, softmax_grad_legacy, - softmax_graph, softmax_legacy, softmax_with_bias, softsign, @@ -83,6 +82,10 @@ ) +def softmax_graph(c): + return exp(c) / exp(c).sum(axis=-1, keepdims=True) + + def valid_axis_tester(Op): with pytest.raises(TypeError): Op(1.5)