Skip to content

Commit

Permalink
add species indexing node
Browse files Browse the repository at this point in the history
  • Loading branch information
shinkle-lanl committed Feb 6, 2025
1 parent 78059c7 commit c59f2e1
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 4 deletions.
2 changes: 1 addition & 1 deletion hippynn/graphs/nodes/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .base import Node, SingleNode, InputNode, LossInputNode, LossPredNode, LossTrueNode, _BaseNode

# Node that provides multiple outputs
from .multi import MultiNode
from .multi import MultiNode, IndexNode

# Optional mixins for simplifying the process of defining BaseNode subclasses
from .definition_helpers import AutoKw, AutoNoKw, ExpandParents
37 changes: 35 additions & 2 deletions hippynn/graphs/nodes/indexers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
Nodes for indexing information.
"""
from .tags import Encoder, AtomIndexer
from .base import SingleNode, AutoNoKw, AutoKw, find_unique_relative, MultiNode, ExpandParents, _BaseNode
from .base import SingleNode, AutoNoKw, AutoKw, find_unique_relative, MultiNode, ExpandParents, _BaseNode, IndexNode
from .base.node_functions import NodeNotFound
from .inputs import SpeciesNode

# Index generating functions need access to appropriately raise this
from ..indextypes import IdxType
from ..indextypes.reduce_funcs import index_type_coercion
from ...layers import indexers as index_modules


Expand Down Expand Up @@ -208,4 +209,36 @@ def __init__(self, name, parents, length, vmin, vmax, module="auto", **kwargs):
self._output_index_state = parents[0]._index_state
self.module_kwargs = {"length": length, "vmin": vmin, "vmax": vmax}

super().__init__(name, parents, module=module, **kwargs)
super().__init__(name, parents, module=module, **kwargs)

class SpeciesIndexed(AutoNoKw, SingleNode, ExpandParents):
_input_names = "values", "onehot_encoding"
_auto_module_class = index_modules.SpeciesIndex
_index_state = IdxType.Atoms

@_parent_expander.match(_BaseNode)
def expansion0(self, node_to_index, species_set, **kwargs):
atom_node_to_index = index_type_coercion(node_to_index, IdxType.Atoms)
onehot = find_unique_relative(atom_node_to_index, OneHotEncoder)
self.species_set = species_set or onehot.species_set
return atom_node_to_index, onehot.encoding

# add asserts for parent expansion
_parent_expander.assertlen(2)
_parent_expander.get_main_outputs()
_parent_expander.require_idx_states(IdxType.Atoms, IdxType.Atoms)

def __init__(self, name, parents, *args, module="auto", species_set=None, **kwargs):
parents = self.expand_parents(parents, species_set=species_set)
super().__init__(name, parents, *args, module=module, **kwargs)

nonzero_species = [species for species in self.species_set if species != 0]
self.species_to_idx = {species: idx for idx, species in enumerate(nonzero_species)}

self.children = tuple(
IndexNode(name=f"{name}_{species}", parents=(self,), index=idx, index_state=IdxType.Atoms)
for species, idx in self.species_to_idx.items()
)

def with_species_equal(self, z_value):
return self.children[self.species_to_idx(z_value)]
12 changes: 11 additions & 1 deletion hippynn/layers/indexers.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,4 +270,14 @@ def forward(self, values):
values = values[...,None]
x = values - self.bins
histo = torch.exp(-((x / self.sigma) ** 2) / 4)
return torch.flatten(histo, end_dim=1)
return torch.flatten(histo, end_dim=1)

class SpeciesIndex(torch.nn.Module):
def forward(self, values, onehot_encoding):
n_species = onehot_encoding.shape[1]
values_by_species = []
for i in range(n_species):
species_mask = onehot_encoding[:,i]
species_values = values[species_mask]
values_by_species.append(species_values)
return values_by_species

0 comments on commit c59f2e1

Please sign in to comment.