From 7d3e0f3af24e2d7b4d83dddb2292fa2dfe464df2 Mon Sep 17 00:00:00 2001 From: Nicholas Lubbers <56895592+lubbersnick@users.noreply.github.com> Date: Sat, 13 Apr 2024 01:16:37 -0600 Subject: [PATCH] Ensemble creation feature (#68) * first draft of ensemble code * add ensemble.py * working version, still working on merging identical nodes * add merging of identical nodes and some cleaning * try to add type annotations * adding ensemble usage example * update docs for ensembles and more * further docs * update example wording --- docs/source/examples/ensembles.rst | 25 ++ docs/source/examples/index.rst | 8 +- docs/source/index.rst | 21 +- docs/source/installation.rst | 19 +- examples/ensembling_models.py | 56 ++++ hippynn/graphs/__init__.py | 2 + hippynn/graphs/ensemble.py | 240 ++++++++++++++++++ hippynn/graphs/gops.py | 109 +++++++- hippynn/graphs/graph.py | 13 +- hippynn/graphs/indextypes/reduce_funcs.py | 2 +- hippynn/graphs/indextypes/registry.py | 34 +++ .../graphs/nodes/base/definition_helpers.py | 28 +- hippynn/graphs/nodes/misc.py | 26 +- hippynn/graphs/viz.py | 2 +- .../interfaces/ase_interface/ase_database.py | 21 +- hippynn/layers/__init__.py | 2 +- hippynn/layers/algebra.py | 9 + hippynn/tools.py | 33 +++ 18 files changed, 625 insertions(+), 25 deletions(-) create mode 100644 docs/source/examples/ensembles.rst create mode 100644 examples/ensembling_models.py create mode 100644 hippynn/graphs/ensemble.py diff --git a/docs/source/examples/ensembles.rst b/docs/source/examples/ensembles.rst new file mode 100644 index 00000000..6389bd50 --- /dev/null +++ b/docs/source/examples/ensembles.rst @@ -0,0 +1,25 @@ +Ensembling Models +################# + + +Using the :func:`~hippynn.graphs.make_ensemble` function makes it easy to combine models. + +By default, ensembling is based on the db_name for the nodes in each input graph. +Nodes which have the same name will be assigned an ensemble node which combines +the different versions of that quantity, and additionally calculates the +mean and standard deviation. + +It is easy to make an ensemble from a glob string or a list of directories where +the models are saved:: + + from hippynn.graphs import make_ensemble + model_form = '../../collected_models/quad0_b512_p5_GPU*' + ensemble_graph, ensemble_info = make_ensemble(model_form) + +The ensemble graph takes the inputs which are required for all of the models in the ensemble. +The ``ensemble_info`` object provides the counts for the inputs and targets of the ensemble +and the counts of those corresponding quantities across the ensemble members. + +A typical use case would be to then build a Predictor or ASE Calculator from the ensemble. +See :file:`~examples/ensembling_models.py` for a detailed example. + diff --git a/docs/source/examples/index.rst b/docs/source/examples/index.rst index cd3002db..78703eac 100644 --- a/docs/source/examples/index.rst +++ b/docs/source/examples/index.rst @@ -3,8 +3,10 @@ Examples Here are some examples about how to use various features in ``hippynn``. Besides the :doc:`/examples/minimal_workflow` example, -the examples are just snippets. For fully-fledged examples see the -``examples`` directory in the repository. +the examples are just snippets. For runnable example scripts, see +`the examples at the hippynn github repository`_ + +.. _`the examples at the hippynn github repository`: https://github.com/lanl/hippynn/tree/development/examples .. toctree:: :maxdepth: 1 @@ -13,6 +15,7 @@ the examples are just snippets. For fully-fledged examples see the controller plotting predictor + ensembles periodic forces restarting @@ -21,3 +24,4 @@ the examples are just snippets. For fully-fledged examples see the excited_states weighted_loss + diff --git a/docs/source/index.rst b/docs/source/index.rst index 5ee05e79..50bcd450 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -12,10 +12,28 @@ What is hippynn? We aim to provide high-performance modular design so that different components can be re-used, extended, or added to. You can find more information at the :doc:`/user_guide/features` page. The development home is located -at `the hippynn github repository`_. +at `the hippynn github repository`_, which also contains `many example files`_ +The main components of hippynn are constructing models, loading databases, +training the models to those databases, making predictions on new databases, +and interfacing with other atomistic codes. In particular, we provide interfaces +to `ASE`_ (prediction), `PYSEQM`_ (training/prediction), and `LAMMPS`_ (prediction). +hippynn is also used within `ALF`_ for generating machine learned potentials +along with their training data completely from scratch. + +Multiple formats for training data are supported, including +Numpy arrays, the ASE Database, `fitSNAP`_ JSON format, and `ANI HDF5 files`_. + +.. _`ASE`: https://wiki.fysik.dtu.dk/ase/ +.. _`PYSEQM`: https://github.com/lanl/PYSEQM/ +.. _`LAMMPS`: https://www.lammps.org +.. _`fitSNAP`: https://github.com/FitSNAP/FitSNAP +.. _`ANI HDF5 files`: https://doi.org/10.1038/s41597-020-0473-z +.. _`ALF`: https://github.com/lanl/ALF/ .. _`the hippynn github repository`: https://github.com/lanl/hippynn/ +.. _`many example files`: https://github.com/lanl/hippynn/tree/development/examples + .. toctree:: :maxdepth: 1 @@ -27,7 +45,6 @@ at `the hippynn github repository`_. hippynn API documentation license - Indices and tables ================== diff --git a/docs/source/installation.rst b/docs/source/installation.rst index dea3ae1d..3f3a3a27 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -43,6 +43,21 @@ Interfacing codes: Installation Instructions ^^^^^^^^^^^^^^^^^^^^^^^^^ +Conda +----- +Install using conda:: + + conda install -c conda-forge hippynn + +Pip +--- +Install using pip:: + + pip install hippynn + +Install from source: +-------------------- + Clone the hippynn_ repository and navigate into it, e.g.:: $ git clone https://github.com/lanl/hippynn.git @@ -55,14 +70,14 @@ Clone the hippynn_ repository and navigate into it, e.g.:: out ``cupy`` from the conda_requirements.txt file. Dependencies using conda -------------------------- +........................ Install dependencies from conda using recommended channels:: $ conda install -c pytorch -c conda-forge --file conda_requirements.txt Dependencies using pip ------------------------ +....................... Minimum dependencies using pip:: diff --git a/examples/ensembling_models.py b/examples/ensembling_models.py new file mode 100644 index 00000000..3055f649 --- /dev/null +++ b/examples/ensembling_models.py @@ -0,0 +1,56 @@ +import torch +import hippynn + +if torch.cuda.is_available(): + device = 0 +else: + device = 'cpu' + +### Building the ensemble just requires calling one function call. +model_form = '../../collected_models/quad0_b512_p5_GPU*' +ensemble_graph, ensemble_info = hippynn.graphs.make_ensemble(model_form) + +# Retrieve the ensemble node which has just been created. +# The name will be the prefix 'ensemble' followed by the db_name from the ensemble members. +ensemble_energy = ensemble_graph.node_from_name("ensemble_T") + +### Building an ASE calculator for the ensemble + +import ase.build + +from hippynn.interfaces.ase_interface import HippynnCalculator + +# The ensemble node has `mean`, `std`, and `all` outputs. +energy_node = ensemble_energy.mean +extra_properties = {"ens_predictions": ensemble_energy.all, "ens_std": ensemble_energy.std} +calc = HippynnCalculator(energy=energy_node, extra_properties=extra_properties) +calc.to(device) + +# build something and attach the calculator +molecule = ase.build.molecule("CH4") +molecule.calc = calc + +energy_value = molecule.get_potential_energy() # Activate calculation to get results dict + +print("Got energy", energy_value) +print("In units of kcal/mol", energy_value / (ase.units.kcal/ase.units.mol)) + +# All outputs from the ensemble members. Because the model was trained in kcal/mol, this is too. +# The name in the results dictionary comes from the key in the 'extra_properties' dictionary. +print("All predictions:", calc.results["ens_predictions"]) + + +### Building a Predictor object for the ensemble +pred = hippynn.graphs.Predictor.from_graph(ensemble_graph) + +# get batch-like inputs to the ensemble +z_vals = torch.as_tensor(molecule.get_atomic_numbers()).unsqueeze(0) +r_vals = torch.as_tensor(molecule.positions).unsqueeze(0) + +pred.to(r_vals.dtype) +pred.to(device) +# Do some computation +output = pred(Z=z_vals, R=r_vals) +# Print the output of a node using the node or the db_name. +print(output[ensemble_energy.all]) +print(output["T_all"]) \ No newline at end of file diff --git a/hippynn/graphs/__init__.py b/hippynn/graphs/__init__.py index 4fc3bd66..b1c2f1d1 100644 --- a/hippynn/graphs/__init__.py +++ b/hippynn/graphs/__init__.py @@ -27,6 +27,7 @@ from .graph import GraphModule from .predictor import Predictor +from .ensemble import make_ensemble __all__ = [ "get_subgraph", @@ -39,4 +40,5 @@ "GraphModule", "Predictor", "IdxType", + "make_ensemble", ] diff --git a/hippynn/graphs/ensemble.py b/hippynn/graphs/ensemble.py new file mode 100644 index 00000000..9fe002fb --- /dev/null +++ b/hippynn/graphs/ensemble.py @@ -0,0 +1,240 @@ +import collections +import glob + +from ..tools import device_fallback, active_directory +from . import GraphModule, replace_node, get_subgraph + +from .indextypes import get_reduced_index_state, index_type_coercion +from .indextypes.reduce_funcs import db_state_of +from .indextypes.registry import assign_index_aliases + +from .nodes.base import _BaseNode, InputNode +from .nodes.misc import EnsembleTarget + +from .gops import merge_children_recursive + +from typing import List, Dict, Union, Tuple + + +def make_ensemble(models, *, targets: List[str] = "auto", inputs: List[str] = "auto", + prefix: str = "ensemble_", quiet=False, + ) -> Tuple[GraphModule, Tuple[Dict[str, int], Dict[str, int]]]: + + """ + :param models: list containing str, node, or graphmodule, or str to glob for model directories. + :param targets: list of db_name strings or the string 'auto', which will attempt to infer. + :param inputs: list of db_name strings of the string 'auto', which will attempt to infer. + :param prefix: specifies the prefix for the db_name of created ensemble nodes. + :param quiet: whether to print information about the constructed ensemble. + :return: ensemble GraphModule, (intput_info, output_info) + """ + + # Phase 0: Make sure we are dealing with GraphModules + graphs: List[GraphModule] = get_graphs(models) + + # Phase 1: Figure out what the ensemble will look like. + if inputs == "auto": + inputs = identify_inputs(graphs) + if not quiet: + print("Identified input quantities:", inputs) + + if targets == "auto": + targets = identify_targets(graphs) + if not quiet: + print("Identified output quantities:", targets) + + input_classes: Dict[str, List[_BaseNode]] = collate_inputs(graphs, inputs) + target_classes: Dict[str, List[_BaseNode]] = collate_targets(graphs, targets) + + ensemble_info = make_ensemble_info(input_classes, target_classes, quiet=quiet) + + # Phase 2 build ensemble graph and GraphModule. + ensemble_outputs: List[EnsembleTarget] = construct_outputs(target_classes, prefix=prefix) + ensemble_inputs: List[_BaseNode] = replace_inputs(input_classes) + merged_inputs: List[_BaseNode] = merge_children_recursive(ensemble_inputs) + + if not quiet: + print("Merged the following nodes from the ensemble members:") + for node in merged_inputs: + print("\t", node) + + ensemble_graph = make_ensemble_graph(ensemble_inputs, ensemble_outputs) + + return ensemble_graph, ensemble_info + + +# TODO: Potentially move this function, or part of it, into experiment.serialization? +# TODO ; It seems possible that someone might want to load several models without ensembling them. +def get_graphs(models: Union[List[Union[str, GraphModule, _BaseNode]], str]) -> List[GraphModule]: + """ + + :param models: + :return: + """ + + graphs = [] + if isinstance(models, str): + models = glob.glob(models) + + device = None + for model in models: + if isinstance(model, str): + from ..experiment.serialization import load_model_from_cwd + + # Get graph from disk + if device is None: + device = device_fallback() + with active_directory(model, create=False): + try: + model = load_model_from_cwd(map_location=device) + except FileNotFoundError: + import warnings + warnings.warn(f"Model not found in directory: {model}") + else: + graphs.append(model) + + elif isinstance(model, _BaseNode): + subgraph = get_subgraph([model]) + subgraph_inputs = list({x for x in subgraph if isinstance(x, InputNode)}) + model = GraphModule(subgraph_inputs, [model]) + graphs.append(model) + + elif isinstance(model, GraphModule): + graphs.append(model) + + return graphs + + +def identify_targets(models: List[GraphModule]) -> set[str]: + + targets: set[str] = set() + + for model in models: + for node in model.nodes_to_compute: + if node.db_name is not None: + targets.add(node.db_name) + + return targets + + +def identify_inputs(models: list[GraphModule]) -> set[str]: + + inputs: set[str] = set() + + for model in models: + for node in model.input_nodes: + inputs.add(node.db_name) + + return inputs + + +def collate_inputs(models: list[GraphModule], inputs: List[str]) -> Dict[str, List[GraphModule]]: + """ + + :param models: + :param inputs: + :return: + """ + input_classes = collections.defaultdict(list) + + for m in models: + for n in m.input_nodes: + if n.db_name not in inputs: + raise ValueError("Input not allowed: '{n.db_name}' (Allowed targets were {inputs}") + input_classes[n.db_name].append(n) + + input_classes = dict(input_classes.items()) + return input_classes + + +def collate_targets(models: List[GraphModule], targets: List[str]) -> Dict[str, List[_BaseNode]]: + target_classes = collections.defaultdict(list) + + for m in models: + for n in m.nodes_to_compute: + if not hasattr(n, "db_name"): + continue + if n.db_name is None: + continue + if n.db_name in targets: + target_classes[n.db_name].append(n) + + target_classes = dict(target_classes.items()) + + return target_classes + + +def make_ensemble_info(input_classes: Dict[str, List[GraphModule]], output_classes: Dict[str, List[GraphModule]], quiet=False): + + input_info = {k: len(v) for k, v in input_classes.items()} + output_info = {k: len(v) for k, v in output_classes.items()} + + if not quiet: + print("Inputs needed and respective model counts:") + for k, v in input_info.items(): + print(f"\t{k} : {v}") + print("Outputs generated and respective model counts:") + for k, v in output_info.items(): + print(f"\t{k} : {v}") + + ensemble_info = input_info, output_info + + return ensemble_info + + +def construct_outputs(output_classes: Dict[str, List[GraphModule]], prefix: str) -> List[EnsembleTarget]: + ensemble_outputs = {} + + for db_name, parents in sorted(output_classes.items(), key=lambda x: x[0]): + + # To facilitate conversion of index states of ensembled nodes, we will build + # an ensemble target for both the db_form and the reduced form for each node. + # The ensemble will return the db_form when they differ, + # but the index cache will still register the reduced form (when it is different) + + reduced_index_state = get_reduced_index_state(*parents) + db_index_state = db_state_of(reduced_index_state) + + # Note: We want to run these before linking the separate models together, + # because the automation algorithms of hippynn currently handle cases + # where there is a unique type for some nodes in the graph, e.g. one pair indexer + # or one padding indexer. + db_state_parents = [index_type_coercion(p, db_index_state) for p in parents] + reduced_parents = [index_type_coercion(p, reduced_index_state) for p in parents] + + # Build db_form output + ensemble_node = EnsembleTarget(name=f"{prefix}{db_name}", parents=db_state_parents) + ensemble_outputs[db_name] = ensemble_node + + if reduced_index_state != db_index_state: + name = f"ensemble_{db_name}[{reduced_index_state}]" + + reduced_ensemble_node = EnsembleTarget(name=name, parents=reduced_parents) + + for db_child, reduced_child in zip(ensemble_node.children, reduced_ensemble_node.children): + assign_index_aliases(db_child, reduced_child) + + return ensemble_outputs + + +def replace_inputs(input_classes: Dict[str, List[GraphModule]]) -> List[InputNode]: + + ensemble_inputs = [] + + for db_name, node_list in input_classes.items(): + first_node = node_list[0] + ensemble_inputs.append(first_node) + rest_nodes = node_list[1:] + for node in rest_nodes: + replace_node(node, first_node) + + return ensemble_inputs + + +def make_ensemble_graph(ensemble_inputs: List[InputNode], ensemble_outputs: List[EnsembleTarget]) -> GraphModule: + + ensemble_output_list = [c for k,out in ensemble_outputs.items() for c in out.children] + ensemble_graph = GraphModule(ensemble_inputs, ensemble_output_list) + + return ensemble_graph + diff --git a/hippynn/graphs/gops.py b/hippynn/graphs/gops.py index df4c15b8..ee98ddc1 100644 --- a/hippynn/graphs/gops.py +++ b/hippynn/graphs/gops.py @@ -1,6 +1,7 @@ """ Graph Operations ("gops") that process or transform a set of nodes. """ +import collections import copy from .nodes.base import InputNode, MultiNode @@ -8,7 +9,7 @@ from .nodes.base.node_functions import NodeNotFound, NodeOperationError from .indextypes import soft_index_type_coercion from . import get_connected_nodes, find_unique_relative - +from ..tools import is_equal_state_dict def get_subgraph(required_nodes): """ @@ -305,3 +306,109 @@ def search_by_name(nodes, name_or_dbname): return find_unique_relative(nodes, lambda n: n.db_name == name_or_dbname) except NodeNotFound: return find_unique_relative(nodes, lambda n: n.name == name_or_dbname) + + +def merge_children_recursive(start_nodes): + """ + Merge children of some seed nodes if they are identical computations, + and apply to future children until no more merges can be performed. + + This function changes a graph in-place. + + :param start_nodes: + :return: Merged nodes. + """ + + all_merged_nodes = [] + while start_nodes: + merged_nodes = merge_children(start_nodes) + all_merged_nodes += merged_nodes + next_nodes = [] + for node in merged_nodes: + if isinstance(node, MultiNode): + next_nodes += node.children + else: + next_nodes.append(node) + + start_nodes = next_nodes + + return all_merged_nodes + + +def merge_children(start_nodes): + """ + Merge the children of some seed nodes if those children are identical. + + This function changes a graph in-place. + + :param start_nodes: + :return: child_nodes: the merged children, post merge + + """ + + from .nodes.tags import PairIndexer + next_generation = list(set([c for s in start_nodes for c in s.children])) + + # Check same parents + # Check same node type + # Check same module type + node_class_map = collections.defaultdict(list) + for node in next_generation: + node_class = (type(node), type(node.torch_module), node.parents) + node_class_map[node_class].append(node) + + # Only merge things when there is more than one node of the same class + considered_node_classes = {k: v for k, v in node_class_map.items() if len(v) > 1} + + # Check all nodes from the class have the same module state dict + # Add exceptional code for PairIndexer by finding the one with the maximum distance threshold. + mergeable_node_classes = [] + for nodes_to_merge in considered_node_classes.values(): # CODA + + first, *rest = nodes_to_merge + # check if the state dict of nodes is all equal. + d1 = first.torch_module.state_dict() + for node in rest: + d2 = node.torch_module.state_dict() + if not is_equal_state_dict(d1, d2): + nodes_can_merge = False + break + else: + nodes_can_merge = True + + if not nodes_can_merge: + continue # DS AL CODA (back to considered_node_classes.values() iteration.) + + # Extra code to handle merging of pair indexers. + # Even if we clean up the pair indexer with an extra state we would still need logic to merge them. + # TODO: Clean up hard_dist_cutoff vs dist_hard max and make it impossible for the node and + # TODO module to disagree. + if isinstance(first, PairIndexer): + max_r = first.torch_module.hard_dist_cutoff + max_r_node = first + swap_first_in = False + for other_node in rest: + this_r = other_node.torch_module.hard_dist_cutoff + if this_r > max_r: + max_r_node = other_node + max_r = this_r + swap_first_in = True + + if swap_first_in: + # Make the max_radius node the first one. + first = max_r_node + rest = [n for n in nodes_to_merge if n is not first] + nodes_to_merge = first, *rest + + mergeable_node_classes.append(nodes_to_merge) + + # actually perform the merging after all the analysis is done. + new_children = [] # These are the nodes that have been swapped into the graphs. + for (first, *rest) in mergeable_node_classes: + for equivalent_node in rest: + replace_node(equivalent_node, first) + equivalent_node.disconnect_recursive() + new_children.append(first) + + return new_children + diff --git a/hippynn/graphs/graph.py b/hippynn/graphs/graph.py index 4aee1e3e..8b3e7b6b 100644 --- a/hippynn/graphs/graph.py +++ b/hippynn/graphs/graph.py @@ -33,6 +33,12 @@ def __init__(self, required_inputs, nodes_to_compute): """ super().__init__() + nodes_to_compute = list(nodes_to_compute) + + if len(nodes_to_compute) == 0: + raise ValueError("Length of `nodes_to_compute` was zero. A graph module " + "must receive a list of outputs with length greater than zero.") + assert all(isinstance(n, InputNode) for n in required_inputs) self.input_nodes = required_inputs @@ -76,7 +82,7 @@ def __init__(self, required_inputs, nodes_to_compute): def get_module(self, node): return self.moddict[self.names_dict[node]] - def print_structure(self): + def print_structure(self, suppress=True): """Pretty-print the structure of the nodes and links comprising this graph.""" in_nodes = {n: "I{}".format(i) for i, n in enumerate(self.input_nodes)} out_nodes = {n: "O{}".format(i) for i, n in enumerate(self.nodes_to_compute)} @@ -90,7 +96,12 @@ def print_structure(self): for k, v in out_nodes.items(): print("\t", v, ":", k) print("Order:") + all_inputs = set(n for this_list in self.forward_inputs_list for n in this_list) for computed, inputs_for_computed in zip(self.forward_output_list, self.forward_inputs_list): + if computed not in all_inputs and computed not in self.nodes_to_compute: + # most likely this is just the child of MultiNode which is still unpacked. + continue + pre = ",".join({node_map[n] for n in inputs_for_computed}) mid = "{:3} : {}".format(node_map[computed], computed.name) print("{:-<20}-> {}".format(pre, mid)) diff --git a/hippynn/graphs/indextypes/reduce_funcs.py b/hippynn/graphs/indextypes/reduce_funcs.py index bd78f6a7..838a6ed0 100644 --- a/hippynn/graphs/indextypes/reduce_funcs.py +++ b/hippynn/graphs/indextypes/reduce_funcs.py @@ -89,7 +89,7 @@ def get_reduced_index_state(*nodes_to_reduce): Find the index state for comparison between values in a loss function or plot. .. Note:: - This function is unlikely to be directed needed as a user. + This function is unlikely to be directly needed as a user. it's more likely you want to use :func:`elementwise_compare_reduce`. :param nodes_to_reduce: diff --git a/hippynn/graphs/indextypes/registry.py b/hippynn/graphs/indextypes/registry.py index 5cfb0425..76132ae8 100644 --- a/hippynn/graphs/indextypes/registry.py +++ b/hippynn/graphs/indextypes/registry.py @@ -80,6 +80,40 @@ def clear_index_cache(): _index_cache = {} +def assign_index_aliases(*nodes): + """ + Store the input set of nodes in the index cache as index aliases of each other. + + Errors if the nodes contain two different nodes with the same index state. + + :param nodes: + :return: None + """ + # Developer node: + # In the interest of safety this function currently errors rather than over-write information. + # Operationally, it is not clear if it really needs to be safe or not. + # If there becomes a convenient reason to make this function overwrite current info, + # it could be changed. + + nodes = set(nodes) + state_map = {n._index_state: n for n in nodes} + if len(state_map) != len(nodes): + raise ValueError(f"Input nodes did not each have a unique index state!\n" + f"Nodes and corresponding states: \n" + f"\t{[(n, n._index_state) for n in nodes]}") + + for target_state, target_node in state_map.items(): + for n in nodes: + if n is target_node: + continue + idxcache_info = (target_state, n) + if idxcache_info in _index_cache: + raise ValueError(f"Index state for node is already cached: {idxcache_info,_index_cache[idxcache_info]}") + else: + _index_cache[idxcache_info] = target_node + + + def register_index_transformer(input_idxstate, output_idxstate): """ Decorator for registering a transformer from one IdxType to another. diff --git a/hippynn/graphs/nodes/base/definition_helpers.py b/hippynn/graphs/nodes/base/definition_helpers.py index fb39adf0..ff0aa894 100644 --- a/hippynn/graphs/nodes/base/definition_helpers.py +++ b/hippynn/graphs/nodes/base/definition_helpers.py @@ -14,7 +14,7 @@ from .. import _debprint from . import _BaseNode -from ...indextypes import index_type_coercion +from ...indextypes import index_type_coercion, elementwise_compare_reduce, get_reduced_index_state class AutoNoKw: @@ -38,7 +38,7 @@ def temporary_parents(child, parents): Context manager for temporarily connecting a node to a set of parents. This is used during parent expansion so that `find_relatives` and `find_unique_relatives` can treat the nodes as connected even though - they they are not fully formed. + they are not fully formed. :param child: :param parents: @@ -200,6 +200,14 @@ def matched_idx_coercion(self, form, needed_index_states): """ return IndexFormTransformer(form, needed_index_states) + @adds_to_forms + def require_compatible_idx_states(self): + """ + Ensure that all parents have commensurate index states. + :return: + """ + return CompatibleIdxTypeTransformer(AlwaysMatch) + @adds_to_forms def require_idx_states(self, *needed_index_states): """ @@ -359,6 +367,22 @@ def fn(node_self, *parents, **kwargs): return tuple(p.main_output for p in parents) +class CompatibleIdxTypeTransformer(FormTransformer): + def __init__(self, form): + super().__init__(form, self.fn) + + def add_class_doc(self): + return """Attempts coercion of all inputs to the same index state.""" + + @staticmethod + def fn(node_self, *parents, **kwargs): + """ + Enforces that all parents have compatible index states. + """ + index_state = get_reduced_index_state(*parents) + return parents + + class FormAssertion(FormHandler): def __init__(self, form): self.form = form diff --git a/hippynn/graphs/nodes/misc.py b/hippynn/graphs/nodes/misc.py index 5197afa2..91d8713f 100644 --- a/hippynn/graphs/nodes/misc.py +++ b/hippynn/graphs/nodes/misc.py @@ -2,9 +2,9 @@ Nodes not otherwise categorized. """ from ..indextypes import IdxType -from .base import AutoNoKw, SingleNode, MultiNode +from .base import AutoNoKw, SingleNode, MultiNode, ExpandParents from ...layers import indexers as index_modules, algebra as algebra_modules - +from ..indextypes import elementwise_compare_reduce class StrainInducer(AutoNoKw, MultiNode): _input_names = "coordinates", "cell" @@ -25,3 +25,25 @@ class ListNode(AutoNoKw, SingleNode): def __init__(self, name, parents, module="auto"): super().__init__(name, parents, module=module) + +class EnsembleTarget(ExpandParents, AutoNoKw, MultiNode): + _auto_module_class = algebra_modules.EnsembleTarget + _input_names = NotImplemented # NotImplemented tells __init_subclass__ that we will fill this in later. + _output_names = "mean", "std", "all" + + _parent_expander.get_main_outputs() + _parent_expander.require_compatible_idx_states() + + def __init__(self, name, parents, module="auto"): + + parents = self.expand_parents(parents) + + index_state = parents[0]._index_state + db_name = parents[0].db_name # assumes that all are the same! + + self._output_index_states = (index_state,)*3 + self._input_names = [f"input_{i}" for i in range(len(parents))] + + super().__init__(name, parents, module=module) + for c, out_name in zip(self.children, self._output_names): + c.db_name = f'{db_name}_{out_name}' diff --git a/hippynn/graphs/viz.py b/hippynn/graphs/viz.py index 18d9901e..1e11df65 100644 --- a/hippynn/graphs/viz.py +++ b/hippynn/graphs/viz.py @@ -103,7 +103,7 @@ def get_viz_node_names(node_set): unique_names = {} for node in node_set: if node.name in nonunique_names: - unique_names[node] = "{} (id={})".format(node.name, id(node)) + unique_names[node] = "{}".format(node.name, hex(id(node))) else: unique_names[node] = node.name return unique_names diff --git a/hippynn/interfaces/ase_interface/ase_database.py b/hippynn/interfaces/ase_interface/ase_database.py index e193f347..b3c05057 100644 --- a/hippynn/interfaces/ase_interface/ase_database.py +++ b/hippynn/interfaces/ase_interface/ase_database.py @@ -69,17 +69,18 @@ def __init__(self, directory: str, name: Union[str, List[str]], inputs, targets, ) def load_arrays(self, directory, filename, inputs, targets, quiet=False, allow_unfound=False): - """load arrays load ase database into hippynn database arrays - - Parameters - ---------- - filename : str - filename or path of database to convert - prefix : str, optional - prefix for output numpy arrays, by default None - return_data : bool, optional - whether or not to return the data or write to files, by default False """ + load arrays load ase database into hippynn database arrays + + :param directory: directory where database is stored + :param filename: file or path to file from directory + :param inputs: + :param targets: + :param quiet: + :param allow_unfound: + :return: + """ + var_list = inputs + targets try: if isinstance(filename, str): diff --git a/hippynn/layers/__init__.py b/hippynn/layers/__init__.py index c2180c26..558093c9 100644 --- a/hippynn/layers/__init__.py +++ b/hippynn/layers/__init__.py @@ -6,4 +6,4 @@ from . import targets from . import transform from . import physics -from . import excited \ No newline at end of file +from . import excited diff --git a/hippynn/layers/algebra.py b/hippynn/layers/algebra.py index b7dab0fc..fa3258ce 100644 --- a/hippynn/layers/algebra.py +++ b/hippynn/layers/algebra.py @@ -74,3 +74,12 @@ def extra_repr(self): def forward(self, bundled_inputs): return bundled_inputs[self.index] + +class EnsembleTarget(torch.nn.Module): + def forward(self,*input_tensors): + n_members = len(input_tensors) + + all = torch.stack(input_tensors, dim=1) + mean = torch.mean(all, dim=1) + std = torch.std(all, dim=1) + return mean, std, all \ No newline at end of file diff --git a/hippynn/tools.py b/hippynn/tools.py index 5550062e..6e653128 100644 --- a/hippynn/tools.py +++ b/hippynn/tools.py @@ -153,3 +153,36 @@ def pad_np_array_to_length_with_zeros(array, length, axis=0): def np_of_torchdefaultdtype(): return torch.ones(1, dtype=torch.get_default_dtype()).numpy().dtype + +def is_equal_state_dict(d1, d2): + """ + Checks if two pytorch state dictionaries are equal. Calls itself recursively + if the value for a parameter is a dictionary. + + + :param d1: + :param d2: + :return: + """ + if set(d1.keys()) != set(d2.keys()): + # They have different sets of keys. + return False + for k in d1: + v1 = d1[k] + v2 = d2[k] + if type(v1) != type(v2): + return False + if isinstance(v1, torch.Tensor): + if torch.equal(v1, v2): + continue + else: + return False + elif isinstance(v1, dict): + # call recursive: + return is_equal_state_dict(v1, v2) + elif v1 != v2: + return False + + return True + +