Skip to content

Commit

Permalink
fix(EsmDTA): added missing args and refactored attributes #113
Browse files Browse the repository at this point in the history
For retraining on new splits #113
  • Loading branch information
jyaacoub committed Jul 10, 2024
1 parent a0e4405 commit 9d561da
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 16 deletions.
2 changes: 1 addition & 1 deletion src/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@
'dropout_prot': 0.0,
'output_dim': 128,
'pro_extra_fc_lyr': False,
# 'pro_emb_dim': 512 # just for reference since this is the default for EDI
'pro_emb_dim': 512 # just for reference since this is the default for EDI
}
},
'davis_gvpl_esm_aflow': {
Expand Down
26 changes: 13 additions & 13 deletions src/models/esm_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class EsmDTA(BaseModel):
def __init__(self, esm_head:str='facebook/esm2_t6_8M_UR50D',
num_features_pro=320, pro_emb_dim=54, num_features_mol=78,
output_dim=128, dropout=0.2, pro_feat='esm_only', edge_weight_opt='binary',
dropout_prot=0.0, extra_profc_layer=False):
dropout_prot=0.0, pro_extra_fc_lyr=False):

super(EsmDTA, self).__init__(pro_feat, edge_weight_opt)

Expand All @@ -33,13 +33,13 @@ def __init__(self, esm_head:str='facebook/esm2_t6_8M_UR50D',
self.pro_conv2 = GCNConv(pro_emb_dim, pro_emb_dim * 2)
self.pro_conv3 = GCNConv(pro_emb_dim * 2, pro_emb_dim * 4)

if not extra_profc_layer:
self.pro_extra_fc_lyr = pro_extra_fc_lyr
if not pro_extra_fc_lyr:
self.pro_fc_g1 = nn.Linear(pro_emb_dim * 4, 1024)
else:
self.pro_fc_g1 = nn.Linear(pro_emb_dim * 4, pro_emb_dim * 2)
self.pro_fc_g1b = nn.Linear(pro_emb_dim * 2, 1024)

self.extra_profc_layer = extra_profc_layer
self.pro_fc_g2 = nn.Linear(1024, output_dim)


Expand All @@ -52,7 +52,7 @@ def __init__(self, esm_head:str='facebook/esm2_t6_8M_UR50D',
self.relu = nn.ReLU()
self.dropout = nn.Dropout(dropout)
# note that dropout for edge and nodes is handled by torch_geometric in forward pass
self.dropout_prot_p = dropout_prot
self.dropout_gnn = dropout_prot

# combined layers
self.fc1 = nn.Linear(2 * output_dim, 1024)
Expand Down Expand Up @@ -100,17 +100,17 @@ def forward_pro(self, data):
training=self.training)

# conv1
xt = self.conv1(xt, ei_drp, ew[e_mask] if ew is not None else ew)
xt = self.pro_conv1(xt, ei_drp, ew[e_mask] if ew is not None else ew)
xt = self.relu(xt)
ei_drp, e_mask, _ = dropout_node(ei, p=self.dropout_gnn, num_nodes=target_x.shape[0],
training=self.training)
# conv2
xt = self.conv2(xt, ei_drp, ew[e_mask] if ew is not None else ew)
xt = self.pro_conv2(xt, ei_drp, ew[e_mask] if ew is not None else ew)
xt = self.relu(xt)
ei_drp, e_mask, _ = dropout_node(ei, p=self.dropout_gnn, num_nodes=target_x.shape[0],
training=self.training)
# conv3
xt = self.conv3(xt, ei_drp, ew[e_mask] if ew is not None else ew)
xt = self.pro_conv3(xt, ei_drp, ew[e_mask] if ew is not None else ew)
xt = self.relu(xt)

# flatten/pool
Expand All @@ -123,7 +123,7 @@ def forward_pro(self, data):
xt = self.relu(xt)
xt = self.dropout(xt)

if self.extra_profc_layer:
if self.pro_extra_fc_lyr:
xt = self.pro_fc_g1b(xt)
xt = self.relu(xt)
xt = self.dropout(xt)
Expand Down Expand Up @@ -213,7 +213,7 @@ def __init__(self, esm_head: str = 'westlake-repl/SaProt_35M_AF2',
edge_weight_opt='binary', **kwargs):
super().__init__(esm_head, num_features_pro, pro_emb_dim, num_features_mol,
output_dim, dropout, pro_feat, edge_weight_opt,
extra_profc_layer=True,**kwargs)
pro_extra_fc_lyr=True,**kwargs)

# overwrite the forward_pro pass to account for new saprot model
def forward_pro(self, data):
Expand Down Expand Up @@ -262,17 +262,17 @@ def forward_pro(self, data):
training=self.training)

# conv1
xt = self.conv1(xt, ei_drp, ew[e_mask] if ew is not None else ew)
xt = self.pro_conv1(xt, ei_drp, ew[e_mask] if ew is not None else ew)
xt = self.relu(xt)
ei_drp, e_mask, _ = dropout_node(ei, p=self.dropout_gnn, num_nodes=target_x.shape[0],
training=self.training)
# conv2
xt = self.conv2(xt, ei_drp, ew[e_mask] if ew is not None else ew)
xt = self.pro_conv2(xt, ei_drp, ew[e_mask] if ew is not None else ew)
xt = self.relu(xt)
ei_drp, e_mask, _ = dropout_node(ei, p=self.dropout_gnn, num_nodes=target_x.shape[0],
training=self.training)
# conv3
xt = self.conv3(xt, ei_drp, ew[e_mask] if ew is not None else ew)
xt = self.pro_conv3(xt, ei_drp, ew[e_mask] if ew is not None else ew)
xt = self.relu(xt)

# flatten/pool
Expand All @@ -285,7 +285,7 @@ def forward_pro(self, data):
xt = self.relu(xt)
xt = self.dropout(xt)

if self.extra_profc_layer:
if self.pro_extra_fc_lyr:
xt = self.pro_fc_g1b(xt)
xt = self.relu(xt)
xt = self.dropout(xt)
Expand Down
2 changes: 1 addition & 1 deletion src/train_test/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def dtrain(args, unknown_args):
map_location=torch.device(f'cuda:{args.gpu}')))

model = nn.SyncBatchNorm.convert_sync_batchnorm(model) # use if model contains batchnorm.
model = nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
model = nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False)

torch.distributed.barrier() # Sync params across GPUs before training

Expand Down
8 changes: 7 additions & 1 deletion src/utils/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,15 @@ def init_model(model:str, pro_feature:str, pro_edge:str, dropout:float=0.2, **kw
pro_feat='all', # to include all feats (esm + 52 from DGraphDTA)
edge_weight_opt=pro_edge)
elif model == 'EDI':
pro_emb_dim = 512 # increase embedding size
if "pro_emb_dim" in kwargs:
pro_emb_dim = kwargs['pro_emb_dim']
logging.warning(f'pro_emb_dim changed from default of {512} to {pro_emb_dim} for model EDI')
del kwargs['pro_emb_dim']

model = EsmDTA(esm_head='facebook/esm2_t6_8M_UR50D',
num_features_pro=320,
pro_emb_dim=512, # increase embedding size
pro_emb_dim=pro_emb_dim,
dropout=dropout,
pro_feat='esm_only',
edge_weight_opt=pro_edge,
Expand Down

0 comments on commit 9d561da

Please sign in to comment.