diff --git a/pynta/coveragedependence.py b/pynta/coveragedependence.py index 3d375118..16ed790d 100644 --- a/pynta/coveragedependence.py +++ b/pynta/coveragedependence.py @@ -1311,6 +1311,95 @@ def split_triad(tagged_grp): return out_pairs +def is_descendent_of_or_is(node,ancestor_node): + n = node + while n.parent is not None: + if n is ancestor_node: + return True + else: + n = n.parent + return False + +class CoverageDependenceRegressor(MultiEvalSubgraphIsomorphicDecisionTreeRegressor): + def fit_rule(self, alpha=0.1): + max_depth = max([node.depth for node in self.nodes.values()]) + y = np.array([datum.value for datum in self.datums]) + preds = np.zeros(len(self.datums)) + self.node_uncertainties = dict() + weights = self.weights + W = self.W + triad_node = self.nodes["Root_Triad"] + for pair in [True,False]: + for depth in range(max_depth + 1): + if depth == 0: + self.nodes["Root"].rule = Rule(value=0.0,num_data=0) + continue #skip Root node + else: + if pair: + nodes = [node for node in self.nodes.values() if node.depth == depth and not is_descendent_of_or_is(node,triad_node)] + else: + nodes = [node for node in self.nodes.values() if node.depth == depth and is_descendent_of_or_is(node,triad_node)] + + if len(nodes) == 0: + continue + + # generate matrix + A = sp.lil_matrix((len(self.datums), len(nodes))) + y -= preds + + for i, datum in enumerate(self.datums): + for node in self.mol_node_maps[datum]["nodes"]: + while node is not None: + if node in nodes: + j = nodes.index(node) + A[i, j] += 1.0 + node = node.parent + + clf = linear_model.Lasso( + alpha=alpha, + fit_intercept=False, + tol=1e-4, + max_iter=1000000000, + selection="random", + ) + if weights is not None: + lasso = clf.fit(A, y, sample_weight=weights) + else: + lasso = clf.fit(A, y) + + preds = A * clf.coef_ + self.data_delta = preds - y + + for i, val in enumerate(clf.coef_): + nodes[i].rule = Rule(value=val, num_data=np.sum(A[:, i])) + + train_error = [self.evaluate(d.mol, estimate_uncertainty=False) - d.value for d in self.datums] + + logging.info("training MAE: {}".format(np.mean(np.abs(np.array(train_error))))) + + if self.validation_set: + val_error = [self.evaluate(d.mol, estimate_uncertainty=False) - d.value for d in self.validation_set] + val_mae = np.mean(np.abs(np.array(val_error))) + if val_mae < self.min_val_error: + self.min_val_error = val_mae + self.best_tree_nodes = list(self.nodes.keys()) + self.bestA = A + self.best_nodes = {k: v for k, v in self.nodes.items()} + self.best_mol_node_maps = { + k: {"mols": v["mols"][:], "nodes": v["nodes"][:]} + for k, v in self.mol_node_maps.items() + } + self.best_rule_map = {name:self.nodes[name].rule for name in self.best_tree_nodes} + self.val_mae = val_mae + logging.info("validation MAE: {}".format(self.val_mae)) + + if self.test_set: + test_error = [self.evaluate(d.mol) - d.value for d in self.test_set] + test_mae = np.mean(np.abs(np.array(test_error))) + logging.info("test MAE: {}".format(test_mae)) + + logging.info("# nodes: {}".format(len(self.nodes))) + def train_sidt_cov_dep_regressor(pairs_datums,sampling_datums,Nconfigs,Ncoads,r_site=None, r_atoms=None,node_fract_training=0.9): @@ -1401,7 +1490,7 @@ def train_sidt_cov_dep_regressor(pairs_datums,sampling_datums,Nconfigs,Ncoads,r_ for n in root_triad_node.children: nodes[n.name] = n - tree = MultiEvalSubgraphIsomorphicDecisionTreeRegressor([adsorbate_interaction_decomposition,adsorbate_triad_interaction_decomposition], + tree = CoverageDependenceRegressor([adsorbate_interaction_decomposition,adsorbate_triad_interaction_decomposition], nodes=nodes, r=[ATOMTYPES[x] for x in r_atoms], r_bonds=[1,2,3,0.05],