From 46f4b06950c7dc366321975d3e23299a6877fbdd Mon Sep 17 00:00:00 2001 From: Nicholas Lubbers Date: Tue, 9 Apr 2024 11:46:24 -0600 Subject: [PATCH] remove need for backwards compatibility operation in feature nodes --- examples/feature_extraction.py | 4 -- hippynn/graphs/nodes/networks.py | 63 +++++++++++++++++--------------- 2 files changed, 33 insertions(+), 34 deletions(-) diff --git a/examples/feature_extraction.py b/examples/feature_extraction.py index acd96fef..edd62dec 100644 --- a/examples/feature_extraction.py +++ b/examples/feature_extraction.py @@ -23,10 +23,6 @@ # Here we get nodes associated with the features from each block of a HIP-NN model. # Note: the first set is typically just a one-hot species representation, and does not reflect the atom's environment. -# Note!!!: For backwards compatibility, you may need to run this function on the network: -hippynn.graphs.networks._make_feature_nodes(network_node) -# If your network is created with this version of hippynn or later, it is not necessary to run this function. - feature_nodes = network_node.feature_nodes # list of feature nodes feature_node_dict = {node.name: node for node in feature_nodes} # dictionary of feature nodes diff --git a/hippynn/graphs/nodes/networks.py b/hippynn/graphs/nodes/networks.py index e7f9defb..08494a82 100644 --- a/hippynn/graphs/nodes/networks.py +++ b/hippynn/graphs/nodes/networks.py @@ -53,7 +53,37 @@ def expansion1(self, pidxer, pairfinder, **kwargs): return pidxer.indexed_features, pairfinder -class Hipnn(DefaultNetworkExpansion, AutoKw, Network, SingleNode): +class _FeatureNodesMixin: + @property + def feature_nodes(self): + if not hasattr(self, "_feature_nodes"): + self._make_feature_nodes() + return self._feature_nodes + + def _make_feature_nodes(self): + """ + This function can be used on a network to make nodes that refer to the individual feature blocks. + We use this function/class to provide backwards compatibility with models that did not have this + attribute when created. + :param self: the input network, which is modified in-place + :return: None + """ + + net_module = self.torch_module + n_interactions = net_module.ni + + feature_nodes = [] + + index_state = IdxType.Atoms + parents = (self,) + for i in range(n_interactions + 1): + name = f"{self.name}_features_{i}" + fnode = IndexNode(name=name, parents=parents, index=i, index_state=index_state) + feature_nodes.append(fnode) + self._feature_nodes = feature_nodes + + +class Hipnn(DefaultNetworkExpansion, AutoKw, Network, SingleNode,_FeatureNodesMixin): """ Node for HIP-NN neural networks """ @@ -81,11 +111,10 @@ def __init__(self, name, parents, periodic=False, module="auto", module_kwargs=N ) super().__init__(name, parents, module=net_module) - _make_feature_nodes(self) -class HipnnVec(DefaultNetworkExpansion, AutoKw, Network, SingleNode): +class HipnnVec(DefaultNetworkExpansion, AutoKw, Network, SingleNode,_FeatureNodesMixin): """ Node for HIP-NN-TS neural networks, l=1 """ @@ -114,7 +143,6 @@ def __init__(self, name, parents, periodic=False, module="auto", module_kwargs=N super().__init__(name, parents, module=net_module) - _make_feature_nodes(self) class HipnnQuad(HipnnVec): @@ -125,30 +153,5 @@ class HipnnQuad(HipnnVec): _auto_module_class = network_modules.hipnn.HipnnQuad -def _make_feature_nodes(network_node): - """ - This function can be used on a network to make nodes that refer to the individual feature blocks. - :param network_node: the input network, which is modified in-place - :return: None - """ - import warnings - warnings.warn("This function is included for backwards compatibility and may be removed in a future release. " - "The preferred way to access these nodes is through `network.feature_nodes`, which is available on " - "networks created with this version of hippynn or later.") - - if hasattr(network_node, "feature_nodes"): - return network_node.feature_nodes - - net_module = network_node.torch_module - n_interactions = net_module.ni - - feature_nodes = [] - - index_state = IdxType.Atoms - parents = (network_node,) - for i in range(n_interactions + 1): - name = f"{network_node.name}_features_{i}" - fnode = IndexNode(name=name, parents=parents, index=i, index_state=index_state) - feature_nodes.append(fnode) - network_node.feature_nodes = feature_nodes +