Skip to content

Commit

Permalink
remove need for backwards compatibility operation in feature nodes (l…
Browse files Browse the repository at this point in the history
…anl#65)

* remove need for backwards compatibility operation in feature nodes

eliminates unneeded warning when creating new networks
  • Loading branch information
lubbersnick authored Apr 9, 2024
1 parent e576555 commit ba363b3
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 40 deletions.
4 changes: 0 additions & 4 deletions examples/feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@
# Here we get nodes associated with the features from each block of a HIP-NN model.
# Note: the first set is typically just a one-hot species representation, and does not reflect the atom's environment.

# Note!!!: For backwards compatibility, you may need to run this function on the network:
hippynn.graphs.networks._make_feature_nodes(network_node)
# If your network is created with this version of hippynn or later, it is not necessary to run this function.

feature_nodes = network_node.feature_nodes # list of feature nodes
feature_node_dict = {node.name: node for node in feature_nodes} # dictionary of feature nodes

Expand Down
68 changes: 32 additions & 36 deletions hippynn/graphs/nodes/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,37 @@ def expansion1(self, pidxer, pairfinder, **kwargs):
return pidxer.indexed_features, pairfinder


class Hipnn(DefaultNetworkExpansion, AutoKw, Network, SingleNode):
class _FeatureNodesMixin:
@property
def feature_nodes(self):
if not hasattr(self, "_feature_nodes"):
self._make_feature_nodes()
return self._feature_nodes

def _make_feature_nodes(self):
"""
This function can be used on a network to make nodes that refer to the individual feature blocks.
We use this function/class to provide backwards compatibility with models that did not have this
attribute when created.
:param self: the input network, which is modified in-place
:return: None
"""

net_module = self.torch_module
n_interactions = net_module.ni

feature_nodes = []

index_state = IdxType.Atoms
parents = (self,)
for i in range(n_interactions + 1):
name = f"{self.name}_features_{i}"
fnode = IndexNode(name=name, parents=parents, index=i, index_state=index_state)
feature_nodes.append(fnode)
self._feature_nodes = feature_nodes


class Hipnn(DefaultNetworkExpansion, AutoKw, Network, SingleNode, _FeatureNodesMixin):
"""
Node for HIP-NN neural networks
"""
Expand Down Expand Up @@ -81,11 +111,8 @@ def __init__(self, name, parents, periodic=False, module="auto", module_kwargs=N
)
super().__init__(name, parents, module=net_module)

_make_feature_nodes(self)



class HipnnVec(DefaultNetworkExpansion, AutoKw, Network, SingleNode):
class HipnnVec(DefaultNetworkExpansion, AutoKw, Network, SingleNode, _FeatureNodesMixin):
"""
Node for HIP-NN-TS neural networks, l=1
"""
Expand Down Expand Up @@ -114,41 +141,10 @@ def __init__(self, name, parents, periodic=False, module="auto", module_kwargs=N

super().__init__(name, parents, module=net_module)

_make_feature_nodes(self)


class HipnnQuad(HipnnVec):
"""
Node for HIP-NN-TS neural networks, l=2
"""

_auto_module_class = network_modules.hipnn.HipnnQuad


def _make_feature_nodes(network_node):
"""
This function can be used on a network to make nodes that refer to the individual feature blocks.
:param network_node: the input network, which is modified in-place
:return: None
"""
import warnings
warnings.warn("This function is included for backwards compatibility and may be removed in a future release. "
"The preferred way to access these nodes is through `network.feature_nodes`, which is available on "
"networks created with this version of hippynn or later.")

if hasattr(network_node, "feature_nodes"):
return network_node.feature_nodes

net_module = network_node.torch_module
n_interactions = net_module.ni

feature_nodes = []

index_state = IdxType.Atoms
parents = (network_node,)
for i in range(n_interactions + 1):
name = f"{network_node.name}_features_{i}"
fnode = IndexNode(name=name, parents=parents, index=i, index_state=index_state)
feature_nodes.append(fnode)
network_node.feature_nodes = feature_nodes

0 comments on commit ba363b3

Please sign in to comment.