Skip to content

Commit

Permalink
optimize pair groups first
Browse files Browse the repository at this point in the history
  • Loading branch information
mjohnson541 committed Oct 15, 2024
1 parent 0b9cfa5 commit efea275
Showing 1 changed file with 90 additions and 1 deletion.
91 changes: 90 additions & 1 deletion pynta/coveragedependence.py
Original file line number Diff line number Diff line change
Expand Up @@ -1311,6 +1311,95 @@ def split_triad(tagged_grp):

return out_pairs

def is_descendent_of_or_is(node,ancestor_node):
n = node
while n.parent is not None:
if n is ancestor_node:
return True
else:
n = n.parent
return False

class CoverageDependenceRegressor(MultiEvalSubgraphIsomorphicDecisionTreeRegressor):
def fit_rule(self, alpha=0.1):
max_depth = max([node.depth for node in self.nodes.values()])
y = np.array([datum.value for datum in self.datums])
preds = np.zeros(len(self.datums))
self.node_uncertainties = dict()
weights = self.weights
W = self.W
triad_node = self.nodes["Root_Triad"]
for pair in [True,False]:
for depth in range(max_depth + 1):
if depth == 0:
self.nodes["Root"].rule = Rule(value=0.0,num_data=0)
continue #skip Root node
else:
if pair:
nodes = [node for node in self.nodes.values() if node.depth == depth and not is_descendent_of_or_is(node,triad_node)]
else:
nodes = [node for node in self.nodes.values() if node.depth == depth and is_descendent_of_or_is(node,triad_node)]

if len(nodes) == 0:
continue

# generate matrix
A = sp.lil_matrix((len(self.datums), len(nodes)))
y -= preds

for i, datum in enumerate(self.datums):
for node in self.mol_node_maps[datum]["nodes"]:
while node is not None:
if node in nodes:
j = nodes.index(node)
A[i, j] += 1.0
node = node.parent

clf = linear_model.Lasso(
alpha=alpha,
fit_intercept=False,
tol=1e-4,
max_iter=1000000000,
selection="random",
)
if weights is not None:
lasso = clf.fit(A, y, sample_weight=weights)
else:
lasso = clf.fit(A, y)

preds = A * clf.coef_
self.data_delta = preds - y

for i, val in enumerate(clf.coef_):
nodes[i].rule = Rule(value=val, num_data=np.sum(A[:, i]))

train_error = [self.evaluate(d.mol, estimate_uncertainty=False) - d.value for d in self.datums]

logging.info("training MAE: {}".format(np.mean(np.abs(np.array(train_error)))))

if self.validation_set:
val_error = [self.evaluate(d.mol, estimate_uncertainty=False) - d.value for d in self.validation_set]
val_mae = np.mean(np.abs(np.array(val_error)))
if val_mae < self.min_val_error:
self.min_val_error = val_mae
self.best_tree_nodes = list(self.nodes.keys())
self.bestA = A
self.best_nodes = {k: v for k, v in self.nodes.items()}
self.best_mol_node_maps = {
k: {"mols": v["mols"][:], "nodes": v["nodes"][:]}
for k, v in self.mol_node_maps.items()
}
self.best_rule_map = {name:self.nodes[name].rule for name in self.best_tree_nodes}
self.val_mae = val_mae
logging.info("validation MAE: {}".format(self.val_mae))

if self.test_set:
test_error = [self.evaluate(d.mol) - d.value for d in self.test_set]
test_mae = np.mean(np.abs(np.array(test_error)))
logging.info("test MAE: {}".format(test_mae))

logging.info("# nodes: {}".format(len(self.nodes)))

def train_sidt_cov_dep_regressor(pairs_datums,sampling_datums,Nconfigs,Ncoads,r_site=None,
r_atoms=None,node_fract_training=0.9):

Expand Down Expand Up @@ -1401,7 +1490,7 @@ def train_sidt_cov_dep_regressor(pairs_datums,sampling_datums,Nconfigs,Ncoads,r_
for n in root_triad_node.children:
nodes[n.name] = n

tree = MultiEvalSubgraphIsomorphicDecisionTreeRegressor([adsorbate_interaction_decomposition,adsorbate_triad_interaction_decomposition],
tree = CoverageDependenceRegressor([adsorbate_interaction_decomposition,adsorbate_triad_interaction_decomposition],
nodes=nodes,
r=[ATOMTYPES[x] for x in r_atoms],
r_bonds=[1,2,3,0.05],
Expand Down

0 comments on commit efea275

Please sign in to comment.