From 80e412edfe47946fa2c0975270c4c674c21ba57f Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Mon, 29 Aug 2022 15:04:34 -0500 Subject: [PATCH] Hash cons Apply nodes and Constants --- aesara/graph/basic.py | 191 ++++++++++++++++------------- aesara/graph/fg.py | 90 +++++++++++--- aesara/graph/rewriting/basic.py | 15 +-- aesara/link/c/basic.py | 4 +- aesara/link/c/params_type.py | 8 +- aesara/link/c/type.py | 10 +- aesara/sparse/basic.py | 45 ++++--- aesara/tensor/rewriting/shape.py | 4 +- aesara/tensor/type.py | 14 ++- aesara/tensor/type_other.py | 49 +++++--- aesara/tensor/var.py | 142 +-------------------- tests/graph/test_basic.py | 203 ++++++++++++++++++++----------- 12 files changed, 385 insertions(+), 390 deletions(-) diff --git a/aesara/graph/basic.py b/aesara/graph/basic.py index ba48f21bf2..b5b473575f 100644 --- a/aesara/graph/basic.py +++ b/aesara/graph/basic.py @@ -1,5 +1,4 @@ """Core graph classes.""" -import abc import warnings from collections import deque from copy import copy @@ -26,13 +25,12 @@ Union, cast, ) -from weakref import WeakKeyDictionary +from weakref import WeakValueDictionary import numpy as np from aesara.configdefaults import config from aesara.graph.utils import ( - MetaObject, MethodNotDefined, Scratchpad, TestValueError, @@ -53,32 +51,39 @@ _TypeType = TypeVar("_TypeType", bound="Type") _IdType = TypeVar("_IdType", bound=Hashable) -T = TypeVar("T", bound="Node") +T = TypeVar("T", bound=Union["Apply", "Variable"]) NoParams = object() NodeAndChildren = Tuple[T, Optional[Iterable[T]]] -class Node(MetaObject): - r"""A `Node` in an Aesara graph. +class UniqueInstanceFactory(type): - Currently, graphs contain two kinds of `Nodes`: `Variable`\s and `Apply`\s. - Edges in the graph are not explicitly represented. Instead each `Node` - keeps track of its parents via `Variable.owner` / `Apply.inputs`. + __instances__: WeakValueDictionary - """ - name: Optional[str] + def __new__(cls, name, bases, dct): + dct["__instances__"] = WeakValueDictionary() + res = super().__new__(cls, name, bases, dct) + return res - def get_parents(self): - """ - Return a list of the parents of this node. - Should return a copy--i.e., modifying the return - value should not modify the graph structure. + def __call__( + cls, + *args, + **kwargs, + ): + idp = cls.create_key(*args, **kwargs) - """ - raise NotImplementedError() + if idp not in cls.__instances__: + res = super(UniqueInstanceFactory, cls).__call__(*args, **kwargs) + cls.__instances__[idp] = res + return res + return cls.__instances__[idp] -class Apply(Node, Generic[OpType]): + +class Apply( + Generic[OpType], + metaclass=UniqueInstanceFactory, +): """A `Node` representing the application of an operation to inputs. Basically, an `Apply` instance is an object that represents the @@ -113,12 +118,19 @@ class Apply(Node, Generic[OpType]): """ + __slots__ = ("op", "inputs", "outputs", "__weakref__", "tag") + + @classmethod + def create_key(cls, op, inputs, outputs): + return (op,) + tuple(inputs) + def __init__( self, op: OpType, inputs: Sequence["Variable"], outputs: Sequence["Variable"], ): + if not isinstance(inputs, Sequence): raise TypeError("The inputs of an Apply must be a sequence type") @@ -154,6 +166,21 @@ def __init__( f"The 'outputs' argument to Apply must contain Variable instances with no owner, not {output}" ) + def __eq__(self, other): + if isinstance(other, type(self)): + if ( + self.op == other.op + and self.inputs == other.inputs + # and self.outputs == other.outputs + ): + return True + return False + + return NotImplemented + + def __hash__(self): + return hash((type(self), self.op, tuple(self.inputs), tuple(self.outputs))) + def run_params(self): """ Returns the params for the node, or NoParams if no params is set. @@ -165,8 +192,7 @@ def run_params(self): return NoParams def __getstate__(self): - d = self.__dict__ - # ufunc don't pickle/unpickle well + d = {k: getattr(self, k) for k in self.__slots__ if k not in ("__weakref__",)} if hasattr(self.tag, "ufunc"): d = copy(self.__dict__) t = d["tag"] @@ -174,6 +200,11 @@ def __getstate__(self): d["tag"] = t return d + def __setstate__(self, dct): + for k in self.__slots__: + if k in dct: + setattr(self, k, dct[k]) + def default_output(self): """ Returns the default output for this node. @@ -267,6 +298,7 @@ def clone_with_new_inputs( from aesara.graph.op import HasInnerGraph assert isinstance(inputs, (list, tuple)) + remake_node = False new_inputs: List["Variable"] = list(inputs) for i, (curr, new) in enumerate(zip(self.inputs, new_inputs)): @@ -280,17 +312,22 @@ def clone_with_new_inputs( else: remake_node = True - if remake_node: - new_op = self.op + new_op = self.op - if isinstance(new_op, HasInnerGraph) and clone_inner_graph: # type: ignore - new_op = new_op.clone() # type: ignore + if isinstance(new_op, HasInnerGraph) and clone_inner_graph: # type: ignore + new_op = new_op.clone() # type: ignore + if remake_node: new_node = new_op.make_node(*new_inputs) new_node.tag = copy(self.tag).__update__(new_node.tag) + elif new_op == self.op and new_inputs == self.inputs: + new_node = self else: - new_node = self.clone(clone_inner_graph=clone_inner_graph) - new_node.inputs = new_inputs + new_node = self.__class__( + new_op, new_inputs, [output.clone() for output in self.outputs] + ) + new_node.tag = copy(self.tag) + return new_node def get_parents(self): @@ -316,7 +353,7 @@ def params_type(self): return self.op.params_type -class Variable(Node, Generic[_TypeType, OptionalApplyType]): +class Variable(Generic[_TypeType, OptionalApplyType]): r""" A :term:`Variable` is a node in an expression graph that represents a variable. @@ -411,7 +448,7 @@ class Variable(Node, Generic[_TypeType, OptionalApplyType]): """ - # __slots__ = ['type', 'owner', 'index', 'name'] + __slots__ = ("_owner", "_index", "name", "type", "__weakref__", "tag", "auto_name") __count__ = count(0) _owner: OptionalApplyType @@ -487,26 +524,17 @@ def __str__(self): else: return f"<{self.type}>" - def __repr_test_value__(self): - """Return a ``repr`` of the test value. - - Return a printable representation of the test value. It can be - overridden by classes with non printable test_value to provide a - suitable representation of the test_value. - """ - return repr(self.get_test_value()) - def __repr__(self, firstPass=True): """Return a ``repr`` of the `Variable`. - Return a printable name or description of the Variable. If - ``config.print_test_value`` is ``True`` it will also print the test - value, if any. + Return a printable name or description of the `Variable`. If + `aesara.config.print_test_value` is ``True``, it will also print the + test value, if any. """ to_print = [str(self)] if config.print_test_value and firstPass: try: - to_print.append(self.__repr_test_value__()) + to_print.append(repr(self.get_test_value())) except TestValueError: pass return "\n".join(to_print) @@ -528,26 +556,6 @@ def clone(self): cp.tag = copy(self.tag) return cp - def __lt__(self, other): - raise NotImplementedError( - "Subclasses of Variable must provide __lt__", self.__class__.__name__ - ) - - def __le__(self, other): - raise NotImplementedError( - "Subclasses of Variable must provide __le__", self.__class__.__name__ - ) - - def __gt__(self, other): - raise NotImplementedError( - "Subclasses of Variable must provide __gt__", self.__class__.__name__ - ) - - def __ge__(self, other): - raise NotImplementedError( - "Subclasses of Variable must provide __ge__", self.__class__.__name__ - ) - def get_parents(self): if self.owner is not None: return [self.owner] @@ -605,7 +613,7 @@ def eval(self, inputs_to_values=None): return rval def __getstate__(self): - d = self.__dict__.copy() + d = {k: getattr(self, k) for k in self.__slots__ if k not in ("__weakref__",)} d.pop("_fn_cache", None) if (not config.pickle_test_value) and (hasattr(self.tag, "test_value")): if not type(config).pickle_test_value.is_default: @@ -618,6 +626,11 @@ def __getstate__(self): d["tag"] = t return d + def __setstate__(self, dct): + for k in self.__slots__: + if k in dct: + setattr(self, k, dct[k]) + class AtomicVariable(Variable[_TypeType, None]): """A node type that has no ancestors and should never be considered an input to a graph.""" @@ -625,19 +638,12 @@ class AtomicVariable(Variable[_TypeType, None]): def __init__(self, type: _TypeType, **kwargs): super().__init__(type, None, None, **kwargs) - @abc.abstractmethod - def signature(self): - ... - - def merge_signature(self): - return self.signature() - def equals(self, other): """ This does what `__eq__` would normally do, but `Variable` and `Apply` should always be hashable by `id`. """ - return isinstance(other, type(self)) and self.signature() == other.signature() + return self == other @property def owner(self): @@ -661,12 +667,15 @@ def index(self, value): class NominalVariable(AtomicVariable[_TypeType]): """A variable that enables alpha-equivalent comparisons.""" - __instances__: WeakKeyDictionary[ + __instances__: WeakValueDictionary[ Tuple["Type", Hashable], "NominalVariable" - ] = WeakKeyDictionary() + ] = WeakValueDictionary() def __new__(cls, id: _IdType, typ: _TypeType, **kwargs): - if (typ, id) not in cls.__instances__: + + idp = (typ, id) + + if idp not in cls.__instances__: var_type = typ.variable_type type_name = f"Nominal{var_type.__name__}" @@ -681,9 +690,9 @@ def _str(self): ) res: NominalVariable = super().__new__(new_type) - cls.__instances__[(typ, id)] = res + cls.__instances__[idp] = res - return cls.__instances__[(typ, id)] + return cls.__instances__[idp] def __init__(self, id: _IdType, typ: _TypeType, **kwargs): self.id = id @@ -708,11 +717,11 @@ def __hash__(self): def __repr__(self): return f"{type(self).__name__}({repr(self.id)}, {repr(self.type)})" - def signature(self) -> Tuple[_TypeType, _IdType]: - return (self.type, self.id) - -class Constant(AtomicVariable[_TypeType]): +class Constant( + AtomicVariable[_TypeType], + metaclass=UniqueInstanceFactory, +): """A `Variable` with a fixed `data` field. `Constant` nodes make numerous optimizations possible (e.g. constant @@ -725,19 +734,22 @@ class Constant(AtomicVariable[_TypeType]): """ - # __slots__ = ['data'] + __slots__ = ("type", "data") + + @classmethod + def create_key(cls, type, data, *args, **kwargs): + # TODO FIXME: This filters the data twice: once here, and again in + # `cls.__init__`. This might not be a big deal, though. + return (type, type.filter(data)) def __init__(self, type: _TypeType, data: Any, name: Optional[str] = None): - super().__init__(type, name=name) + AtomicVariable.__init__(self, type, name=name) self.data = type.filter(data) add_tag_trace(self) def get_test_value(self): return self.data - def signature(self): - return (self.type, self.data) - def __str__(self): if self.name is not None: return self.name @@ -764,6 +776,15 @@ def owner(self, value) -> None: def value(self): return self.data + def __hash__(self): + return hash((type(self), self.type, self.data)) + + def __eq__(self, other): + if isinstance(other, type(self)): + return self.type == other.type and self.data == other.data + + return NotImplemented + def walk( nodes: Iterable[T], diff --git a/aesara/graph/fg.py b/aesara/graph/fg.py index 26fb74bd7f..36cb4edbd2 100644 --- a/aesara/graph/fg.py +++ b/aesara/graph/fg.py @@ -399,8 +399,8 @@ def change_node_input( ``old_var`` is the current value of ``node.inputs[i]`` which we want to replace. - For each feature that has an `on_change_input` method, this method calls: - ``feature.on_change_input(function_graph, node, i, old_var, new_var, reason)`` + For each feature that has an `Feature.on_change_input` method, this method calls: + ``feature.on_change_input(function_graph, old_node, new_node, i, old_var, new_var, reason)`` Parameters ---------- @@ -420,35 +420,85 @@ def change_node_input( `History` `Feature`, which needs to revert types that have been narrowed and would otherwise fail this check. """ - # TODO: ERROR HANDLING FOR LISTENERS (should it complete the change or revert it?) - if node == "output": - r = self.outputs[i] - if check and not r.type.is_super(new_var.type): + + is_output = node == "output" + + if is_output: + old_var = self.outputs[i] + + if old_var is new_var: + return + + if check and not old_var.type.is_super(new_var.type): raise TypeError( f"The type of the replacement ({new_var.type}) must be " - f"compatible with the type of the original Variable ({r.type})." + f"compatible with the type of the original Variable ({old_var.type})." ) self.outputs[i] = new_var + new_node = node + + self.import_var(new_var, reason=reason, import_missing=import_missing) + self.add_client(new_var, (new_node, i)) + self.remove_client(old_var, (node, i), reason=reason) + # Precondition: the substitution is semantically valid; however, it may + # introduce cycles to the graph, in which case the transaction will be + # reverted later. + self.execute_callbacks( + "on_change_input", new_node, i, old_var, new_var, reason=reason + ) else: assert isinstance(node, Apply) - r = node.inputs[i] - if check and not r.type.is_super(new_var.type): + old_var = node.inputs[i] + + if old_var is new_var: + return + + if check and not old_var.type.is_super(new_var.type): raise TypeError( f"The type of the replacement ({new_var.type}) must be " - f"compatible with the type of the original Variable ({r.type})." + f"compatible with the type of the original Variable ({old_var.type})." ) - node.inputs[i] = new_var - if r is new_var: - return + # In this case, we need to construct a new `Apply` node with + # `node.inputs[i] = new_var`, and swap the old `node` with the new + # node. + new_inputs = list(node.inputs) + new_inputs[i] = new_var + + new_node = Apply(node.op, new_inputs, [o.clone() for o in node.outputs]) + + # TODO FIXME: This is unnecessarily costly. + self.import_node(new_node, reason=reason) + + for old_out, new_out in zip(node.outputs, new_node.outputs): + # self.import_var(new_out, reason=reason, import_missing=False) + + old_out_clients = list(self.clients[old_out]) + for client_and_id in old_out_clients: + self.add_client(new_out, client_and_id) + # TODO: This is a little too much... + self.remove_client( + old_out, client_and_id, reason=reason, remove_if_empty=True + ) + # Perhaps we could slim it down to: + # self.clients[old_out].remove(client_and_id) + self.execute_callbacks( + "on_change_input", + *client_and_id, + old_out, + new_out, + reason=reason, + ) - self.import_var(new_var, reason=reason, import_missing=import_missing) - self.add_client(new_var, (node, i)) - self.remove_client(r, (node, i), reason=reason) - # Precondition: the substitution is semantically valid However it may - # introduce cycles to the graph, in which case the transaction will be - # reverted later. - self.execute_callbacks("on_change_input", node, i, r, new_var, reason=reason) + # We can keep the original mapping in case someone, say, recreates + # the old un-rewritten node, and, in which case, they'll get back + # the fully rewritten node. + # type(node).__instances__[ + # type(node).create_key(node.op, new_inputs, node.outputs) + # ] = new_node + # self.execute_callbacks( + # "on_change_input", new_node, i, old_var, new_var, reason=reason + # ) def replace( self, diff --git a/aesara/graph/rewriting/basic.py b/aesara/graph/rewriting/basic.py index 586dca33d3..72be8fe275 100644 --- a/aesara/graph/rewriting/basic.py +++ b/aesara/graph/rewriting/basic.py @@ -34,7 +34,7 @@ from aesara.graph.features import AlreadyThere, Feature, NodeFinder from aesara.graph.fg import FunctionGraph from aesara.graph.op import Op -from aesara.graph.utils import AssocList, InconsistencyError +from aesara.graph.utils import InconsistencyError from aesara.misc.ordered_set import OrderedSet from aesara.utils import flatten @@ -532,8 +532,7 @@ def on_attach(self, fgraph): fgraph.merge_feature = self self.seen_atomics = set() - self.atomic_sig = AssocList() - self.atomic_sig_inv = AssocList() + self.canonical_atomics = {} # For all Apply nodes # Set of distinct (not mergeable) nodes @@ -587,17 +586,14 @@ def on_prune(self, fgraph, node, reason): for c in node.inputs: if isinstance(c, AtomicVariable) and len(fgraph.clients[c]) <= 1: # This was the last node using this constant - sig = self.atomic_sig[c] - self.atomic_sig.discard(c) - self.atomic_sig_inv.discard(sig) + self.canonical_atomics.pop(c) self.seen_atomics.discard(id(c)) def process_atomic(self, fgraph, c): """Check if an atomic `c` can be merged, and queue that replacement.""" if id(c) in self.seen_atomics: return - sig = c.merge_signature() - other_c = self.atomic_sig_inv.get(sig, None) + other_c = self.canonical_atomics.get(c, None) if other_c is not None: # multiple names will clobber each other.. # we adopt convention to keep the last name @@ -606,8 +602,7 @@ def process_atomic(self, fgraph, c): self.scheduled.append([[(c, other_c, "merge")]]) else: # this is a new constant - self.atomic_sig[c] = sig - self.atomic_sig_inv[sig] = c + self.canonical_atomics[c] = c self.seen_atomics.add(id(c)) def process_node(self, fgraph, node): diff --git a/aesara/link/c/basic.py b/aesara/link/c/basic.py index 8aed25cd13..f601404086 100644 --- a/aesara/link/c/basic.py +++ b/aesara/link/c/basic.py @@ -1416,15 +1416,13 @@ def in_sig(i, topological_pos, i_idx): # yield a 'position' that reflects its role in code_gen() if isinstance(i, AtomicVariable): # orphans if id(i) not in constant_ids: - isig = (i.signature(), topological_pos, i_idx) + isig = (hash(i), topological_pos, i_idx) # If the Aesara constant provides a strong hash # (no collision for transpose, 2, 1, 0, -1, -2, # 2 element swapped...) we put this hash in the signature # instead of the value. This makes the key file much # smaller for big constant arrays. Before this, we saw key # files up to 80M. - if hasattr(isig[0], "aesara_hash"): - isig = (isig[0].aesara_hash(), topological_pos, i_idx) try: hash(isig) except Exception: diff --git a/aesara/link/c/params_type.py b/aesara/link/c/params_type.py index c48db53fc5..62df6d19e5 100644 --- a/aesara/link/c/params_type.py +++ b/aesara/link/c/params_type.py @@ -291,9 +291,11 @@ def __hash__(self): # NB: For writing, we must bypass setattr() which is always called by default by Python. self.__dict__["__signatures__"] = tuple( # NB: Params object should have been already filtered. - self.__params_type__.types[i] - .make_constant(self[self.__params_type__.fields[i]]) - .signature() + hash( + self.__params_type__.types[i].make_constant( + self[self.__params_type__.fields[i]] + ) + ) for i in range(self.__params_type__.length) ) return hash((type(self), self.__params_type__) + self.__signatures__) diff --git a/aesara/link/c/type.py b/aesara/link/c/type.py index 33632fa1a6..e3dc2f4c9e 100644 --- a/aesara/link/c/type.py +++ b/aesara/link/c/type.py @@ -292,15 +292,7 @@ def __setstate__(self, dct): class CDataTypeConstant(Constant[T]): - def merge_signature(self): - # We don't want to merge constants that don't point to the - # same object. - return id(self.data) - - def signature(self): - # There is no way to put the data in the signature, so we - # don't even try - return (self.type,) + pass CDataType.constant_type = CDataTypeConstant diff --git a/aesara/sparse/basic.py b/aesara/sparse/basic.py index eba5fbe353..f5c4d7be42 100644 --- a/aesara/sparse/basic.py +++ b/aesara/sparse/basic.py @@ -23,7 +23,6 @@ from aesara.link.c.type import generic from aesara.misc.safe_asarray import _asarray from aesara.sparse.type import SparseTensorType, _is_sparse -from aesara.sparse.utils import hash_from_sparse from aesara.tensor import basic as at from aesara.tensor.basic import Split from aesara.tensor.math import _conj @@ -465,35 +464,33 @@ def __repr__(self): return str(self) -class SparseConstantSignature(tuple): +class SparseConstant(TensorConstant, _sparse_py_operators): + format = property(lambda self: self.type.format) + + # def __init__(self, *args): + # .view(HashableNDArray) + def __eq__(self, other): - (a, b), (x, y) = self, other - return ( - a == x - and (b.dtype == y.dtype) - and (type(b) == type(y)) - and (b.shape == y.shape) - and (abs(b - y).sum() < 1e-6 * b.nnz) - ) + if isinstance(other, type(self)): + b = self.data + y = other.data + if ( + self.type == other.type + and (b.dtype == y.dtype) + and (type(b) == type(y)) + and (b.shape == y.shape) + and (abs(b - y).sum() < 1e-6 * b.nnz) + ): + return True + return False + + return NotImplemented def __ne__(self, other): return not self == other def __hash__(self): - (a, b) = self - return hash(type(self)) ^ hash(a) ^ hash(type(b)) - - def aesara_hash(self): - (_, d) = self - return hash_from_sparse(d) - - -class SparseConstant(TensorConstant, _sparse_py_operators): - format = property(lambda self: self.type.format) - - def signature(self): - assert self.data is not None - return SparseConstantSignature((self.type, self.data)) + return hash((type(self), self.type, self.data)) def __str__(self): return "{}{{{},{},shape={},nnz={}}}".format( diff --git a/aesara/tensor/rewriting/shape.py b/aesara/tensor/rewriting/shape.py index a3b30177f0..3d4fe36f35 100644 --- a/aesara/tensor/rewriting/shape.py +++ b/aesara/tensor/rewriting/shape.py @@ -366,8 +366,8 @@ def set_shape(self, r, s, override=False): assert all( not hasattr(r.type, "broadcastable") or not r.type.broadcastable[i] - or self.lscalar_one.equals(shape_vars[i]) - or self.lscalar_one.equals(extract_constant(shape_vars[i])) + or self.lscalar_one == shape_vars[i] + or self.lscalar_one == extract_constant(shape_vars[i]) for i in range(r.type.ndim) ) self.shape_of[r] = tuple(shape_vars) diff --git a/aesara/tensor/type.py b/aesara/tensor/type.py index 2617270614..93be2efa8a 100644 --- a/aesara/tensor/type.py +++ b/aesara/tensor/type.py @@ -12,7 +12,7 @@ from aesara.graph.utils import MetaType from aesara.link.c.type import CType from aesara.misc.safe_asarray import _asarray -from aesara.utils import apply_across_args +from aesara.utils import HashableNDArray, apply_across_args _logger = logging.getLogger("aesara.tensor.type") @@ -58,7 +58,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape): filter_checks_isfinite = False """ When this is ``True``, strict filtering rejects data containing - ``numpy.nan`` or ``numpy.inf`` entries. (Used in `DebugMode`) + `numpy.nan` or `numpy.inf` entries. (Used in `DebugMode`) """ def __init__( @@ -247,6 +247,14 @@ def filter(self, data, strict=False, allow_downcast=None): if self.filter_checks_isfinite and not np.all(np.isfinite(data)): raise ValueError("Non-finite elements not allowed") + + if not isinstance(data, HashableNDArray): + return data.view(HashableNDArray) + + # TODO: Make sure it's read-only so that we can cache hash values and + # such + data.setflags(write=0) + return data def filter_variable(self, other, allow_convert=True): @@ -339,7 +347,7 @@ def value_zeros(self, shape): TODO: Remove this trivial method. """ - return np.zeros(shape, dtype=self.dtype) + return np.zeros(shape, dtype=self.dtype).view(HashableNDArray) @staticmethod def values_eq(a, b, force_same_dtype=True): diff --git a/aesara/tensor/type_other.py b/aesara/tensor/type_other.py index e0c438c5e5..00c7ed3048 100644 --- a/aesara/tensor/type_other.py +++ b/aesara/tensor/type_other.py @@ -57,12 +57,25 @@ def clone(self, **kwargs): def filter(self, x, strict=False, allow_downcast=None): if isinstance(x, slice): + + if isinstance(x.start, np.ndarray): + assert str(x.start.dtype) in integer_dtypes + x = slice(x.start.item(), x.stop, x.step) + + if isinstance(x.stop, np.ndarray): + assert str(x.stop.dtype) in integer_dtypes + x = slice(x.start, x.stop.item(), x.step) + + if isinstance(x.step, np.ndarray): + assert str(x.step.dtype) in integer_dtypes + x = slice(x.start, x.stop, x.step.item()) + return x else: raise TypeError("Expected a slice!") def __str__(self): - return "slice" + return f"{type(self)}()" def __eq__(self, other): return type(self) == type(other) @@ -80,25 +93,23 @@ def may_share_memory(a, b): class SliceConstant(Constant): + @classmethod + def create_key(cls, type, data, *args, **kwargs): + return (type, data.start, data.stop, data.step) + def __init__(self, type, data, name=None): - assert isinstance(data, slice) - # Numpy ndarray aren't hashable, so get rid of them. - if isinstance(data.start, np.ndarray): - assert data.start.ndim == 0 - assert str(data.start.dtype) in integer_dtypes - data = slice(int(data.start), data.stop, data.step) - elif isinstance(data.stop, np.ndarray): - assert data.stop.ndim == 0 - assert str(data.stop.dtype) in integer_dtypes - data = slice(data.start, int(data.stop), data.step) - elif isinstance(data.step, np.ndarray): - assert data.step.ndim == 0 - assert str(data.step.dtype) in integer_dtypes - data = slice(data.start, int(data.stop), data.step) - Constant.__init__(self, type, data, name) - - def signature(self): - return (SliceConstant, self.data.start, self.data.stop, self.data.step) + super().__init__(type, data, name) + + def __eq__(self, other): + if isinstance(other, type(self)): + if self.data == other.data: + return True + return False + + return NotImplemented + + def __hash__(self): + return hash(self.data.__reduce__()) def __str__(self): return "{}{{{}, {}, {}}}".format( diff --git a/aesara/tensor/var.py b/aesara/tensor/var.py index 8b281e6bd0..09544c42d5 100644 --- a/aesara/tensor/var.py +++ b/aesara/tensor/var.py @@ -1,4 +1,3 @@ -import copy import traceback as tb import warnings from collections.abc import Iterable @@ -16,7 +15,6 @@ from aesara.tensor.exceptions import AdvancedIndexingError from aesara.tensor.type import TensorType from aesara.tensor.type_other import NoneConst -from aesara.tensor.utils import hash_from_ndarray _TensorTypeType = TypeVar("_TensorTypeType", bound=TensorType) @@ -877,119 +875,6 @@ def _get_vector_length_TensorVariable(op_or_var, var): TensorType.variable_type = TensorVariable -class TensorConstantSignature(tuple): - r"""A signature object for comparing `TensorConstant` instances. - - An instance is a pair with the type ``(Type, ndarray)``. - - TODO FIXME: Subclassing `tuple` is unnecessary, and it appears to be - preventing the use of a much more convenient `__init__` that removes the - need for all these lazy computations and their safety checks. - - Also, why do we even need this signature stuff? We could simply implement - good `Constant.__eq__` and `Constant.__hash__` implementations. - - We could also produce plain `tuple`\s with hashable values. - - """ - - def __eq__(self, other): - if type(self) != type(other): - return False - try: - (t0, d0), (t1, d1) = self, other - except Exception: - return False - - # N.B. compare shape to ensure no broadcasting in == - if t0 != t1 or d0.shape != d1.shape: - return False - - self.no_nan # Ensure has_nan is computed. - # Note that in the comparisons below, the elementwise comparisons - # come last because they are the most expensive checks. - if self.has_nan: - other.no_nan # Ensure has_nan is computed. - return ( - other.has_nan - and self.sum == other.sum - and (self.no_nan.mask == other.no_nan.mask).all() - and - # Note that the second test below (==) may crash e.g. for - # a single scalar NaN value, so we do not run it when all - # values are missing. - (self.no_nan.mask.all() or (self.no_nan == other.no_nan).all()) - ) - else: - # Simple case where we do not need to worry about NaN values. - # (note that if there are NaN values in d1, this will return - # False, which is why we do not bother with testing `other.has_nan` - # here). - return (self.sum == other.sum) and np.all(d0 == d1) - - def __ne__(self, other): - return not self == other - - def __hash__(self): - t, d = self - return hash((type(self), t, d.shape, self.sum)) - - def aesara_hash(self): - _, d = self - return hash_from_ndarray(d) - - @property - def sum(self): - """Compute sum of non NaN / Inf values in the array.""" - try: - return self._sum - except AttributeError: - - # Prevent warnings when there are `inf`s and `-inf`s present - with warnings.catch_warnings(): - warnings.simplefilter("ignore", category=RuntimeWarning) - self._sum = self.no_nan.sum() - - # The following 2 lines are needed as in Python 3.3 with NumPy - # 1.7.1, numpy.ndarray and numpy.memmap aren't hashable. - if isinstance(self._sum, np.memmap): - self._sum = np.asarray(self._sum).item() - - if self.has_nan and self.no_nan.mask.all(): - # In this case the sum is not properly computed by numpy. - self._sum = 0 - - if np.isinf(self._sum) or np.isnan(self._sum): - # NaN may happen when there are both -inf and +inf values. - if self.has_nan: - # Filter both NaN and Inf values. - mask = self.no_nan.mask + np.isinf(self[1]) - else: - # Filter only Inf values. - mask = np.isinf(self[1]) - if mask.all(): - self._sum = 0 - else: - self._sum = np.ma.masked_array(self[1], mask).sum() - # At this point there should be no more NaN. - assert not np.isnan(self._sum) - - if isinstance(self._sum, np.ma.core.MaskedConstant): - self._sum = 0 - - return self._sum - - @property - def no_nan(self): - try: - return self._no_nan - except AttributeError: - nans = np.isnan(self[1]) - self._no_nan = np.ma.masked_array(self[1], nans) - self.has_nan = np.any(nans) - return self._no_nan - - def get_unique_value(x: TensorVariable) -> Optional[Number]: """Return the unique value of a tensor, if there is one""" if isinstance(x, Constant): @@ -998,7 +883,7 @@ def get_unique_value(x: TensorVariable) -> Optional[Number]: if isinstance(data, np.ndarray) and data.ndim > 0: flat_data = data.ravel() if flat_data.shape[0]: - if (flat_data == flat_data[0]).all(): + if np.all(flat_data == flat_data[0]): return flat_data[0] return None @@ -1039,31 +924,6 @@ def __str__(self): name = "TensorConstant" return "%s{%s}" % (name, val) - def signature(self): - return TensorConstantSignature((self.type, self.data)) - - def equals(self, other): - # Override Constant.equals to allow to compare with - # numpy.ndarray, and python type. - if isinstance(other, (np.ndarray, int, float)): - # Make a TensorConstant to be able to compare - other = at.basic.constant(other) - return ( - isinstance(other, TensorConstant) and self.signature() == other.signature() - ) - - def __copy__(self): - # We need to do this to remove the cached attribute - return type(self)(self.type, self.data, self.name) - - def __deepcopy__(self, memo): - # We need to do this to remove the cached attribute - return type(self)( - copy.deepcopy(self.type, memo), - copy.deepcopy(self.data, memo), - copy.deepcopy(self.name, memo), - ) - TensorType.constant_type = TensorConstant diff --git a/tests/graph/test_basic.py b/tests/graph/test_basic.py index 4964e01f08..5818a8f223 100644 --- a/tests/graph/test_basic.py +++ b/tests/graph/test_basic.py @@ -8,6 +8,7 @@ from aesara import tensor as at from aesara.graph.basic import ( Apply, + Constant, NominalVariable, Variable, ancestors, @@ -41,10 +42,14 @@ ) from aesara.tensor.type_other import NoneConst from aesara.tensor.var import TensorVariable +from aesara.utils import HashableNDArray from tests import unittest_tools as utt from tests.graph.utils import MyInnerGraphOp +pytestmark = pytest.mark.filterwarnings("error") + + class MyType(Type): def __init__(self, thingy): self.thingy = thingy @@ -84,7 +89,7 @@ def perform(self, *args, **kwargs): raise NotImplementedError("No Python implementation available.") -MyOp = MyOp() +my_op = MyOp() def leaf_formatter(leaf): @@ -107,29 +112,29 @@ def format_graph(inputs, outputs): class TestStr: def test_as_string(self): r1, r2 = MyVariable(1), MyVariable(2) - node = MyOp.make_node(r1, r2) + node = my_op.make_node(r1, r2) s = format_graph([r1, r2], node.outputs) assert s == ["MyOp(R1, R2)"] def test_as_string_deep(self): r1, r2, r5 = MyVariable(1), MyVariable(2), MyVariable(5) - node = MyOp.make_node(r1, r2) - node2 = MyOp.make_node(node.outputs[0], r5) + node = my_op.make_node(r1, r2) + node2 = my_op.make_node(node.outputs[0], r5) s = format_graph([r1, r2, r5], node2.outputs) assert s == ["MyOp(MyOp(R1, R2), R5)"] def test_multiple_references(self): r1, r2, r5 = MyVariable(1), MyVariable(2), MyVariable(5) - node = MyOp.make_node(r1, r2) - node2 = MyOp.make_node(node.outputs[0], node.outputs[0]) + node = my_op.make_node(r1, r2) + node2 = my_op.make_node(node.outputs[0], node.outputs[0]) assert format_graph([r1, r2, r5], node2.outputs) == [ "MyOp(*1 -> MyOp(R1, R2), *1)" ] def test_cutoff(self): r1, r2 = MyVariable(1), MyVariable(2) - node = MyOp.make_node(r1, r2) - node2 = MyOp.make_node(node.outputs[0], node.outputs[0]) + node = my_op.make_node(r1, r2) + node2 = my_op.make_node(node.outputs[0], node.outputs[0]) assert format_graph(node.outputs, node2.outputs) == ["MyOp(R3, R3)"] assert format_graph(node2.inputs, node2.outputs) == ["MyOp(R3, R3)"] @@ -137,43 +142,27 @@ def test_cutoff(self): class TestClone: def test_accurate(self): r1, r2 = MyVariable(1), MyVariable(2) - node = MyOp.make_node(r1, r2) - _, new = clone([r1, r2], node.outputs, False) + node = my_op.make_node(r1, r2) + _, new = clone([r1, r2], node.outputs, copy_inputs=False) assert format_graph([r1, r2], new) == ["MyOp(R1, R2)"] def test_copy(self): r1, r2, r5 = MyVariable(1), MyVariable(2), MyVariable(5) - node = MyOp.make_node(r1, r2) - node2 = MyOp.make_node(node.outputs[0], r5) - _, new = clone([r1, r2, r5], node2.outputs, False) - assert ( - node2.outputs[0].type == new[0].type and node2.outputs[0] is not new[0] - ) # the new output is like the old one but not the same object - assert node2 is not new[0].owner # the new output has a new owner + node = my_op.make_node(r1, r2) + node2 = my_op.make_node(node.outputs[0], r5) + _, new = clone([r1, r2, r5], node2.outputs, copy_inputs=False) + assert node2.outputs[0].type == new[0].type and node2.outputs[0] is new[0] + assert node2 is new[0].owner assert new[0].owner.inputs[1] is r5 # the inputs are not copied assert ( new[0].owner.inputs[0].type == node.outputs[0].type - and new[0].owner.inputs[0] is not node.outputs[0] - ) # check that we copied deeper too - - def test_not_destructive(self): - # Checks that manipulating a cloned graph leaves the original unchanged. - r1, r2, r5 = MyVariable(1), MyVariable(2), MyVariable(5) - node = MyOp.make_node(MyOp.make_node(r1, r2).outputs[0], r5) - _, new = clone([r1, r2, r5], node.outputs, False) - new_node = new[0].owner - new_node.inputs = [MyVariable(7), MyVariable(8)] - assert format_graph(graph_inputs(new_node.outputs), new_node.outputs) == [ - "MyOp(R7, R8)" - ] - assert format_graph(graph_inputs(node.outputs), node.outputs) == [ - "MyOp(MyOp(R1, R2), R5)" - ] + and new[0].owner.inputs[0] is node.outputs[0] + ) def test_constant(self): r1, r2, r5 = MyVariable(1), MyVariable(2), MyVariable(5) - node = MyOp.make_node(MyOp.make_node(r1, r2).outputs[0], r5) - _, new = clone([r1, r2, r5], node.outputs, False) + node = my_op.make_node(my_op.make_node(r1, r2).outputs[0], r5) + _, new = clone([r1, r2, r5], node.outputs, copy_inputs=False) new_node = new[0].owner new_node.inputs = [MyVariable(7), MyVariable(8)] c1 = at.constant(1.5) @@ -192,13 +181,13 @@ def test_constant(self): def test_clone_inner_graph(self): r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) - o1 = MyOp(r1, r2) + o1 = my_op(r1, r2) o1.name = "o1" # Inner graph igo_in_1 = MyVariable(4) igo_in_2 = MyVariable(5) - igo_out_1 = MyOp(igo_in_1, igo_in_2) + igo_out_1 = my_op(igo_in_1, igo_in_2) igo_out_1.name = "igo1" igo = MyInnerGraphOp([igo_in_1, igo_in_2], [igo_out_1]) @@ -209,8 +198,8 @@ def test_clone_inner_graph(self): o2_node = o2.owner o2_node_clone = o2_node.clone(clone_inner_graph=True) - assert o2_node_clone is not o2_node - assert o2_node_clone.op.fgraph is not o2_node.op.fgraph + assert o2_node_clone is o2_node + assert o2_node_clone.op.fgraph is o2_node.op.fgraph assert equal_computations( o2_node_clone.op.fgraph.outputs, o2_node.op.fgraph.outputs ) @@ -228,9 +217,9 @@ class TestToposort: def test_simple(self): # Test a simple graph r1, r2, r5 = MyVariable(1), MyVariable(2), MyVariable(5) - o = MyOp(r1, r2) + o = my_op(r1, r2) o.name = "o1" - o2 = MyOp(o, r5) + o2 = my_op(o, r5) o2.name = "o2" clients = {} @@ -257,49 +246,50 @@ def test_simple(self): def test_double_dependencies(self): # Test a graph with double dependencies r1, r5 = MyVariable(1), MyVariable(5) - o = MyOp.make_node(r1, r1) - o2 = MyOp.make_node(o.outputs[0], r5) + o = my_op.make_node(r1, r1) + o2 = my_op.make_node(o.outputs[0], r5) all = general_toposort(o2.outputs, prenode) assert all == [r5, r1, o, o.outputs[0], o2, o2.outputs[0]] def test_inputs_owners(self): # Test a graph where the inputs have owners r1, r5 = MyVariable(1), MyVariable(5) - o = MyOp.make_node(r1, r1) + o = my_op.make_node(r1, r1) r2b = o.outputs[0] - o2 = MyOp.make_node(r2b, r2b) + o2 = my_op.make_node(r2b, r2b) all = io_toposort([r2b], o2.outputs) assert all == [o2] - o2 = MyOp.make_node(r2b, r5) + o2 = my_op.make_node(r2b, r5) all = io_toposort([r2b], o2.outputs) assert all == [o2] def test_not_connected(self): # Test a graph which is not connected r1, r2, r3, r4 = MyVariable(1), MyVariable(2), MyVariable(3), MyVariable(4) - o0 = MyOp.make_node(r1, r2) - o1 = MyOp.make_node(r3, r4) + o0 = my_op.make_node(r1, r2) + o1 = my_op.make_node(r3, r4) all = io_toposort([r1, r2, r3, r4], o0.outputs + o1.outputs) assert all == [o1, o0] or all == [o0, o1] def test_io_chain(self): # Test inputs and outputs mixed together in a chain graph r1, r2 = MyVariable(1), MyVariable(2) - o0 = MyOp.make_node(r1, r2) - o1 = MyOp.make_node(o0.outputs[0], r1) + o0 = my_op.make_node(r1, r2) + o1 = my_op.make_node(o0.outputs[0], r1) all = io_toposort([r1, o0.outputs[0]], [o0.outputs[0], o1.outputs[0]]) assert all == [o1] def test_outputs_clients(self): # Test when outputs have clients r1, r2, r4 = MyVariable(1), MyVariable(2), MyVariable(4) - o0 = MyOp.make_node(r1, r2) - MyOp.make_node(o0.outputs[0], r4) + o0 = my_op.make_node(r1, r2) + my_op.make_node(o0.outputs[0], r4) all = io_toposort([], o0.outputs) assert all == [o0] +@pytest.mark.skip(reason="Not finished") class TestEval: def setup_method(self): self.x, self.y = scalars("x", "y") @@ -391,9 +381,9 @@ def test_equal_computations(): def test_walk(): r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) - o1 = MyOp(r1, r2) + o1 = my_op(r1, r2) o1.name = "o1" - o2 = MyOp(r3, o1) + o2 = my_op(r3, o1) o2.name = "o2" def expand(r): @@ -422,9 +412,9 @@ def expand(r): def test_ancestors(): r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) - o1 = MyOp(r1, r2) + o1 = my_op(r1, r2) o1.name = "o1" - o2 = MyOp(r3, o1) + o2 = my_op(r3, o1) o2.name = "o2" res = ancestors([o2], blockers=None) @@ -444,9 +434,9 @@ def test_ancestors(): def test_graph_inputs(): r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) - o1 = MyOp(r1, r2) + o1 = my_op(r1, r2) o1.name = "o1" - o2 = MyOp(r3, o1) + o2 = my_op(r3, o1) o2.name = "o2" res = graph_inputs([o2], blockers=None) @@ -457,9 +447,9 @@ def test_graph_inputs(): def test_variables_and_orphans(): r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) - o1 = MyOp(r1, r2) + o1 = my_op(r1, r2) o1.name = "o1" - o2 = MyOp(r3, o1) + o2 = my_op(r3, o1) o2.name = "o2" vars_res = vars_between([r1, r2], [o2]) @@ -474,11 +464,11 @@ def test_variables_and_orphans(): def test_ops(): r1, r2, r3, r4 = MyVariable(1), MyVariable(2), MyVariable(3), MyVariable(4) - o1 = MyOp(r1, r2) + o1 = my_op(r1, r2) o1.name = "o1" - o2 = MyOp(r3, r4) + o2 = my_op(r3, r4) o2.name = "o2" - o3 = MyOp(r3, o1, o2) + o3 = my_op(r3, o1, o2) o3.name = "o3" res = applys_between([r1, r2], [o3]) @@ -489,9 +479,9 @@ def test_ops(): def test_list_of_nodes(): r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) - o1 = MyOp(r1, r2) + o1 = my_op(r1, r2) o1.name = "o1" - o2 = MyOp(r3, o1) + o2 = my_op(r3, o1) o2.name = "o2" res = list_of_nodes([r1, r2], [o2]) @@ -501,9 +491,9 @@ def test_list_of_nodes(): def test_is_in_ancestors(): r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) - o1 = MyOp(r1, r2) + o1 = my_op(r1, r2) o1.name = "o1" - o2 = MyOp(r3, o1) + o2 = my_op(r3, o1) o2.name = "o2" assert is_in_ancestors(o2.owner, o1.owner) @@ -522,13 +512,13 @@ def test_view_roots(): def test_get_var_by_name(): r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) - o1 = MyOp(r1, r2) + o1 = my_op(r1, r2) o1.name = "o1" # Inner graph igo_in_1 = MyVariable(4) igo_in_2 = MyVariable(5) - igo_out_1 = MyOp(igo_in_1, igo_in_2) + igo_out_1 = my_op(igo_in_1, igo_in_2) igo_out_1.name = "igo1" igo = MyInnerGraphOp([igo_in_1, igo_in_2], [igo_out_1]) @@ -661,6 +651,7 @@ def test_cloning_replace_not_strict_not_copy_inputs(self): assert x not in f2_inp assert y2 not in f2_inp + @pytest.mark.skip(reason="Not finished") def test_clone(self): def test(x, y, mention_y): if mention_y: @@ -765,7 +756,7 @@ def test_NominalVariable(): assert repr(nv5) == f"NominalVariable(2, {repr(type3)})" - assert nv5.signature() == (type3, 2) + assert hash(nv5) == hash((type(nv5), 2, type3)) nv5_pkld = pickle.dumps(nv5) nv5_unpkld = pickle.loads(nv5_pkld) @@ -801,5 +792,75 @@ def test_NominalVariable_create_variable_type(): ntv_unpkld = pickle.loads(ntv_pkld) assert type(ntv_unpkld) is type(ntv) - assert ntv_unpkld.equals(ntv) + assert ntv_unpkld == ntv assert ntv_unpkld is ntv + + +def test_Apply_equivalence(): + + type1 = MyType(1) + + in_1 = Variable(type1, None, name="in_1") + in_2 = Variable(type1, None, name="in_2") + out_10 = Variable(type1, None, name="out_10") + out_11 = Variable(type1, None, name="out_11") + out_12 = Variable(type1, None, name="out_12") + + apply_1 = Apply(my_op, [in_1], [out_10]) + apply_2 = Apply(my_op, [in_1], [out_11]) + apply_3 = Apply(my_op, [in_2], [out_12]) + + assert apply_1 is apply_2 + assert apply_1.inputs == apply_2.inputs + assert apply_1.outputs == apply_2.outputs + assert apply_1.outputs[0] is out_10 + assert apply_2.outputs[0] is out_10 + assert apply_1 == apply_2 + assert apply_1 != apply_3 + assert hash(apply_1) == hash(apply_2) + assert hash(apply_1) != hash(apply_3) + + apply_1_pkl = pickle.dumps(apply_1) + apply_1_2 = pickle.loads(apply_1_pkl) + + assert apply_1.op == apply_1_2.op + assert len(apply_1.inputs) == len(apply_1_2.inputs) + assert len(apply_1.outputs) == len(apply_1_2.outputs) + assert apply_1.inputs[0].type == apply_1_2.inputs[0].type + assert apply_1.inputs[0].name == apply_1_2.inputs[0].name + assert apply_1.outputs[0].type == apply_1_2.outputs[0].type + assert apply_1.outputs[0].name == apply_1_2.outputs[0].name + + +class MyType2(MyType): + def filter(self, value, **kwargs): + value = np.asarray(value).view(HashableNDArray) + return value + + +def test_Constant_equivalence(): + type1 = MyType2(1) + x = Constant(type1, 1.0) + y = Constant(type1, 1.0) + + assert x == y + assert x is y + + rng = np.random.default_rng(3209) + a_val = rng.normal(size=(2, 3)) + c_val = rng.normal(size=(2, 3)) + + a = Constant(type1, a_val) + b = Constant(type1, a_val) + c = Constant(type1, c_val) + + assert a == b + assert a is b + assert a != x + assert a != c + + a_pkl = pickle.dumps(a) + a_2 = pickle.loads(a_pkl) + + assert a.type == a_2.type + assert a.data == a_2.data