From 9d561da29ac88b03ad8f91162441006f7abedf5a Mon Sep 17 00:00:00 2001 From: jyaacoub Date: Wed, 10 Jul 2024 11:39:40 -0400 Subject: [PATCH] fix(EsmDTA): added missing args and refactored attributes #113 For retraining on new splits #113 --- src/__init__.py | 2 +- src/models/esm_models.py | 26 +++++++++++++------------- src/train_test/distributed.py | 2 +- src/utils/loader.py | 8 +++++++- 4 files changed, 22 insertions(+), 16 deletions(-) diff --git a/src/__init__.py b/src/__init__.py index 90d806c0..66ae0c5d 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -92,7 +92,7 @@ 'dropout_prot': 0.0, 'output_dim': 128, 'pro_extra_fc_lyr': False, - # 'pro_emb_dim': 512 # just for reference since this is the default for EDI + 'pro_emb_dim': 512 # just for reference since this is the default for EDI } }, 'davis_gvpl_esm_aflow': { diff --git a/src/models/esm_models.py b/src/models/esm_models.py index 372b12c4..3c9504a9 100644 --- a/src/models/esm_models.py +++ b/src/models/esm_models.py @@ -18,7 +18,7 @@ class EsmDTA(BaseModel): def __init__(self, esm_head:str='facebook/esm2_t6_8M_UR50D', num_features_pro=320, pro_emb_dim=54, num_features_mol=78, output_dim=128, dropout=0.2, pro_feat='esm_only', edge_weight_opt='binary', - dropout_prot=0.0, extra_profc_layer=False): + dropout_prot=0.0, pro_extra_fc_lyr=False): super(EsmDTA, self).__init__(pro_feat, edge_weight_opt) @@ -33,13 +33,13 @@ def __init__(self, esm_head:str='facebook/esm2_t6_8M_UR50D', self.pro_conv2 = GCNConv(pro_emb_dim, pro_emb_dim * 2) self.pro_conv3 = GCNConv(pro_emb_dim * 2, pro_emb_dim * 4) - if not extra_profc_layer: + self.pro_extra_fc_lyr = pro_extra_fc_lyr + if not pro_extra_fc_lyr: self.pro_fc_g1 = nn.Linear(pro_emb_dim * 4, 1024) else: self.pro_fc_g1 = nn.Linear(pro_emb_dim * 4, pro_emb_dim * 2) self.pro_fc_g1b = nn.Linear(pro_emb_dim * 2, 1024) - self.extra_profc_layer = extra_profc_layer self.pro_fc_g2 = nn.Linear(1024, output_dim) @@ -52,7 +52,7 @@ def __init__(self, esm_head:str='facebook/esm2_t6_8M_UR50D', self.relu = nn.ReLU() self.dropout = nn.Dropout(dropout) # note that dropout for edge and nodes is handled by torch_geometric in forward pass - self.dropout_prot_p = dropout_prot + self.dropout_gnn = dropout_prot # combined layers self.fc1 = nn.Linear(2 * output_dim, 1024) @@ -100,17 +100,17 @@ def forward_pro(self, data): training=self.training) # conv1 - xt = self.conv1(xt, ei_drp, ew[e_mask] if ew is not None else ew) + xt = self.pro_conv1(xt, ei_drp, ew[e_mask] if ew is not None else ew) xt = self.relu(xt) ei_drp, e_mask, _ = dropout_node(ei, p=self.dropout_gnn, num_nodes=target_x.shape[0], training=self.training) # conv2 - xt = self.conv2(xt, ei_drp, ew[e_mask] if ew is not None else ew) + xt = self.pro_conv2(xt, ei_drp, ew[e_mask] if ew is not None else ew) xt = self.relu(xt) ei_drp, e_mask, _ = dropout_node(ei, p=self.dropout_gnn, num_nodes=target_x.shape[0], training=self.training) # conv3 - xt = self.conv3(xt, ei_drp, ew[e_mask] if ew is not None else ew) + xt = self.pro_conv3(xt, ei_drp, ew[e_mask] if ew is not None else ew) xt = self.relu(xt) # flatten/pool @@ -123,7 +123,7 @@ def forward_pro(self, data): xt = self.relu(xt) xt = self.dropout(xt) - if self.extra_profc_layer: + if self.pro_extra_fc_lyr: xt = self.pro_fc_g1b(xt) xt = self.relu(xt) xt = self.dropout(xt) @@ -213,7 +213,7 @@ def __init__(self, esm_head: str = 'westlake-repl/SaProt_35M_AF2', edge_weight_opt='binary', **kwargs): super().__init__(esm_head, num_features_pro, pro_emb_dim, num_features_mol, output_dim, dropout, pro_feat, edge_weight_opt, - extra_profc_layer=True,**kwargs) + pro_extra_fc_lyr=True,**kwargs) # overwrite the forward_pro pass to account for new saprot model def forward_pro(self, data): @@ -262,17 +262,17 @@ def forward_pro(self, data): training=self.training) # conv1 - xt = self.conv1(xt, ei_drp, ew[e_mask] if ew is not None else ew) + xt = self.pro_conv1(xt, ei_drp, ew[e_mask] if ew is not None else ew) xt = self.relu(xt) ei_drp, e_mask, _ = dropout_node(ei, p=self.dropout_gnn, num_nodes=target_x.shape[0], training=self.training) # conv2 - xt = self.conv2(xt, ei_drp, ew[e_mask] if ew is not None else ew) + xt = self.pro_conv2(xt, ei_drp, ew[e_mask] if ew is not None else ew) xt = self.relu(xt) ei_drp, e_mask, _ = dropout_node(ei, p=self.dropout_gnn, num_nodes=target_x.shape[0], training=self.training) # conv3 - xt = self.conv3(xt, ei_drp, ew[e_mask] if ew is not None else ew) + xt = self.pro_conv3(xt, ei_drp, ew[e_mask] if ew is not None else ew) xt = self.relu(xt) # flatten/pool @@ -285,7 +285,7 @@ def forward_pro(self, data): xt = self.relu(xt) xt = self.dropout(xt) - if self.extra_profc_layer: + if self.pro_extra_fc_lyr: xt = self.pro_fc_g1b(xt) xt = self.relu(xt) xt = self.dropout(xt) diff --git a/src/train_test/distributed.py b/src/train_test/distributed.py index 97984db2..6466f531 100644 --- a/src/train_test/distributed.py +++ b/src/train_test/distributed.py @@ -152,7 +152,7 @@ def dtrain(args, unknown_args): map_location=torch.device(f'cuda:{args.gpu}'))) model = nn.SyncBatchNorm.convert_sync_batchnorm(model) # use if model contains batchnorm. - model = nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) + model = nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False) torch.distributed.barrier() # Sync params across GPUs before training diff --git a/src/utils/loader.py b/src/utils/loader.py index bb1c2215..cbc646ae 100644 --- a/src/utils/loader.py +++ b/src/utils/loader.py @@ -115,9 +115,15 @@ def init_model(model:str, pro_feature:str, pro_edge:str, dropout:float=0.2, **kw pro_feat='all', # to include all feats (esm + 52 from DGraphDTA) edge_weight_opt=pro_edge) elif model == 'EDI': + pro_emb_dim = 512 # increase embedding size + if "pro_emb_dim" in kwargs: + pro_emb_dim = kwargs['pro_emb_dim'] + logging.warning(f'pro_emb_dim changed from default of {512} to {pro_emb_dim} for model EDI') + del kwargs['pro_emb_dim'] + model = EsmDTA(esm_head='facebook/esm2_t6_8M_UR50D', num_features_pro=320, - pro_emb_dim=512, # increase embedding size + pro_emb_dim=pro_emb_dim, dropout=dropout, pro_feat='esm_only', edge_weight_opt=pro_edge,