-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* first draft of ensemble code * add ensemble.py * working version, still working on merging identical nodes * add merging of identical nodes and some cleaning * try to add type annotations * adding ensemble usage example * update docs for ensembles and more * further docs * update example wording
- Loading branch information
1 parent
604a93c
commit 7d3e0f3
Showing
18 changed files
with
625 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
Ensembling Models | ||
################# | ||
|
||
|
||
Using the :func:`~hippynn.graphs.make_ensemble` function makes it easy to combine models. | ||
|
||
By default, ensembling is based on the db_name for the nodes in each input graph. | ||
Nodes which have the same name will be assigned an ensemble node which combines | ||
the different versions of that quantity, and additionally calculates the | ||
mean and standard deviation. | ||
|
||
It is easy to make an ensemble from a glob string or a list of directories where | ||
the models are saved:: | ||
|
||
from hippynn.graphs import make_ensemble | ||
model_form = '../../collected_models/quad0_b512_p5_GPU*' | ||
ensemble_graph, ensemble_info = make_ensemble(model_form) | ||
|
||
The ensemble graph takes the inputs which are required for all of the models in the ensemble. | ||
The ``ensemble_info`` object provides the counts for the inputs and targets of the ensemble | ||
and the counts of those corresponding quantities across the ensemble members. | ||
|
||
A typical use case would be to then build a Predictor or ASE Calculator from the ensemble. | ||
See :file:`~examples/ensembling_models.py` for a detailed example. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import torch | ||
import hippynn | ||
|
||
if torch.cuda.is_available(): | ||
device = 0 | ||
else: | ||
device = 'cpu' | ||
|
||
### Building the ensemble just requires calling one function call. | ||
model_form = '../../collected_models/quad0_b512_p5_GPU*' | ||
ensemble_graph, ensemble_info = hippynn.graphs.make_ensemble(model_form) | ||
|
||
# Retrieve the ensemble node which has just been created. | ||
# The name will be the prefix 'ensemble' followed by the db_name from the ensemble members. | ||
ensemble_energy = ensemble_graph.node_from_name("ensemble_T") | ||
|
||
### Building an ASE calculator for the ensemble | ||
|
||
import ase.build | ||
|
||
from hippynn.interfaces.ase_interface import HippynnCalculator | ||
|
||
# The ensemble node has `mean`, `std`, and `all` outputs. | ||
energy_node = ensemble_energy.mean | ||
extra_properties = {"ens_predictions": ensemble_energy.all, "ens_std": ensemble_energy.std} | ||
calc = HippynnCalculator(energy=energy_node, extra_properties=extra_properties) | ||
calc.to(device) | ||
|
||
# build something and attach the calculator | ||
molecule = ase.build.molecule("CH4") | ||
molecule.calc = calc | ||
|
||
energy_value = molecule.get_potential_energy() # Activate calculation to get results dict | ||
|
||
print("Got energy", energy_value) | ||
print("In units of kcal/mol", energy_value / (ase.units.kcal/ase.units.mol)) | ||
|
||
# All outputs from the ensemble members. Because the model was trained in kcal/mol, this is too. | ||
# The name in the results dictionary comes from the key in the 'extra_properties' dictionary. | ||
print("All predictions:", calc.results["ens_predictions"]) | ||
|
||
|
||
### Building a Predictor object for the ensemble | ||
pred = hippynn.graphs.Predictor.from_graph(ensemble_graph) | ||
|
||
# get batch-like inputs to the ensemble | ||
z_vals = torch.as_tensor(molecule.get_atomic_numbers()).unsqueeze(0) | ||
r_vals = torch.as_tensor(molecule.positions).unsqueeze(0) | ||
|
||
pred.to(r_vals.dtype) | ||
pred.to(device) | ||
# Do some computation | ||
output = pred(Z=z_vals, R=r_vals) | ||
# Print the output of a node using the node or the db_name. | ||
print(output[ensemble_energy.all]) | ||
print(output["T_all"]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.