Skip to content

Commit

Permalink
update interface to ase and lammps for atomization consistent networks
Browse files Browse the repository at this point in the history
  • Loading branch information
lubbersnick committed Jan 7, 2025
1 parent 5bc48e8 commit 5aa82ed
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
18 changes: 11 additions & 7 deletions hippynn/interfaces/ase_interface/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from hippynn.graphs.nodes.misc import StrainInducer
from hippynn.graphs.nodes.physics import CoulombEnergyNode, DipoleNode, StressForceNode
from hippynn.graphs.nodes.pairs import PairFilter

from hippynn.graphs.nodes.targets import AtomizationEnergyNode
from hippynn.graphs.nodes.inputs import SpeciesNode, PositionsNode, CellNode


Expand All @@ -26,7 +26,9 @@
# This works for orthorhombic boxes and is much faster than ASE...


def setup_ASE_graph(energy, charges=None, extra_properties=None):
def setup_ASE_graph(energy, charges=None, extra_properties=None, species_set=None,indexer=None):
if isinstance(energy, AtomizationEnergyNode):
energy = energy.create_henergy_equivalent()

if charges is None:
required_nodes = [energy]
Expand Down Expand Up @@ -72,9 +74,11 @@ def setup_ASE_graph(energy, charges=None, extra_properties=None):
positions = find_unique_relative(new_required, search_fn(PositionsNode, new_subgraph), why_desc=why)

# TODO: is .clone necessary? Or good? Or torch.as_tensor instead?
encoder = find_unique_relative(species, search_fn(Encoder, new_subgraph), why_desc=why)
species_set = torch.as_tensor(encoder.species_set).to(torch.int64) # works with lists or tensors
indexer = find_unique_relative(species, search_fn(AtomIndexer, new_subgraph), why_desc=why)
if species_set is None:
encoder = find_unique_relative(species, search_fn(Encoder, new_subgraph), why_desc=why)
species_set = torch.as_tensor(encoder.species_set).to(torch.int64) # works with lists or tensors
if indexer is None:
indexer = find_unique_relative(species, search_fn(AtomIndexer, new_subgraph), why_desc=why)
min_radius = max(p.dist_hard_max for p in pair_indexers)
###############################################################

Expand Down Expand Up @@ -214,7 +218,7 @@ class HippynnCalculator(Calculator): # Calculator inheritance required for ASE M
ASE calculator based on hippynn graphs. Uses ASE neighbor lists. Not suitable for domain decomposition.
"""

def __init__(self, energy, charges=None, skin=1.0, extra_properties=None, en_unit=None, dist_unit=None):
def __init__(self, energy, charges=None, skin=1.0, extra_properties=None, en_unit=None, dist_unit=None, species_set=None, indexer=None):
"""
:param energy: Node for energy
:param charges: Node for charges (optional)
Expand All @@ -228,7 +232,7 @@ def __init__(self, energy, charges=None, skin=1.0, extra_properties=None, en_uni
"""

self.min_radius, self.species_set, self.implemented_properties, self.module, self.pbc = setup_ASE_graph(
energy, charges=charges, extra_properties=extra_properties
energy, charges=charges, extra_properties=extra_properties, species_set=species_set,indexer=indexer,
)

self.implemented_properties.append("energy") # Required for using mixing calculators in ASE
Expand Down
6 changes: 4 additions & 2 deletions hippynn/interfaces/lammps_interface/mliap_interface.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""
Interface for creating LAMMPS MLIAP Unified models.
"""
import pickle
import warnings

import numpy as np
Expand All @@ -20,7 +19,7 @@
from hippynn.graphs.nodes.physics import GradientNode, VecMag
from hippynn.graphs.nodes.inputs import SpeciesNode
from hippynn.graphs.nodes.pairs import PairFilter

from hippynn.graphs.nodes.targets import AtomizationEnergyNode

class MLIAPInterface(MLIAPUnified):
"""
Expand Down Expand Up @@ -175,6 +174,9 @@ def setup_LAMMPS_graph(energy):
:param energy: energy node for lammps interface
:return: graph for computing from lammps MLIAP unified inputs.
"""
if isinstance(energy, AtomizationEnergyNode):
energy = energy.create_henergy_equivalent()

required_nodes = [energy]

why = "Generating LAMMPS Calculator interface"
Expand Down

0 comments on commit 5aa82ed

Please sign in to comment.