Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Development #108

Merged
merged 15 commits into from
Jun 20, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix(gnn_dropout): mask error due to not dropping out ew as well
jyaacoub committed Jun 20, 2024
commit d939e4612fe5139be2f2e2c1e52206042c58fd8d
16 changes: 8 additions & 8 deletions src/models/branches.py
Original file line number Diff line number Diff line change
@@ -86,22 +86,22 @@ def forward(self, data):
ew = data.edge_weight if (self.edge_weight is not None and
self.edge_weight != 'binary') else None

target_x = self.relu(target_x)
ei_drp, _, _ = dropout_node(ei, p=self.dropout_gnn, num_nodes=target_x.shape[0],
training=self.training)
xt = self.relu(target_x)
ei_drp, e_mask, _ = dropout_node(ei, p=self.dropout_gnn, num_nodes=target_x.shape[0],
training=self.training)

# conv1
xt = self.conv1(target_x, ei_drp, ew)
xt = self.conv1(xt, ei_drp, ew[e_mask] if ew is not None else ew)
xt = self.relu(xt)
ei_drp, _, _ = dropout_node(ei, p=self.dropout_gnn, num_nodes=target_x.shape[0],
ei_drp, e_mask, _ = dropout_node(ei, p=self.dropout_gnn, num_nodes=target_x.shape[0],
training=self.training)
# conv2
xt = self.conv2(xt, ei_drp, ew)
xt = self.conv2(xt, ei_drp, ew[e_mask] if ew is not None else ew)
xt = self.relu(xt)
ei_drp, _, _ = dropout_node(ei, p=self.dropout_gnn, num_nodes=target_x.shape[0],
ei_drp, e_mask, _ = dropout_node(ei, p=self.dropout_gnn, num_nodes=target_x.shape[0],
training=self.training)
# conv3
xt = self.conv3(xt, ei_drp, ew)
xt = self.conv3(xt, ei_drp, ew[e_mask] if ew is not None else ew)
xt = self.relu(xt)

# flatten/pool
34 changes: 17 additions & 17 deletions src/models/esm_models.py
Original file line number Diff line number Diff line change
@@ -95,22 +95,22 @@ def forward_pro(self, data):
ew = data.edge_weight if (self.edge_weight is not None and
self.edge_weight != 'binary') else None

target_x = self.relu(target_x)
ei_drp, _, _ = dropout_node(ei, p=self.dropout_prot_p, num_nodes=target_x.shape[0],
training=self.training)
xt = self.relu(target_x)
ei_drp, e_mask, _ = dropout_node(ei, p=self.dropout_gnn, num_nodes=target_x.shape[0],
training=self.training)

# conv1
xt = self.pro_conv1(target_x, ei_drp, ew)
xt = self.conv1(xt, ei_drp, ew[e_mask] if ew is not None else ew)
xt = self.relu(xt)
ei_drp, _, _ = dropout_node(ei, p=self.dropout_prot_p, num_nodes=target_x.shape[0],
ei_drp, e_mask, _ = dropout_node(ei, p=self.dropout_gnn, num_nodes=target_x.shape[0],
training=self.training)
# conv2
xt = self.pro_conv2(xt, ei_drp, ew)
xt = self.conv2(xt, ei_drp, ew[e_mask] if ew is not None else ew)
xt = self.relu(xt)
ei_drp, _, _ = dropout_node(ei, p=self.dropout_prot_p, num_nodes=target_x.shape[0],
ei_drp, e_mask, _ = dropout_node(ei, p=self.dropout_gnn, num_nodes=target_x.shape[0],
training=self.training)
# conv3
xt = self.pro_conv3(xt, ei_drp, ew)
xt = self.conv3(xt, ei_drp, ew[e_mask] if ew is not None else ew)
xt = self.relu(xt)

# flatten/pool
@@ -257,24 +257,24 @@ def forward_pro(self, data):
ew = data.edge_weight if (self.edge_weight is not None and
self.edge_weight != 'binary') else None

target_x = self.relu(target_x)
ei_drp, _, _ = dropout_node(ei, p=self.dropout_prot_p, num_nodes=target_x.shape[0],
training=self.training)
xt = self.relu(target_x)
ei_drp, e_mask, _ = dropout_node(ei, p=self.dropout_gnn, num_nodes=target_x.shape[0],
training=self.training)

# conv1
xt = self.pro_conv1(target_x, ei_drp, ew)
xt = self.conv1(xt, ei_drp, ew[e_mask] if ew is not None else ew)
xt = self.relu(xt)
ei_drp, _, _ = dropout_node(ei, p=self.dropout_prot_p, num_nodes=target_x.shape[0],
ei_drp, e_mask, _ = dropout_node(ei, p=self.dropout_gnn, num_nodes=target_x.shape[0],
training=self.training)
# conv2
xt = self.pro_conv2(xt, ei_drp, ew)
xt = self.conv2(xt, ei_drp, ew[e_mask] if ew is not None else ew)
xt = self.relu(xt)
ei_drp, _, _ = dropout_node(ei, p=self.dropout_prot_p, num_nodes=target_x.shape[0],
ei_drp, e_mask, _ = dropout_node(ei, p=self.dropout_gnn, num_nodes=target_x.shape[0],
training=self.training)
# conv3
xt = self.pro_conv3(xt, ei_drp, ew)
xt = self.conv3(xt, ei_drp, ew[e_mask] if ew is not None else ew)
xt = self.relu(xt)

# flatten/pool
xt = gep(xt, data.batch) # global pooling
xt = self.relu(xt)
26 changes: 14 additions & 12 deletions src/models/ring3.py
Original file line number Diff line number Diff line change
@@ -211,20 +211,22 @@ def forward_pro(self, data):
self.edge_weight != 'binary') else None

#### Graph NN ####
target_x = self.relu(target_x)
# WARNING: dropout_node doesnt work if `ew` isnt also dropped out
# ei_drp, _, _ = dropout_node(ei, p=self.dropout_prot_p, num_nodes=target_x.shape[0],
# training=self.training)
# GNN layers:
xt = self.pro_conv1(target_x, ei, ew)
xt = self.relu(target_x)
ei_drp, e_mask, _ = dropout_node(ei, p=self.dropout_gnn, num_nodes=target_x.shape[0],
training=self.training)

# conv1
xt = self.conv1(xt, ei_drp, ew[e_mask] if ew is not None else ew)
xt = self.relu(xt)
# ei_drp, _, _ = dropout_node(ei, p=self.dropout_prot_p, num_nodes=target_x.shape[0],
# training=self.training)
xt = self.pro_conv2(xt, ei, ew)
ei_drp, e_mask, _ = dropout_node(ei, p=self.dropout_gnn, num_nodes=target_x.shape[0],
training=self.training)
# conv2
xt = self.conv2(xt, ei_drp, ew[e_mask] if ew is not None else ew)
xt = self.relu(xt)
# ei_drp, _, _ = dropout_node(ei, p=self.dropout_prot_p, num_nodes=target_x.shape[0],
# training=self.training)
xt = self.pro_conv3(xt, ei, ew)
ei_drp, e_mask, _ = dropout_node(ei, p=self.dropout_gnn, num_nodes=target_x.shape[0],
training=self.training)
# conv3
xt = self.conv3(xt, ei_drp, ew[e_mask] if ew is not None else ew)
xt = self.relu(xt)

# flatten/pool