diff --git a/sgdml/cli.py b/sgdml/cli.py index 998bf0b..8478b96 100644 --- a/sgdml/cli.py +++ b/sgdml/cli.py @@ -223,7 +223,7 @@ def _print_dataset_properties(dataset, title_str='Dataset properties'): def _print_task_properties( - use_sym, use_cprsn, use_E, use_E_cstr, title_str='Task properties' + use_sym, use_cprsn, use_E, use_E_cstr, descriptor, title_str='Task properties' ): print(ui.white_bold_str(title_str)) @@ -252,7 +252,11 @@ def _print_task_properties( 'Compression:', 'requested' if use_cprsn else 'not requested' ) ) - + print( + ' {:<16} {}'.format( + 'Descriptor:', descriptor[0] + ) + ) def _print_model_properties(model, title_str='Model properties'): @@ -270,6 +274,10 @@ def _print_model_properties(model, title_str='Model properties'): _, cprsn_keep_idxs = np.unique( np.sort(model['perms'], axis=0), axis=1, return_index=True ) + try: + print(' {:<18} {}'.format('Descriptor:', model['use_descriptor'][0])) + except: + print(' {:<18} {}'.format('Descriptor:', 'descriptor was not specified.')) n_atoms_kept = cprsn_keep_idxs.shape[0] print( ' {:<18} {}'.format( @@ -349,6 +357,7 @@ def all( use_E, use_E_cstr, use_cprsn, + use_descriptor, overwrite, max_processes, use_torch, @@ -398,6 +407,7 @@ def all( use_E, use_E_cstr, use_cprsn, + use_descriptor, overwrite, max_processes, task_dir, @@ -424,6 +434,7 @@ def all( overwrite=False, max_processes=max_processes, use_torch=use_torch, + use_descriptor=use_descriptor, **kwargs ) @@ -443,6 +454,7 @@ def all( overwrite=False, max_processes=max_processes, use_torch=use_torch, + use_descriptor=use_descriptor, **kwargs ) @@ -466,6 +478,7 @@ def create( # noqa: C901 use_E, use_E_cstr, use_cprsn, + use_descriptor, overwrite, max_processes, task_dir=None, @@ -487,7 +500,7 @@ def create( # noqa: C901 print() _print_task_properties( - use_sym=not gdml, use_cprsn=use_cprsn, use_E=use_E, use_E_cstr=use_E_cstr + use_sym=not gdml, use_cprsn=use_cprsn, use_E=use_E, use_E_cstr=use_E_cstr, descriptor=use_descriptor ) print() @@ -531,6 +544,7 @@ def create( # noqa: C901 use_cprsn=use_cprsn, use_E=use_E, use_E_cstr=use_E_cstr, + descriptor=use_descriptor[0], model0=model0, ) @@ -583,7 +597,8 @@ def create( # noqa: C901 if not use_E: log.info( 'Energy labels will be ignored for training.\n' - + 'Note: If available in the dataset file, the energy labels will however still be used to generate stratified training, test and validation datasets. Otherwise a random sampling is used.' + + 'Note: If available in the dataset file, the energy labels will however still be used to generate ' + 'stratified training, test and validation datasets. Otherwise a random sampling is used.' ) if 'E' not in dataset: @@ -618,6 +633,7 @@ def create( # noqa: C901 n_train, valid_dataset, n_valid, + use_descriptor, sig=1, use_sym=not gdml, use_E=use_E, @@ -808,7 +824,7 @@ def _online_err(err, size, n, mae_n_sum, rmse_n_sum): def validate( - model_dir, dataset, overwrite, max_processes, use_torch, command=None, **kwargs + model_dir, dataset, overwrite, max_processes, use_torch, use_descriptor, command=None, **kwargs ): dataset_path_extracted, dataset_extracted = dataset @@ -827,6 +843,7 @@ def validate( overwrite, max_processes, use_torch, + use_descriptor, command, **kwargs ) @@ -859,6 +876,7 @@ def test( overwrite, max_processes, use_torch, + use_descriptor, command=None, **kwargs ): # noqa: C901 @@ -1176,7 +1194,8 @@ def test( else 'only {:,}'.format(len(test_idxs)) ) log.warning( - 'This model has previously been tested on {:,} points, which is why the errors for the current test run with {} points have NOT been used to update the model file.\n'.format( + 'This model has previously been tested on {:,} points, which is why the errors for the current test ' + 'run with {} points have NOT been used to update the model file.\n'.format( model['n_test'], add_info_str ) + 'Run \'{} test -o {} {} {}\' to overwrite.'.format( @@ -1258,8 +1277,10 @@ def select( if any_model_is_tested: log.error( - 'One or more models in the given directory have already been tested. This means that their recorded expected errors are test errors, not validation errors. However, one should never perform model selection based on the test error!\n' - + 'Please run the validation command (again) with the overwrite option \'-o\', then this selection command.' + 'One or more models in the given directory have already been tested. This means that their recorded' + ' expected errors are test errors, not validation errors. However, one should never perform model ' + 'selection based on the test error!\n Please run the validation command (again) with the overwrite ' + 'option \'-o\', then this selection command.' ) return @@ -1439,6 +1460,18 @@ def _add_argument_dir_with_file_type(parser, type, or_file=False): action='store_true', help='use PyTorch for validation and test (including kernel evaluations in some numerical solvers)', ) + # TODO: Add more descriptors + parent_parser.add_argument( + '-descr', + '--descriptor', + metavar='', + dest='use_descriptor', + type=io.parse_descriptor, + help='sets the descriptor to be used and their required arguments (e.g. -descr coulomb_matrix/exp_decay_matrix).', + # choices=['coulomb_matrix', 'exp_decay_matrix'], + default=['coulomb_matrix'], + nargs='+', + ) subparsers = parser.add_subparsers(title='commands', dest='command') subparsers.required = True @@ -1472,7 +1505,8 @@ def _add_argument_dir_with_file_type(parser, type, or_file=False): for subparser in [parser_all, parser_create]: _add_argument_dataset( subparser, - help='path to dataset file (train/validation/test subsets are sampled from here if no seperate dataset are specified)', + help='path to dataset file (train/validation/test subsets are sampled from here if no seperate dataset are ' + 'specified)', ) _add_argument_sample_size(subparser, 'train') _add_argument_sample_size(subparser, 'valid') @@ -1571,7 +1605,8 @@ def _add_argument_dir_with_file_type(parser, type, or_file=False): '--model0', metavar='', type=lambda x: io.is_file_type(x, 'model'), - help='initial model file used as a source for training task parameters, including training and validation subsets, permutations, initial set of coefficients (for numerical solvers)', + help='initial model file used as a source for training task parameters, including training and validation ' + 'subsets, permutations, initial set of coefficients (for numerical solvers)', nargs='?', default=None, ) @@ -1601,6 +1636,16 @@ def _add_argument_dir_with_file_type(parser, type, or_file=False): if not args.model_file.endswith('.npz'): args.model_file += '.npz' + # post-processing of descriptor flag + # Looks for dependencies in case that an external descriptor is provided + if 'use_descriptor' in args: + + if args.use_descriptor[0] == 'coulomb_matrix': + pass + + if args.use_descriptor[0] == 'exp_decay_matrix': + pass + _print_splash(args.max_processes) # Check PyTorch GPU support. @@ -1609,12 +1654,15 @@ def _add_argument_dir_with_file_type(parser, type, or_file=False): if not torch.cuda.is_available(): print() # TODO: print only if log level includes warning log.warning( - 'Your PyTorch installation does not see any GPU(s) on your system and will thus run all calculations on the CPU! Unless this is what you want, we recommend running CPU calculations without \'--torch\' for improved performance.' + 'Your PyTorch installation does not see any GPU(s) on your system and will thus run all ' + 'calculations on the CPU! Unless this is what you want, we recommend running CPU calculations ' + 'without \'--torch\' for improved performance.' ) else: print() log.critical( - 'Optional PyTorch dependency not found! Please run \'pip install sgdml[torch]\' to install it or disable the PyTorch option.' + 'Optional PyTorch dependency not found! Please run \'pip install sgdml[torch]\' to install it or ' + 'disable the PyTorch option.' ) sys.exit() diff --git a/sgdml/predict.py b/sgdml/predict.py index 5191b9e..539c3c0 100644 --- a/sgdml/predict.py +++ b/sgdml/predict.py @@ -280,7 +280,12 @@ def __init__( self.n_atoms = model['z'].shape[0] - self.desc = Desc(self.n_atoms, max_processes=max_processes) + self.use_descriptor = model['use_descriptor'] + + # elif self.use_descriptor[0] == 'non_default_descr': + # pass + + self.desc = Desc(self.n_atoms, max_processes=max_processes, use_descriptor=self.use_descriptor[0]) glob['desc_func'] = self.desc self.lat_and_inv = ( diff --git a/sgdml/torchtools.py b/sgdml/torchtools.py index 7f41f08..079d7d3 100644 --- a/sgdml/torchtools.py +++ b/sgdml/torchtools.py @@ -85,6 +85,7 @@ def __init__(self, model, lat_and_inv=None, batch_size=None, max_memory=None): self._sig = int(model['sig']) self._c = float(model['c']) self._std = float(model.get('std', 1)) + self.use_descriptor = model['use_descriptor'] desc_siz = model['R_desc'].shape[0] n_perms, self._n_atoms = model['perms'].shape @@ -191,20 +192,29 @@ def _forward(self, Rs): dists[:, i, j] = np.inf i, j = np.tril_indices(self._n_atoms, k=-1) - xs = 1 / dists[:, i, j] # R_desc (1000, 36) + if self.use_descriptor[0] == 'coulomb_matrix': + xs = 1 / dists[:, i, j] + elif self.use_descriptor[0] == 'exp_decay_matrix': + exp_dists = torch.exp(-dists) + xs = exp_dists[:, i, j] x_diffs = (q * xs)[:, None, :] - q * self._xs_train x_dists = x_diffs.norm(dim=-1) - exp_xs = 5.0 / (3 * sig ** 2) * torch.exp(-x_dists) - dot_x_diff_Jx_alphas = (x_diffs * self._Jx_alphas).sum(dim=-1) - exp_xs_1_x_dists = exp_xs * (1 + x_dists) - F1s_x = ((exp_xs * dot_x_diff_Jx_alphas)[..., None] * x_diffs).sum(dim=1) + exp_xs = 5.0 / (3 * sig ** 2) * torch.exp(-x_dists) # ok + dot_x_diff_Jx_alphas = (x_diffs * self._Jx_alphas).sum(dim=-1) # ok + exp_xs_1_x_dists = exp_xs * (1 + x_dists) # ok + F1s_x = ((exp_xs * dot_x_diff_Jx_alphas)[..., None] * x_diffs).sum(dim=1) # ok F2s_x = exp_xs_1_x_dists.mm(self._Jx_alphas) Fs_x = (F1s_x - F2s_x) * self._std - Fs = ((expand_tril(Fs_x) / dists ** 3)[..., None] * diffs).sum( - dim=1 - ) # * R_d_desc + if self.use_descriptor[0] == 'coulomb_matrix': + Fs = ((expand_tril(Fs_x) / dists ** 3)[..., None] * diffs).sum( + dim=1 + ) # * R_d_desc + elif self.use_descriptor[0] == 'exp_decay_matrix': + Fs = ((expand_tril(Fs_x) * exp_dists / dists)[..., None] * diffs).sum( + dim=1 + ) # * R_d_desc Es = (exp_xs_1_x_dists * dot_x_diff_Jx_alphas).sum(dim=-1) / q Es = self._c + Es * self._std diff --git a/sgdml/train.py b/sgdml/train.py index efca944..7804549 100755 --- a/sgdml/train.py +++ b/sgdml/train.py @@ -261,7 +261,8 @@ def __init__(self, max_processes=None, use_torch=False): if use_torch and not _has_torch: raise ImportError( - 'Optional PyTorch dependency not found! Please run \'pip install sgdml[torch]\' to install it or disable the PyTorch option.' + 'Optional PyTorch dependency not found! Please run \'pip install sgdml[torch]\' to install it or ' + 'disable the PyTorch option.' ) def __del__(self): @@ -275,6 +276,7 @@ def create_task( n_train, valid_dataset, n_valid, + use_descriptor, sig, lam=1e-15, use_sym=True, @@ -311,6 +313,8 @@ def create_task( validation dataset. n_valid : int Number of validation points to sample. + use_descriptor : list + List containing the name of the descriptor to use. sig : int Hyper-parameter (kernel length scale). lam : float, optional @@ -453,6 +457,7 @@ def create_task( 'use_E_cstr': use_E_cstr, 'use_sym': use_sym, 'use_cprsn': use_cprsn, + 'use_descriptor': use_descriptor, 'solver_name': solver, 'solver_tol': solver_tol, } @@ -470,7 +475,8 @@ def create_task( lat_and_inv = (task['lattice'], np.linalg.inv(task['lattice'])) except np.linalg.LinAlgError: raise ValueError( # TODO: Document me - 'Provided dataset contains invalid lattice vectors (not invertible). Note: Only rank 3 lattice vector matrices are supported.' + 'Provided dataset contains invalid lattice vectors (not invertible). Note: Only rank 3 lattice ' + 'vector matrices are supported.' ) if 'r_unit' in train_dataset and 'e_unit' in train_dataset: @@ -486,7 +492,8 @@ def create_task( np.random.choice(n_train, 1000, replace=False), :, : ] self.log.info( - 'Symmetry search has been restricted to a random subset of 1000/{:d} training points for faster convergence.'.format( + 'Symmetry search has been restricted to a random subset of 1000/{:d} training points for ' + 'faster convergence.'.format( n_train ) ) @@ -701,6 +708,7 @@ def create_model( 'tril_perms_lin': tril_perms_lin, 'use_E': task['use_E'], 'use_cprsn': task['use_cprsn'], + 'use_descriptor': task['use_descriptor'], } if solver_resid is not None: @@ -802,15 +810,25 @@ def train( # noqa: C901 n_train, n_atoms = task['R_train'].shape[:2] + use_descriptor = task['use_descriptor'] + desc = Desc( - n_atoms, max_processes=self._max_processes + n_atoms, max_processes=self._max_processes, use_descriptor=use_descriptor[0] ) sig = np.squeeze(task['sig']) lam = np.squeeze(task['lam']) n_perms = task['perms'].shape[0] - tril_perms = np.array([desc.perm(p) for p in task['perms']]) + + if use_descriptor[0] == 'coulomb_matrix' or use_descriptor[0] == 'exp_decay_matrix': + tril_perms = np.array([desc.perm(p) for p in task['perms']]) + descriptor_type = { + 'use_descriptor': use_descriptor[0] + } + + # elif 'other_descriptor' == use_descriptor[0]: + # pass dim_i = 3 * n_atoms dim_d = desc.dim @@ -826,7 +844,8 @@ def train( # noqa: C901 lat_and_inv = (task['lattice'], np.linalg.inv(task['lattice'])) except np.linalg.LinAlgError: raise ValueError( # TODO: Document me - 'Provided dataset contains invalid lattice vectors (not invertible). Note: Only rank 3 lattice vector matrices are supported.' + 'Provided dataset contains invalid lattice vectors (not invertible). Note: Only rank 3 lattice ' + 'vector matrices are supported.' ) # # TODO: check if all atoms are within unit cell @@ -1085,14 +1104,17 @@ def _recov_int_const(self, model, task): # TODO: document e_err_inconsist retur self.log.warning( 'The provided dataset contains gradients instead of force labels (flipped sign). Please correct!\n' + ui.color_str('Note:', bold=True) - + 'Note: The energy prediction accuracy of the model will thus neither be validated nor tested in the following steps!' + + 'Note: The energy prediction accuracy of the model will thus neither be validated nor tested in the ' + 'following steps!' ) return None if corrcoef < 0.95: self.log.warning( 'Inconsistent energy labels detected!\n' - + 'The predicted energies for the training data are only weakly correlated with the reference labels (correlation coefficient {:.2f}) which indicates that the issue is most likely NOT just a unit conversion error.\n\n'.format( + + 'The predicted energies for the training data are only weakly correlated with the reference labels ' + '(correlation coefficient {:.2f}) which indicates that the issue is most likely NOT just a unit ' + 'conversion error.\n\n'.format( corrcoef ) + ui.color_str('Troubleshooting tips:\n', bold=True) @@ -1132,7 +1154,8 @@ def _recov_int_const(self, model, task): # TODO: document e_err_inconsist retur if np.abs(e_fact - 1) > 1e-1: self.log.warning( 'Different scales in energy vs. force labels detected!\n' - + 'The integrated forces differ from the energy labels by factor ~{:.2f}, meaning that the trained model will likely fail to predict energies accurately.\n\n'.format( + + 'The integrated forces differ from the energy labels by factor ~{:.2f}, meaning that the trained ' + 'model will likely fail to predict energies accurately.\n\n'.format( e_fact ) + ui.color_str('Troubleshooting tips:\n', bold=True) diff --git a/sgdml/utils/desc.py b/sgdml/utils/desc.py index fd9ccc9..450dd5c 100755 --- a/sgdml/utils/desc.py +++ b/sgdml/utils/desc.py @@ -46,7 +46,7 @@ def _from_r_alias(obj, r, lat_and_inv=None): class Desc(object): - def __init__(self, n_atoms, max_processes=None): + def __init__(self, n_atoms, max_processes=None, use_descriptor='coulomb_matrix'): """ Generate descriptors and their Jacobians for molecular geometries, including support for periodic boundary conditions. @@ -64,28 +64,36 @@ def __init__(self, n_atoms, max_processes=None): self.n_atoms = n_atoms self.dim_i = 3 * n_atoms - # Size of the resulting descriptor vector. - self.dim = (n_atoms * (n_atoms - 1)) // 2 - - # Precompute indices for nonzero entries in desriptor derivatives. - self.d_desc_mask = np.zeros((n_atoms, n_atoms - 1), dtype=np.int) - for a in range(n_atoms): # for each partial derivative - rows, cols = np.tril_indices(n_atoms, -1) - self.d_desc_mask[a, :] = np.concatenate( - [np.where(rows == a)[0], np.where(cols == a)[0]] - ) - - self.M = np.arange(1, n_atoms) # indexes matrix row-wise, skipping diagonal - for a in range(1, n_atoms): - self.M = np.concatenate((self.M, np.delete(np.arange(n_atoms), a))) + self.use_descriptor = use_descriptor - self.A = np.repeat( - np.arange(n_atoms), n_atoms - 1 - ) # [0, 0, ..., 1, 1, ..., 2, 2, ...] - - self.d_desc = np.zeros( - (self.dim, n_atoms, 3) - ) # template for descriptor matrix (zeros are important) + # Size of the resulting descriptor vector. + if self.use_descriptor == 'coulomb_matrix' or self.use_descriptor == 'exp_decay_matrix': + # Size of the resulting descriptor vector. + self.dim = (n_atoms * (n_atoms - 1)) // 2 + + # Precompute indices for nonzero entries in desriptor derivatives. + self.d_desc_mask = np.zeros((n_atoms, n_atoms - 1), dtype=np.int) + for a in range(n_atoms): # for each partial derivative + rows, cols = np.tril_indices(n_atoms, -1) + self.d_desc_mask[a, :] = np.concatenate( + [np.where(rows == a)[0], np.where(cols == a)[0]] + ) + + self.M = np.arange(1, n_atoms) # indexes matrix row-wise, skipping diagonal + for a in range(1, n_atoms): + self.M = np.concatenate((self.M, np.delete(np.arange(n_atoms), a))) + + self.A = np.repeat( + np.arange(n_atoms), n_atoms - 1 + ) # [0, 0, ..., 1, 1, ..., 2, 2, ...] + + self.d_desc = np.zeros( + (self.dim, n_atoms, 3) + ) # template for descriptor matrix (zeros are important) + + # --- TODO: Add precomputable variables for new descriptor here + # elif self.use_descriptor == '': + # pass self.max_processes = max_processes @@ -250,8 +258,16 @@ def _from_r(self, r, lat_and_inv=None): pd = self._pdist(r, lat_and_inv) - r_desc = self._r_to_desc(r, pd) - r_d_desc = self._r_to_d_desc(r, pd, lat_and_inv) + if self.use_descriptor == 'coulomb_matrix': + pd = self._pdist(r, lat_and_inv) + + r_desc = self._r_to_desc(r, pd) + r_d_desc = self._r_to_d_desc(r, pd, lat_and_inv) + + elif self.use_descriptor == 'exp_decay_matrix': + pd = self._pdist(r, lat_and_inv) + r_desc = self._r_to_desc_exp_decay(r, pd) + r_d_desc = self._r_to_d_desc_exp_decay(r, pd, lat_and_inv) return r_desc, r_d_desc @@ -284,6 +300,7 @@ def _pdist(self, r, lat_and_inv=None): return sp.spatial.distance.squareform(pdist, checks=False) + # --- Coulomb Matrix descriptor def _r_to_desc(self, r, pdist): """ Generate descriptor for a set of atom positions in Cartesian @@ -352,3 +369,69 @@ def _r_to_d_desc(self, r, pdist, lat_and_inv=None): ] return self.d_desc.reshape(self.dim, self.dim_i) + + # --- Exponential decay matrix + def _r_to_desc_exp_decay(self, r, pdist): + """ + Generate descriptor for a set of atom positions in Cartesian + coordinates. + Parameters + ---------- + r : :obj:`numpy.ndarray` + Array of size 3N containing the Cartesian coordinates of + each atom. + pdist : :obj:`numpy.ndarray` + Array of size N x N containing the Euclidean distance + (2-norm) for each pair of atoms. + Returns + ------- + :obj:`numpy.ndarray` + Descriptor representation as 1D array of size N(N-1)/2 + """ + + # Add singleton dimension if input is (,3N). + if r.ndim == 1: + r = r[None, :] + + return np.exp(-pdist[np.tril_indices(self.n_atoms, -1)]) + + def _r_to_d_desc_exp_decay(self, r, pdist, lat_and_inv=None): + """ + Generate descriptor Jacobian for a set of atom positions in + Cartesian coordinates. + This method can apply the minimum-image convention as periodic + boundary condition for distances between atoms, given the edge + length of the (square) unit cell. + Parameters + ---------- + r : :obj:`numpy.ndarray` + Array of size 3N containing the Cartesian coordinates of + each atom. + pdist : :obj:`numpy.ndarray` + Array of size N x N containing the Euclidean distance + (2-norm) for each pair of atoms. + lat_and_inv : tuple of :obj:`numpy.ndarray`, optional + Tuple of 3x3 matrix containing lattice vectors as columns and its inverse. + Returns + ------- + :obj:`numpy.ndarray` + Array of size N(N-1)/2 x 3N containing all partial + derivatives of the descriptor. + """ + + r = r.reshape(-1, 3) + + np.seterr(divide='ignore', invalid='ignore') + + pdiff = r[:, None] - r[None, :] # pairwise differences ri - rj + if lat_and_inv is not None: + pdiff = self.pbc_diff( + pdiff.reshape(self.n_atoms ** 2, 3), lat_and_inv + ).reshape(self.n_atoms, self.n_atoms, 3) + + d_desc_elem = np.exp(-pdist[:, :, None]) * pdiff / pdist[:, :, None] + self.d_desc[self.d_desc_mask.ravel(), self.A, :] = d_desc_elem[ + self.M, self.A, : + ] + + return self.d_desc.reshape(self.dim, self.dim_i) \ No newline at end of file diff --git a/sgdml/utils/io.py b/sgdml/utils/io.py index 1868a78..52b5176 100755 --- a/sgdml/utils/io.py +++ b/sgdml/utils/io.py @@ -160,7 +160,7 @@ def z_to_z_str(z): def train_dir_name( - dataset, n_train, use_sym, use_cprsn, use_E, use_E_cstr, model0=None + dataset, n_train, use_sym, use_cprsn, use_E, use_E_cstr, descriptor, model0=None ): theory_level_str = re.sub(r'[^\w\-_\.]', '.', str(dataset['theory'])) @@ -172,10 +172,11 @@ def train_dir_name( noE_str = '-noE' if not use_E else '' Ecstr_str = '-Ecstr' if use_E_cstr else '' - return '%ssgdml_cv_%s-%s-train%d%s%s%s%s' % ( + return '%ssgdml_cv_%s-%s-%s-train%d%s%s%s%s' % ( m0_str, dataset['name'].astype(str), theory_level_str, + descriptor, n_train, sym_str, cprsn_str, @@ -190,7 +191,7 @@ def task_file_name(task): n_perms = task['perms'].shape[0] sig = np.squeeze(task['sig']) - return 'task-train%d-sym%d-sig%04d.npz' % (n_train, n_perms, sig) + return 'task-%s-train%d-sym%d-sig%04d.npz' % (task['use_descriptor'][0], n_train, n_perms, sig) def model_file_name(task_or_model, is_extended=False): @@ -205,8 +206,8 @@ def model_file_name(task_or_model, is_extended=False): r'[^\w\-_\.]', '.', str(np.squeeze(task_or_model['dataset_theory'])) ) theory_level_str = re.sub(r'\.\.', '.', theory_level_str) - return '%s-%s-train%d-sym%d.npz' % (dataset, theory_level_str, n_train, n_perms) - return 'model-train%d-sym%d-sig%04d.npz' % (n_train, n_perms, sig) + return '%s-%s-%s-train%d-sym%d.npz' % (dataset, theory_level_str, task_or_model['use_descriptor'][0], n_train, n_perms) + return 'model-%s-train%d-sym%d-sig%04d.npz' % (task_or_model['use_descriptor'][0], n_train, n_perms, sig) def dataset_md5(dataset): @@ -705,3 +706,21 @@ def parse_list_or_range(arg): arg ) ) + +def parse_descriptor(arg): + """ + Parses a string that represents either just descriptor's name or required arguments too. + Parameters + ---------- + arg : :obj:`str` + string. + Returns + ------- + str or :obj:`list` of str + Raises + ------ + ArgumentTypeError + If input can neither be interpreted as a str or list of str. + """ + + return arg \ No newline at end of file