Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The option to specify the desired molecular descriptor #11

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 60 additions & 12 deletions sgdml/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def _print_dataset_properties(dataset, title_str='Dataset properties'):


def _print_task_properties(
use_sym, use_cprsn, use_E, use_E_cstr, title_str='Task properties'
use_sym, use_cprsn, use_E, use_E_cstr, descriptor, title_str='Task properties'
):

print(ui.white_bold_str(title_str))
Expand Down Expand Up @@ -252,7 +252,11 @@ def _print_task_properties(
'Compression:', 'requested' if use_cprsn else 'not requested'
)
)

print(
' {:<16} {}'.format(
'Descriptor:', descriptor[0]
)
)

def _print_model_properties(model, title_str='Model properties'):

Expand All @@ -270,6 +274,10 @@ def _print_model_properties(model, title_str='Model properties'):
_, cprsn_keep_idxs = np.unique(
np.sort(model['perms'], axis=0), axis=1, return_index=True
)
try:
print(' {:<18} {}'.format('Descriptor:', model['use_descriptor'][0]))
except:
print(' {:<18} {}'.format('Descriptor:', 'descriptor was not specified.'))
n_atoms_kept = cprsn_keep_idxs.shape[0]
print(
' {:<18} {}'.format(
Expand Down Expand Up @@ -349,6 +357,7 @@ def all(
use_E,
use_E_cstr,
use_cprsn,
use_descriptor,
overwrite,
max_processes,
use_torch,
Expand Down Expand Up @@ -398,6 +407,7 @@ def all(
use_E,
use_E_cstr,
use_cprsn,
use_descriptor,
overwrite,
max_processes,
task_dir,
Expand All @@ -424,6 +434,7 @@ def all(
overwrite=False,
max_processes=max_processes,
use_torch=use_torch,
use_descriptor=use_descriptor,
**kwargs
)

Expand All @@ -443,6 +454,7 @@ def all(
overwrite=False,
max_processes=max_processes,
use_torch=use_torch,
use_descriptor=use_descriptor,
**kwargs
)

Expand All @@ -466,6 +478,7 @@ def create( # noqa: C901
use_E,
use_E_cstr,
use_cprsn,
use_descriptor,
overwrite,
max_processes,
task_dir=None,
Expand All @@ -487,7 +500,7 @@ def create( # noqa: C901
print()

_print_task_properties(
use_sym=not gdml, use_cprsn=use_cprsn, use_E=use_E, use_E_cstr=use_E_cstr
use_sym=not gdml, use_cprsn=use_cprsn, use_E=use_E, use_E_cstr=use_E_cstr, descriptor=use_descriptor
)
print()

Expand Down Expand Up @@ -531,6 +544,7 @@ def create( # noqa: C901
use_cprsn=use_cprsn,
use_E=use_E,
use_E_cstr=use_E_cstr,
descriptor=use_descriptor[0],
model0=model0,
)

Expand Down Expand Up @@ -583,7 +597,8 @@ def create( # noqa: C901
if not use_E:
log.info(
'Energy labels will be ignored for training.\n'
+ 'Note: If available in the dataset file, the energy labels will however still be used to generate stratified training, test and validation datasets. Otherwise a random sampling is used.'
+ 'Note: If available in the dataset file, the energy labels will however still be used to generate '
'stratified training, test and validation datasets. Otherwise a random sampling is used.'
)

if 'E' not in dataset:
Expand Down Expand Up @@ -618,6 +633,7 @@ def create( # noqa: C901
n_train,
valid_dataset,
n_valid,
use_descriptor,
sig=1,
use_sym=not gdml,
use_E=use_E,
Expand Down Expand Up @@ -808,7 +824,7 @@ def _online_err(err, size, n, mae_n_sum, rmse_n_sum):


def validate(
model_dir, dataset, overwrite, max_processes, use_torch, command=None, **kwargs
model_dir, dataset, overwrite, max_processes, use_torch, use_descriptor, command=None, **kwargs
):

dataset_path_extracted, dataset_extracted = dataset
Expand All @@ -827,6 +843,7 @@ def validate(
overwrite,
max_processes,
use_torch,
use_descriptor,
command,
**kwargs
)
Expand Down Expand Up @@ -859,6 +876,7 @@ def test(
overwrite,
max_processes,
use_torch,
use_descriptor,
command=None,
**kwargs
): # noqa: C901
Expand Down Expand Up @@ -1176,7 +1194,8 @@ def test(
else 'only {:,}'.format(len(test_idxs))
)
log.warning(
'This model has previously been tested on {:,} points, which is why the errors for the current test run with {} points have NOT been used to update the model file.\n'.format(
'This model has previously been tested on {:,} points, which is why the errors for the current test '
'run with {} points have NOT been used to update the model file.\n'.format(
model['n_test'], add_info_str
)
+ 'Run \'{} test -o {} {} {}\' to overwrite.'.format(
Expand Down Expand Up @@ -1258,8 +1277,10 @@ def select(

if any_model_is_tested:
log.error(
'One or more models in the given directory have already been tested. This means that their recorded expected errors are test errors, not validation errors. However, one should never perform model selection based on the test error!\n'
+ 'Please run the validation command (again) with the overwrite option \'-o\', then this selection command.'
'One or more models in the given directory have already been tested. This means that their recorded'
' expected errors are test errors, not validation errors. However, one should never perform model '
'selection based on the test error!\n Please run the validation command (again) with the overwrite '
'option \'-o\', then this selection command.'
)
return

Expand Down Expand Up @@ -1439,6 +1460,18 @@ def _add_argument_dir_with_file_type(parser, type, or_file=False):
action='store_true',
help='use PyTorch for validation and test (including kernel evaluations in some numerical solvers)',
)
# TODO: Add more descriptors
parent_parser.add_argument(
'-descr',
'--descriptor',
metavar='<descriptor [args...]>',
dest='use_descriptor',
type=io.parse_descriptor,
help='sets the descriptor to be used and their required arguments (e.g. -descr coulomb_matrix/exp_decay_matrix).',
# choices=['coulomb_matrix', 'exp_decay_matrix'],
default=['coulomb_matrix'],
nargs='+',
)

subparsers = parser.add_subparsers(title='commands', dest='command')
subparsers.required = True
Expand Down Expand Up @@ -1472,7 +1505,8 @@ def _add_argument_dir_with_file_type(parser, type, or_file=False):
for subparser in [parser_all, parser_create]:
_add_argument_dataset(
subparser,
help='path to dataset file (train/validation/test subsets are sampled from here if no seperate dataset are specified)',
help='path to dataset file (train/validation/test subsets are sampled from here if no seperate dataset are '
'specified)',
)
_add_argument_sample_size(subparser, 'train')
_add_argument_sample_size(subparser, 'valid')
Expand Down Expand Up @@ -1571,7 +1605,8 @@ def _add_argument_dir_with_file_type(parser, type, or_file=False):
'--model0',
metavar='<initial_model_file>',
type=lambda x: io.is_file_type(x, 'model'),
help='initial model file used as a source for training task parameters, including training and validation subsets, permutations, initial set of coefficients (for numerical solvers)',
help='initial model file used as a source for training task parameters, including training and validation '
'subsets, permutations, initial set of coefficients (for numerical solvers)',
nargs='?',
default=None,
)
Expand Down Expand Up @@ -1601,6 +1636,16 @@ def _add_argument_dir_with_file_type(parser, type, or_file=False):
if not args.model_file.endswith('.npz'):
args.model_file += '.npz'

# post-processing of descriptor flag
# Looks for dependencies in case that an external descriptor is provided
if 'use_descriptor' in args:

if args.use_descriptor[0] == 'coulomb_matrix':
pass

if args.use_descriptor[0] == 'exp_decay_matrix':
pass

_print_splash(args.max_processes)

# Check PyTorch GPU support.
Expand All @@ -1609,12 +1654,15 @@ def _add_argument_dir_with_file_type(parser, type, or_file=False):
if not torch.cuda.is_available():
print() # TODO: print only if log level includes warning
log.warning(
'Your PyTorch installation does not see any GPU(s) on your system and will thus run all calculations on the CPU! Unless this is what you want, we recommend running CPU calculations without \'--torch\' for improved performance.'
'Your PyTorch installation does not see any GPU(s) on your system and will thus run all '
'calculations on the CPU! Unless this is what you want, we recommend running CPU calculations '
'without \'--torch\' for improved performance.'
)
else:
print()
log.critical(
'Optional PyTorch dependency not found! Please run \'pip install sgdml[torch]\' to install it or disable the PyTorch option.'
'Optional PyTorch dependency not found! Please run \'pip install sgdml[torch]\' to install it or '
'disable the PyTorch option.'
)
sys.exit()

Expand Down
7 changes: 6 additions & 1 deletion sgdml/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,12 @@ def __init__(

self.n_atoms = model['z'].shape[0]

self.desc = Desc(self.n_atoms, max_processes=max_processes)
self.use_descriptor = model['use_descriptor']

# elif self.use_descriptor[0] == 'non_default_descr':
# pass

self.desc = Desc(self.n_atoms, max_processes=max_processes, use_descriptor=self.use_descriptor[0])
glob['desc_func'] = self.desc

self.lat_and_inv = (
Expand Down
26 changes: 18 additions & 8 deletions sgdml/torchtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(self, model, lat_and_inv=None, batch_size=None, max_memory=None):
self._sig = int(model['sig'])
self._c = float(model['c'])
self._std = float(model.get('std', 1))
self.use_descriptor = model['use_descriptor']

desc_siz = model['R_desc'].shape[0]
n_perms, self._n_atoms = model['perms'].shape
Expand Down Expand Up @@ -191,20 +192,29 @@ def _forward(self, Rs):
dists[:, i, j] = np.inf
i, j = np.tril_indices(self._n_atoms, k=-1)

xs = 1 / dists[:, i, j] # R_desc (1000, 36)
if self.use_descriptor[0] == 'coulomb_matrix':
xs = 1 / dists[:, i, j]
elif self.use_descriptor[0] == 'exp_decay_matrix':
exp_dists = torch.exp(-dists)
xs = exp_dists[:, i, j]

x_diffs = (q * xs)[:, None, :] - q * self._xs_train
x_dists = x_diffs.norm(dim=-1)
exp_xs = 5.0 / (3 * sig ** 2) * torch.exp(-x_dists)
dot_x_diff_Jx_alphas = (x_diffs * self._Jx_alphas).sum(dim=-1)
exp_xs_1_x_dists = exp_xs * (1 + x_dists)
F1s_x = ((exp_xs * dot_x_diff_Jx_alphas)[..., None] * x_diffs).sum(dim=1)
exp_xs = 5.0 / (3 * sig ** 2) * torch.exp(-x_dists) # ok
dot_x_diff_Jx_alphas = (x_diffs * self._Jx_alphas).sum(dim=-1) # ok
exp_xs_1_x_dists = exp_xs * (1 + x_dists) # ok
F1s_x = ((exp_xs * dot_x_diff_Jx_alphas)[..., None] * x_diffs).sum(dim=1) # ok
F2s_x = exp_xs_1_x_dists.mm(self._Jx_alphas)
Fs_x = (F1s_x - F2s_x) * self._std

Fs = ((expand_tril(Fs_x) / dists ** 3)[..., None] * diffs).sum(
dim=1
) # * R_d_desc
if self.use_descriptor[0] == 'coulomb_matrix':
Fs = ((expand_tril(Fs_x) / dists ** 3)[..., None] * diffs).sum(
dim=1
) # * R_d_desc
elif self.use_descriptor[0] == 'exp_decay_matrix':
Fs = ((expand_tril(Fs_x) * exp_dists / dists)[..., None] * diffs).sum(
dim=1
) # * R_d_desc

Es = (exp_xs_1_x_dists * dot_x_diff_Jx_alphas).sum(dim=-1) / q
Es = self._c + Es * self._std
Expand Down
Loading