From 036b153f7975719f5bfd8050a848207d5fe36262 Mon Sep 17 00:00:00 2001 From: Nicholas Lubbers Date: Wed, 24 Apr 2024 16:57:10 -0600 Subject: [PATCH] Add notebook for graph visualization --- examples/graph_exploration.ipynb | 968 ++++++++++++++++++ hippynn/graphs/ensemble.py | 2 +- .../interfaces/ase_interface/calculator.py | 20 +- .../lammps_interface/mliap_interface.py | 16 +- 4 files changed, 987 insertions(+), 19 deletions(-) create mode 100644 examples/graph_exploration.ipynb diff --git a/examples/graph_exploration.ipynb b/examples/graph_exploration.ipynb new file mode 100644 index 00000000..42c241ee --- /dev/null +++ b/examples/graph_exploration.ipynb @@ -0,0 +1,968 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "6574d8b4", + "metadata": {}, + "source": [ + "# Exploration of hippynn graph system\n", + "\n", + "## Let's revisit the simple training script \"barebones.py\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "08632064", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "'''\n", + "To obtain the data files needed for this example, use the script process_QM7_data.py, \n", + "also located in this folder. The script contains further instructions for use.\n", + "'''\n", + "\n", + "import torch\n", + "\n", + "# Setup pytorch things\n", + "torch.set_default_dtype(torch.float32)\n", + "\n", + "import hippynn\n", + "hippynn.settings.WARN_LOW_DISTANCES=False\n", + "\n", + "# Hyperparameters for the network\n", + "# These are set deliberately small so that you can easily run the example on a laptop or similar.\n", + "network_params = {\n", + " \"possible_species\": [0, 1, 6, 7, 8, 16], # Z values of the elements in QM7\n", + " \"n_features\": 20, # Number of neurons at each layer\n", + " \"n_sensitivities\": 20, # Number of sensitivity functions in an interaction layer\n", + " \"dist_soft_min\": 1.6, # qm7 is in Bohr!\n", + " \"dist_soft_max\": 10.0,\n", + " \"dist_hard_max\": 12.5,\n", + " \"n_interaction_layers\": 2, # Number of interaction blocks\n", + " \"n_atom_layers\": 3, # Number of atom layers in an interaction block\n", + "}\n", + "\n", + "# Define a model\n", + "from hippynn.graphs import inputs, networks, targets, physics\n", + "\n", + "species = inputs.SpeciesNode(db_name=\"Z\")\n", + "positions = inputs.PositionsNode(db_name=\"R\")\n", + "\n", + "network = networks.Hipnn(\"hipnn_model\", (species, positions), module_kwargs=network_params)\n", + "henergy = targets.HEnergyNode(\"HEnergy\", network, db_name=\"T\")\n", + "# hierarchicality = henergy.hierarchicality\n", + "\n", + "# define loss quantities\n", + "from hippynn.graphs import loss\n", + "\n", + "mse_energy = loss.MSELoss.of_node(henergy)\n", + "mae_energy = loss.MAELoss.of_node(henergy)\n", + "rmse_energy = mse_energy ** (1 / 2)\n", + "\n", + "# Validation losses are what we check on the data between epochs -- we can only train to\n", + "# a single loss, but we can check other metrics too to better understand how the model is training.\n", + "# There will also be plots of these things over time when training completes.\n", + "validation_losses = {\n", + " \"RMSE\": rmse_energy,\n", + " \"MAE\": mae_energy,\n", + " \"MSE\": mse_energy,\n", + "}\n", + "\n", + "# This piece of code glues the stuff together as a pytorch model,\n", + "# dropping things that are irrelevant for the losses defined.\n", + "training_modules, db_info = hippynn.experiment.assemble_for_training(mse_energy, validation_losses)\n", + "\n", + "# Go to a directory for the model.\n", + "# hippynn will save training files in the current working directory.\n", + "# Log the output of python to `training_log.txt`\n", + "database = hippynn.databases.DirectoryDatabase(\n", + " name=\"data-qm7\", # Prefix for arrays in the directory\n", + " directory=\"../../datasets/qm7_processed\",\n", + " test_size=0.1, # Fraction or number of samples to test on\n", + " valid_size=0.1, # Fraction or number of samples to validate on\n", + " seed=2001, # Random seed for splitting data\n", + " **db_info, # Adds the inputs and targets db_names from the model as things to load\n", + ")\n", + "\n", + "# Now that we have a database and a model, we can\n", + "# Fit the non-interacting energies by examining the database.\n", + "# This tends to stabilize training a lot.\n", + "from hippynn.pretraining import hierarchical_energy_initialization\n", + "\n", + "hierarchical_energy_initialization(henergy, database, trainable_after=False)\n", + "\n", + "# Parameters describing the training procedure.\n", + "from hippynn.experiment import setup_and_train\n", + "\n", + "experiment_params = hippynn.experiment.SetupParams(\n", + " stopping_key=\"MSE\", # The name in the validation_losses dictionary.\n", + " batch_size=12,\n", + " optimizer=torch.optim.Adam,\n", + " max_epochs=1,\n", + " learning_rate=0.001,\n", + ")\n", + "netname = \"TEST_BAREBONES_SCRIPT\"\n", + "with hippynn.tools.active_directory(netname):\n", + " \n", + " setup_and_train(\n", + " training_modules=training_modules,\n", + " database=database,\n", + " setup_params=experiment_params,\n", + " )\n" + ] + }, + { + "cell_type": "markdown", + "id": "7bb9e9f3", + "metadata": {}, + "source": [ + "## Assembling a graph for training\n", + "\n", + "Perhaps one of the more mysterious lines is:\n", + "\n", + "`training_modules, db_info = hippynn.experiment.assemble_for_training(mse_energy, validation_losses)`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1138b8dd", + "metadata": {}, + "outputs": [], + "source": [ + "db_info" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "88253632", + "metadata": {}, + "outputs": [], + "source": [ + "type(training_modules)" + ] + }, + { + "cell_type": "markdown", + "id": "c158e0f7", + "metadata": {}, + "source": [ + "`training_modules` contain 3 objects: A `model`, a (training) `loss`, and an `evaluator` (which computes the validation losses)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2883803f", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "for x in [training_modules.model,training_modules.loss,training_modules.evaluator]:\n", + " print(type(x))" + ] + }, + { + "cell_type": "markdown", + "id": "955d5772", + "metadata": {}, + "source": [ + "With the python graphviz interface installed, it is easy to visualize what a GraphModule does:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "97405f18", + "metadata": {}, + "outputs": [], + "source": [ + "from hippynn.graphs.viz import visualize_connected_nodes, visualize_graph_module, visualize_node_set" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ecdf96a6", + "metadata": {}, + "outputs": [], + "source": [ + "visualize_graph_module(training_modules.model)" + ] + }, + { + "cell_type": "markdown", + "id": "a04e3aa2", + "metadata": {}, + "source": [ + "Hidden in the multiple arrows are child nodes. Each node with multiple arrows was a MultiNode, that actually outputs multiple tensors." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7c9b5120", + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "visualize_graph_module(training_modules.model,compactify=False)" + ] + }, + { + "cell_type": "markdown", + "id": "8b84e5df", + "metadata": {}, + "source": [ + "Let's take a look at just the one-hot encoder. The `node_from_name` method will make it easy to get a reference to a particular node from the printed or visualized information." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "faf1f5c6", + "metadata": {}, + "outputs": [], + "source": [ + "onehot = training_modules.model.node_from_name(\"OneHot\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2a639d14", + "metadata": {}, + "outputs": [], + "source": [ + "visualize_node_set([onehot,*onehot.children],compactify=False)" + ] + }, + { + "cell_type": "markdown", + "id": "15f73107", + "metadata": {}, + "source": [ + "A Predictor interface can make it simpler to compute the value of some nodes over a database." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d0f73ed9", + "metadata": {}, + "outputs": [], + "source": [ + "from hippynn.graphs import Predictor\n", + "onehot_predictor = Predictor([species],[*onehot.children,network.input_features])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "69dfaefb", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4fd79d05", + "metadata": {}, + "outputs": [], + "source": [ + "outputs = onehot_predictor.apply_to_database(database,batch_size=512)\n", + "outputs.keys()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "69eb998c", + "metadata": {}, + "outputs": [], + "source": [ + "train_outs = outputs['train']" + ] + }, + { + "cell_type": "markdown", + "id": "f4132d15", + "metadata": {}, + "source": [ + "The outputs can be indexed by the node name to get the output value:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d5d8c5df", + "metadata": {}, + "outputs": [], + "source": [ + "train_outs[\"OneHot.encoding\"]" + ] + }, + { + "cell_type": "markdown", + "id": "b83580bf", + "metadata": {}, + "source": [ + "You can also get the value for a node using the node directly:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7fbfb38a", + "metadata": {}, + "outputs": [], + "source": [ + "onehot_train = train_outs[onehot.encoding]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a1d74a1c", + "metadata": {}, + "outputs": [], + "source": [ + "input_features = train_outs[\"PaddingIndexer.indexed_features\"]" + ] + }, + { + "cell_type": "markdown", + "id": "1b2da7bd", + "metadata": {}, + "source": [ + "In this context, the input features look to be the same." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b752abe8", + "metadata": {}, + "outputs": [], + "source": [ + "print(input_features.shape,input_features.dtype)\n", + "print(onehot_train.shape,onehot_train.dtype)\n", + "print(torch.equal(onehot_train,input_features))" + ] + }, + { + "cell_type": "markdown", + "id": "8f3ee519", + "metadata": {}, + "source": [ + "But actually, the predictor is hiding some complexity. Let's take a look at a more rudimentary GraphModule constructed directly - we will manually specify the set of inputs and outputs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bbe8e546", + "metadata": {}, + "outputs": [], + "source": [ + "from hippynn.graphs import GraphModule\n", + "onehot_graphmodule = GraphModule([species],[onehot.encoding,onehot.nonblank,network.input_features])\n", + "visualize_graph_module(onehot_graphmodule)" + ] + }, + { + "cell_type": "markdown", + "id": "cd8031ff", + "metadata": {}, + "source": [ + "We will also manually graph the input array:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bf0f7c04", + "metadata": {}, + "outputs": [], + "source": [ + "arrays = database.splits['train']\n", + "\n", + "outputs_graph = onehot_graphmodule(arrays['Z'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5b23d457", + "metadata": {}, + "outputs": [], + "source": [ + "type(outputs_graph)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1f96585f", + "metadata": {}, + "outputs": [], + "source": [ + "len(outputs_graph)" + ] + }, + { + "cell_type": "markdown", + "id": "9fd3a06e", + "metadata": {}, + "source": [ + "Each one corresponds to one of the outputs directly: `[onehot.encoding,onehot.nonblank,network.input_features]`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9f3d831c", + "metadata": {}, + "outputs": [], + "source": [ + "features_graph = outputs_graph[-1]" + ] + }, + { + "cell_type": "markdown", + "id": "95ea67c2", + "metadata": {}, + "source": [ + "What share the the features now?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9b472288", + "metadata": {}, + "outputs": [], + "source": [ + "features_graph.shape" + ] + }, + { + "cell_type": "markdown", + "id": "fe8a6d99", + "metadata": {}, + "source": [ + "Hmm, that's not familiar.\n", + "\n", + "Let's compare this to the output of the Predictor interface:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "576eb0dc", + "metadata": {}, + "outputs": [], + "source": [ + "input_features.shape,features_graph.shape" + ] + }, + { + "cell_type": "markdown", + "id": "d2bee871", + "metadata": {}, + "source": [ + "What's going on? It has to do with the fact that we have batches of systems, but each system has a different number of atoms: " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0e6b446d", + "metadata": {}, + "outputs": [], + "source": [ + "database.splits['train']['Z']" + ] + }, + { + "cell_type": "markdown", + "id": "4db32782", + "metadata": {}, + "source": [ + "The predictor actually uses the graph system too:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c592d929", + "metadata": {}, + "outputs": [], + "source": [ + "visualize_graph_module(onehot_predictor.graph,compactify=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d7b9fab7", + "metadata": {}, + "outputs": [], + "source": [ + "from hippynn.graphs import IdxType\n", + "type(IdxType)" + ] + }, + { + "cell_type": "markdown", + "id": "794fffbd", + "metadata": {}, + "source": [ + "IdxType is an enumeration" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e32a080a", + "metadata": {}, + "outputs": [], + "source": [ + "dir(IdxType)" + ] + }, + { + "cell_type": "markdown", + "id": "0a53989e", + "metadata": {}, + "source": [ + "## IdxType tags the \"batch information\" for the tensor\n", + "\n", + "* On each tensor, the batch might refer to a different quantity\n", + "* We can have a batch of atoms, or a batch of molecules\n", + "* or a batch of MolAtom, meaning molecules on the first batch axis, followed by atoms on the second batch axis\n", + "* Index Types like MolAtom and MolAtomAtom can be conveniently batched over\n", + "* Index Types like Atoms and Pair are sparse, and so make for more efficient computation\n", + "* To track the relationship between the different batch-types, we need _indexing_ information.\n", + "* `hippynn` looks at the index types associated with inputs and outputs and can automatically construct conversions between the types whenever the answer is unambiguous.\n", + "* In cases where the automatic construction fails, an advanced user can directly specify the intended result." + ] + }, + { + "cell_type": "markdown", + "id": "643e928f", + "metadata": {}, + "source": [ + "Behind the hood, the loss and evaluator also use graphs! \n", + "\n", + "- This is what allows us to python syntax to build a loss function from algebraic operations." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "48995013", + "metadata": {}, + "outputs": [], + "source": [ + "visualize_graph_module(training_modules.loss)" + ] + }, + { + "cell_type": "markdown", + "id": "311681ea", + "metadata": {}, + "source": [ + "Every model quantity with a db_name can be an input into the loss graph, either in 'true' (database) form, or 'predicted' (model) form:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "05b1c449", + "metadata": {}, + "outputs": [], + "source": [ + "henergy.mol_energy.true" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "871f6b9e", + "metadata": {}, + "outputs": [], + "source": [ + "henergy.mol_energy.pred" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f401f054", + "metadata": {}, + "outputs": [], + "source": [ + "visualize_graph_module(training_modules.evaluator.loss)" + ] + }, + { + "cell_type": "markdown", + "id": "08320160", + "metadata": {}, + "source": [ + "# Graph transformations\n", + "\n", + "## ASE Interface" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "768353cb", + "metadata": {}, + "outputs": [], + "source": [ + "# To run this, train a model using ani_aluminum_example.py! \n", + "with hippynn.tools.active_directory('./TEST_ALUMINUM_MODEL/'):\n", + " model=hippynn.experiment.serialization.load_model_from_cwd()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4a451fb9", + "metadata": {}, + "outputs": [], + "source": [ + "type(model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "259c037d", + "metadata": {}, + "outputs": [], + "source": [ + "visualize_graph_module(model)" + ] + }, + { + "cell_type": "markdown", + "id": "f4468d46", + "metadata": {}, + "source": [ + "Notice the graph structure is somewhat different here, because, for example, we have per-atom energies to train to, and periopdic boundary conditons.\n", + "\n", + "Let's send this to the Atomic Simulation Environment, a code for performing molecular dynamics in python." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1987c0e2", + "metadata": {}, + "outputs": [], + "source": [ + "from hippynn.interfaces.ase_interface import calculator_from_model\n", + "\n", + "calc = calculator_from_model(model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "52e9ff90", + "metadata": {}, + "outputs": [], + "source": [ + "visualize_graph_module(calc.module)" + ] + }, + { + "cell_type": "markdown", + "id": "3977fb73", + "metadata": {}, + "source": [ + "Very similarly, we can send the model to an MLIAPInterface for the LAMMPS molecular dynamics code, which is very useful for highly parallel simulations." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9afe971b", + "metadata": {}, + "outputs": [], + "source": [ + "from hippynn.interfaces.lammps_interface import MLIAPInterface" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6d906fe1", + "metadata": {}, + "outputs": [], + "source": [ + "lammps_interface = MLIAPInterface(model.node_from_name(\"HEnergy\"),element_types=['Al'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "60b97d16", + "metadata": {}, + "outputs": [], + "source": [ + "visualize_graph_module(lammps_interface.graph)" + ] + }, + { + "cell_type": "markdown", + "id": "a04dfb45", + "metadata": {}, + "source": [ + "# Ensembling\n", + "\n", + "Often it is useful to ensemble multiple models in machine learning. `hippynn` has some tools to automatically ensemble nodes and graphs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1b7af57d", + "metadata": {}, + "outputs": [], + "source": [ + "n_ensemble=5\n", + "useful_nodes = []\n", + "\n", + "for i in range(n_ensemble):\n", + " this_species = inputs.SpeciesNode(db_name=\"Z\")\n", + " this_positions = inputs.PositionsNode(db_name=\"R\")\n", + " this_network = networks.Hipnn(\"hipnn_model\", (this_species, this_positions), module_kwargs=network_params)\n", + " this_henergy = targets.HEnergyNode(\"HEnergy\", this_network, db_name=\"T\")\n", + " this_force = physics.GradientNode(\"Force\",(this_henergy,this_positions),sign=-1,db_name=\"F\")\n", + " \n", + " useful_nodes.append(this_henergy)\n", + " useful_nodes.append(this_force)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6647618c", + "metadata": {}, + "outputs": [], + "source": [ + "visualize_connected_nodes(useful_nodes)" + ] + }, + { + "cell_type": "markdown", + "id": "91a9361f", + "metadata": {}, + "source": [ + "Note that due to the presence of multiple ndoes with the same name in this visualization, each one is tagged with an its id. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11c63acb", + "metadata": {}, + "outputs": [], + "source": [ + "from hippynn.graphs import make_ensemble\n", + "\n", + "ensemble,ensemble_info = make_ensemble(useful_nodes)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f35096be", + "metadata": {}, + "outputs": [], + "source": [ + "visualize_graph_module(ensemble)" + ] + }, + { + "cell_type": "markdown", + "id": "d4cef63c", + "metadata": {}, + "source": [ + "The graph interface allows us to easily glue these models together and share intermediate computations where possible.\n", + "\n", + "Now, the models are merged as far as possible, sharing inputs and early calculations.\n", + " \n", + "At the same time, the ensemble quantities for energy (\"T\") and force (\"F\") have been constructed as nodes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cd75d779", + "metadata": {}, + "outputs": [], + "source": [ + "ensemble_T = ensemble.node_from_name(\"ensemble_T\")\n", + "ensemble_F = ensemble.node_from_name(\"ensemble_F\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b571d284", + "metadata": {}, + "outputs": [], + "source": [ + "ensemble_predictor = Predictor.from_graph(ensemble)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9964981f", + "metadata": {}, + "outputs": [], + "source": [ + "outputs = ensemble_predictor.apply_to_database(database,batch_size=128)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bea6eb42", + "metadata": {}, + "outputs": [], + "source": [ + "ensemble_T.mean" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "744bada1", + "metadata": {}, + "outputs": [], + "source": [ + "outputs['test'][ensemble_T.mean].shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d9e47c6d", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "outputs['test'][ensemble_T.std].shape" + ] + }, + { + "cell_type": "markdown", + "id": "5ecace03", + "metadata": {}, + "source": [ + "The \"all\" node outputs each individual prediction, stacked:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cf1a1b6c", + "metadata": {}, + "outputs": [], + "source": [ + "outputs['test'][ensemble_T.all].shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "599cd25e", + "metadata": {}, + "outputs": [], + "source": [ + "outputs['test'][ensemble_F.all].shape" + ] + }, + { + "cell_type": "markdown", + "id": "214056a6", + "metadata": {}, + "source": [ + "The features above can be intermixed, for example, building an ASE calculator using the ensemble module." + ] + }, + { + "cell_type": "markdown", + "id": "d6e9549a", + "metadata": {}, + "source": [ + "from hippynn.interfaces.ase_interface import HippynnCalculator" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "29ee4bca", + "metadata": {}, + "outputs": [], + "source": [ + "ensemble_calculator =HippynnCalculator(ensemble_T.mean)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cde6addd", + "metadata": {}, + "outputs": [], + "source": [ + "visualize_graph_module(ensemble_calculator.module)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "45afeb81", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python [conda env:hippynn]", + "language": "python", + "name": "conda-env-hippynn-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/hippynn/graphs/ensemble.py b/hippynn/graphs/ensemble.py index 9fe002fb..c9fc68c3 100644 --- a/hippynn/graphs/ensemble.py +++ b/hippynn/graphs/ensemble.py @@ -96,7 +96,7 @@ def get_graphs(models: Union[List[Union[str, GraphModule, _BaseNode]], str]) -> elif isinstance(model, _BaseNode): subgraph = get_subgraph([model]) subgraph_inputs = list({x for x in subgraph if isinstance(x, InputNode)}) - model = GraphModule(subgraph_inputs, [model]) + model = GraphModule(subgraph_inputs, [model.main_output]) graphs.append(model) elif isinstance(model, GraphModule): diff --git a/hippynn/interfaces/ase_interface/calculator.py b/hippynn/interfaces/ase_interface/calculator.py index b4d1f671..bafed038 100644 --- a/hippynn/interfaces/ase_interface/calculator.py +++ b/hippynn/interfaces/ase_interface/calculator.py @@ -64,7 +64,7 @@ def setup_ASE_graph(energy, charges=None, extra_properties=None): # The required nodes passed back are copies of the ones passed in. # We use assume_inputed to avoid grabbing pieces of the graph # that are only prerequisites for the pair indexer. - new_required, new_subgraph = copy_subgraph(required_nodes, assume_inputed=pair_indexers, tag="ASE") + new_required, new_subgraph = copy_subgraph(required_nodes, assume_inputed=pair_indexers) # We now need access to the copied indexers, rather than the originals pair_indexers = find_relatives(new_required, search_fn(PairIndexer, new_subgraph), why_desc=why) @@ -81,12 +81,12 @@ def setup_ASE_graph(energy, charges=None, extra_properties=None): ############################################################### # Set up graph to accept external pair indices and shifts - in_shift = InputNode("(ASE)shift_vector") - in_cell = CellNode("(ASE)cell") - in_pair_first = InputNode("(ASE)pair_first") - in_pair_second = InputNode("(ASE)pair_second") + in_shift = InputNode("shift_vector") + in_cell = CellNode("cell") + in_pair_first = InputNode("pair_first") + in_pair_second = InputNode("pair_second") external_pairs = ExternalNeighborIndexer( - "(ASE)EXTERNAL_NEIGHBORS", + "external_neighbors", (positions, indexer.real_atoms, in_shift, in_cell, in_pair_first, in_pair_second), hard_dist_cutoff=min_radius, ) @@ -102,7 +102,7 @@ def setup_ASE_graph(energy, charges=None, extra_properties=None): mapped_node = external_pairs else: mapped_node = PairFilter( - "DistanceFilter-(ASE)EXTERNAL_NEIGHBORS", + "DistanceFilter_external_neighbors", (external_pairs), dist_hard_max=pi.dist_hard_max, ) @@ -114,9 +114,9 @@ def setup_ASE_graph(energy, charges=None, extra_properties=None): energy, *new_required = new_required - cellscaleinducer = StrainInducer("(ASE)Strain_inducer", (positions, in_cell)) + cellscaleinducer = StrainInducer("Strain_inducer", (positions, in_cell)) strain = cellscaleinducer.strain - derivatives = StressForceNode("(ASE)StressForceCalculator", (energy, strain, positions, in_cell)) + derivatives = StressForceNode("StressForceCalculator", (energy, strain, positions, in_cell)) replace_node(positions, cellscaleinducer.strained_coordinates) replace_node(in_cell, cellscaleinducer.strained_cell) @@ -128,7 +128,7 @@ def setup_ASE_graph(energy, charges=None, extra_properties=None): if charges is not None: charges, *new_required = new_required - dipole_moment = DipoleNode("(ASE)DIPOLE", charges) + dipole_moment = DipoleNode("Dipole", charges) implemented_nodes = *implemented_nodes, charges.main_output, dipole_moment implemented_properties = implemented_properties + ["charges", "dipole_moment"] diff --git a/hippynn/interfaces/lammps_interface/mliap_interface.py b/hippynn/interfaces/lammps_interface/mliap_interface.py index 0e9c7335..583fe5c7 100644 --- a/hippynn/interfaces/lammps_interface/mliap_interface.py +++ b/hippynn/interfaces/lammps_interface/mliap_interface.py @@ -137,7 +137,7 @@ def setup_LAMMPS_graph(energy): search_fn = lambda targ, sg: lambda n: n in sg and isinstance(n, targ) pair_indexers = find_relatives(required_nodes, search_fn(PairIndexer, subgraph), why_desc=why) - new_required, new_subgraph = copy_subgraph(required_nodes, assume_inputed=pair_indexers, tag="LAMMPS") + new_required, new_subgraph = copy_subgraph(required_nodes, assume_inputed=pair_indexers) pair_indexers = find_relatives(new_required, search_fn(PairIndexer, new_subgraph), why_desc=why) species = find_unique_relative(new_required, search_fn(SpeciesNode, new_subgraph), why_desc=why) @@ -152,15 +152,15 @@ def setup_LAMMPS_graph(energy): ############################################################### # Set up graph to accept external pair indices and shifts - in_pair_first = InputNode("(LAMMPS)pair_first") + in_pair_first = InputNode("pair_first") in_pair_first._index_state = hippynn.graphs.IdxType.Pair - in_pair_second = InputNode("(LAMMPS)pair_second") + in_pair_second = InputNode("pair_second") in_pair_second._index_state = hippynn.graphs.IdxType.Pair - in_pair_coord = InputNode("(LAMMPS)pair_coord") + in_pair_coord = InputNode("pair_coord") in_pair_coord._index_state = hippynn.graphs.IdxType.Pair - in_nlocal = InputNode("(LAMMPS)nlocal") + in_nlocal = InputNode("nlocal") in_nlocal._index_state = hippynn.graphs.IdxType.Scalar - pair_dist = VecMag("(LAMMPS)pair_dist", in_pair_coord) + pair_dist = VecMag("pair_dist", in_pair_coord) mapped_pair_first = ReIndexAtomNode("pair_first_internal", (in_pair_first, inv_real_atoms)) mapped_pair_second = ReIndexAtomNode("pair_second_internal", (in_pair_second, inv_real_atoms)) @@ -201,8 +201,8 @@ def setup_LAMMPS_graph(energy): "an object with an `atom_energies` attribute." ) - local_atom_energy = LocalAtomEnergyNode("(LAMMPS)local_atom_energy", (atom_energies, in_nlocal)) - grad_rij = GradientNode("(LAMMPS)grad_rij", (local_atom_energy.total_local_energy, in_pair_coord), -1) + local_atom_energy = LocalAtomEnergyNode("local_atom_energy", (atom_energies, in_nlocal)) + grad_rij = GradientNode("grad_rij", (local_atom_energy.total_local_energy, in_pair_coord), -1) implemented_nodes = local_atom_energy.local_atom_energies, local_atom_energy.total_local_energy, grad_rij