Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

analytical second order derivatives (Hessians) #15

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
17 changes: 15 additions & 2 deletions scripts/sgdml_dataset_from_extxyz.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import argparse
import os
import sys
import re

try:
from ase.io import read
Expand All @@ -45,7 +46,8 @@
# Assumes that the atoms in each molecule are in the same order.
def read_nonstd_ext_xyz(f):
n_atoms = None

pattern = re.compile('[eE]nergy=([\+\-0-9\.]+) ')

R, z, E, F = [], [], [], []
for i, line in enumerate(f):
line = line.strip()
Expand All @@ -59,7 +61,12 @@ def read_nonstd_ext_xyz(f):
try:
e = float(line)
except ValueError:
pass
# Try to read energy from comment line as
# Energy=(.*) ...
match = pattern.findall(line)
if len(match) > 0:
e = float(match[0])
E.append(e)
else:
E.append(e)

Expand Down Expand Up @@ -189,6 +196,12 @@ def read_nonstd_ext_xyz(f):
base_vars['F_min'], base_vars['F_max'] = np.min(F.ravel()), np.max(F.ravel())
base_vars['F_mean'], base_vars['F_var'] = np.mean(F.ravel()), np.var(F.ravel())

print('Please provide a descriptor for the level of theory used to create this dataset.')
theory = raw_input('> ').strip()
if theory == '':
theory = 'unknown'
base_vars['theory'] = theory

print('Please provide a description of the length unit used in your input file, e.g. \'Ang\' or \'au\': ')
print('Note: This string will be stored in the dataset file and passed on to models files for later reference.')
r_unit = raw_input('> ').strip()
Expand Down
Binary file added sgdml/_bmark_cache.npz
Binary file not shown.
146 changes: 146 additions & 0 deletions sgdml/test_torchtools_hessian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
#!/usr/bin/env python
# coding: utf-8
"""
test torch implementation of machine-learned sGDML model (energies, gradients and Hessians)
"""
model_file = 'models/ethanol.npz'
geometry_file = 'geometries/ethanol.xyz'

import numpy as np
import numpy.linalg as la
import torch
import logging
import time

from sgdml.utils import io

from sgdml.torchtools_hessian import GDMLPredict
from sgdml.torchtools import GDMLTorchPredict

# # Logging
logger = logging.getLogger(__name__)

# GPU or CUDA?
torch.set_default_dtype(torch.float64)
if torch.cuda.is_available():
logger.info("CUDA available")
device = torch.device("cuda")
else:
device = torch.device('cpu')

# load model fitted to ground state forces
model = np.load(model_file, allow_pickle=True)
# reference implementation
gdml_ref = GDMLTorchPredict(model)
# new implementation with analytical Hessians
gdml = GDMLPredict(model).to(device)
# load geometry
r,_ = io.read_xyz(geometry_file)
coords = torch.from_numpy(r).to(device)

##########################################################################
# #
# check that energies and gradients agree with reference implementation #
# #
##########################################################################

# make random numbers reproducible
torch.manual_seed(0)

natom = coords.size()[1]//3
# timings for different batch sizes
for batch_size in [1,10,100,1000]:
logger.info(f"batch size {batch_size}")
# batch (B,3*N)
rs = coords.repeat(batch_size, 1) + 0.1 * torch.rand(batch_size,3*natom).to(device)
# (B, N, 3)
rs_3N = rs.reshape(batch_size, -1, 3)

t_start = time.time()
# compute energy and Hessian with reference implementation
en_ref, force_ref = gdml_ref.forward(rs_3N)
grad_ref = -force_ref.reshape(rs.size())

t_end = time.time()
logger.info(f"timing reference implementation, energy+gradient : {t_end-t_start} seconds")

t_start = time.time()
# and compare with new implementation
en, grad, hessian = gdml.forward(rs)

t_end = time.time()
logger.info(f"timing new implementation, energy+gradient+hessian : {t_end-t_start} seconds")

# error per sample
err_en = torch.norm(en_ref - en)/batch_size
err_grad = torch.norm(grad_ref - grad)/batch_size

logger.info(f" error of energy : {err_en}")
logger.info(f" error of gradient : {err_grad}")

assert err_en < 1.0e-4
assert err_grad < 1.0e-4

###############################################################
# #
# compare numerical and analytic Hessians of sGDML potential #
# #
###############################################################
from sgdml.intf.ase_calc import SGDMLCalculator
from ase.io.xyz import read_xyz
from ase.optimize import BFGS
from ase.vibrations import Vibrations
from ase.units import kcal, mol

# compute Hessian numerically using ASE
with open(geometry_file) as f:
molecule = next(read_xyz(f))
sgdml_calc = SGDMLCalculator(model_file)
molecule.calc = sgdml_calc

# optimization
opt = BFGS(molecule)
opt.run(fmax=0.001)
# optimized geometry
coords_opt = torch.from_numpy(molecule.get_positions()).reshape(1,-1).to(device)

# frequencies
vib = Vibrations(molecule, name="/tmp/vib_sgdml")
vib.run()
vib.get_energies()
vib.clean()

# convert numerical Hessian from eV Ang^{-2} to kcal/mol Ang^{-2}
hessian_numerical = vib.H / (kcal / mol)


# compute analytic Hessian directly from sGDML model
hessian_analytical = gdml.forward(coords_opt)[2][0,:,:].cpu().numpy()

# check that Hessian is symmetric
err_sym = la.norm(hessian_analytical - hessian_analytical.T)
logger.info(f"|Hessian-Hessian^T|= {err_sym}")
assert err_sym < 1.0e-8

"""
# compare Hessians visually
import matplotlib.pyplot as plt
ax1 = plt.subplot(1,3,1)
ax1.set_title("numerical Hessian")
ax1.imshow(hessian_numerical)

ax2 = plt.subplot(1,3,2)
ax2.set_title("analytical Hessian")
ax2.imshow(hessian_analytical)

ax3 = plt.subplot(1,3,3)
ax3.set_title("difference")
ax3.imshow(hessian_numerical - hessian_analytical)

plt.show()
"""

# check that numerical and analytical Hessians agree within numerical errors
err = la.norm(hessian_numerical - hessian_analytical)/la.norm(hessian_numerical)
logger.info(f"|Hessian(num)-Hessian(ana)|/|Hessian(num)|= {err}")
assert err < 1.0e-3
Loading