Skip to content

Commit

Permalink
Move test helper softmax_graph to test module
Browse files Browse the repository at this point in the history
  • Loading branch information
Ricardo committed Nov 25, 2021
1 parent 9f11adc commit e6542eb
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 7 deletions.
1 change: 0 additions & 1 deletion aesara/tensor/nnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
sigmoid_binary_crossentropy,
softmax,
softmax_grad_legacy,
softmax_graph,
softmax_legacy,
softmax_simplifier,
softmax_with_bias,
Expand Down
4 changes: 0 additions & 4 deletions aesara/tensor/nnet/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,10 +1122,6 @@ def local_logsoftmax_grad(fgraph, node):
return [ret]


def softmax_graph(c):
return exp(c) / exp(c).sum(axis=-1, keepdims=True)


UNSET_AXIS = object()


Expand Down
3 changes: 2 additions & 1 deletion tests/scan/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from aesara.tensor.math import dot, mean, sigmoid
from aesara.tensor.math import sum as aet_sum
from aesara.tensor.math import tanh
from aesara.tensor.nnet import categorical_crossentropy, softmax_graph
from aesara.tensor.nnet import categorical_crossentropy
from aesara.tensor.random.utils import RandomStream
from aesara.tensor.shape import Shape_i, reshape, shape, specify_shape
from aesara.tensor.sharedvar import SharedVariable
Expand All @@ -81,6 +81,7 @@
vector,
)
from tests import unittest_tools as utt
from tests.tensor.nnet.test_basic import softmax_graph


if config.mode == "FAST_COMPILE":
Expand Down
5 changes: 4 additions & 1 deletion tests/tensor/nnet/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
sigmoid_binary_crossentropy,
softmax,
softmax_grad_legacy,
softmax_graph,
softmax_legacy,
softmax_with_bias,
softsign,
Expand Down Expand Up @@ -83,6 +82,10 @@
)


def softmax_graph(c):
return exp(c) / exp(c).sum(axis=-1, keepdims=True)


def valid_axis_tester(Op):
with pytest.raises(TypeError):
Op(1.5)
Expand Down

0 comments on commit e6542eb

Please sign in to comment.