Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

results: davis-shannon CV #46 #56

Merged
merged 3 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
{
"train_loss": [
1.9954961412380228,
1.1123259179788616,
1.0628154035987316,
1.0195641736376337,
0.9921986040354303,
0.9692184450681758,
0.9488197967360636,
0.940734565818468,
0.9137962934833053,
0.9058670597727847,
0.8862639353074858,
0.8738508635093417,
0.8678539732256821,
0.8494533173146501,
0.8396369166078821,
0.8195936667477399,
0.8189072838399263,
0.7990547935354763,
0.7925020887030173,
0.7814916852136365,
0.7716743027445225,
0.7743883602716753,
0.7596010671087862,
0.7445095050750317,
0.7552647592120729,
0.7467957078685276,
0.7425954988209811,
0.7552545020373258,
0.7356984313346643,
0.7358773813080207,
0.7404806569746111,
0.742597665621248,
0.7288618015735286,
0.7157279677128634,
0.7089876080574715,
0.7249864521807274,
0.7202833260994701,
0.7165739839995461,
0.7198128473854303,
0.7385284647065322,
0.7309793036150853,
0.7196460504958455,
0.7323818235880638,
0.7173825457807531,
0.7192918718256781,
0.722492459147707,
0.7188923951584549,
0.7100095051563288,
0.7068662808268471,
0.7108152348680042,
0.714867272776906,
0.7058873030200469,
0.7120912533698488,
0.7118631477294638,
0.6987585334469919,
0.7243739486037366,
0.7150540050959825,
0.7067155465895755,
0.7103592198325601,
0.7352949746562976,
0.7051169296924388,
0.7122998241236252,
0.7036621218497774,
0.7038248123732948,
0.7041913827106488,
0.7179739314700122,
0.7103856882179337,
0.7186218756951589,
0.7088906502565452,
0.7301404662347869,
0.7173615305198241,
0.7239452102839683,
0.7309273347851977,
0.7180232846268774,
0.7328828894079918,
0.727330649090288,
0.723351289419276,
0.7066450789405972,
0.7188684371242349,
0.7147803407459132,
0.7227345445086207,
0.7226420441658887,
0.7060261584933748,
0.7267584679977952,
0.7306896013628065,
0.7264216862767275,
0.7150264033652117,
0.7042778410116391,
0.7083051127290964,
0.7082227646595741,
0.712323370999533,
0.6989217922833245,
0.7070858577694908,
0.6963187569620467,
0.6941914890390054,
0.6874943163526137,
0.6820477824410902,
0.6814433606423899,
0.6704526256473191,
0.668551255891149,
0.6652297119453418,
0.661144870951508,
0.6565740831624881,
0.6341712325152042,
0.6606567606280467,
0.643499315848195,
0.6269802997976673
],
"val_loss": [
1.1018557784788425,
1.2535118042961158,
1.2187856405981676,
1.2760470970204116,
1.41563035914081,
1.3298013821858254,
1.2990144251502456,
1.2875569346781321,
1.3214407840344964,
1.2794410279208683,
1.319744275546987,
1.3295082387811117,
1.2344647842681908,
1.2835332232072076,
1.0444040316799903,
1.0201239069589967,
0.8929435918643346,
0.7577486736679334,
0.7760434191739929,
0.7643616874784692,
0.7721006384292469,
0.7641122302501998,
0.7538701039368072,
0.7168129901525325,
0.7462786953706723,
0.7326808813011554,
0.7197824846747108,
0.7124754123768844,
0.6933985606487016,
0.6861601175938165,
0.6809609269401866,
0.6760130381557794,
0.6701235365950857,
0.6707020423568752,
0.6693344269013775,
0.6909623357809388,
0.682249966654432,
0.7903768657953931,
0.7202901590730896,
0.7126389422280862,
0.7491712721756998,
0.7206178106454126,
0.6651267873543619,
0.6644183549333561,
0.7078817178355118,
0.6839179164416824,
0.7086175708553498,
0.7268013580026418,
0.7167100477489783,
0.8461501004681222,
0.7416352115425207,
0.8179783209400103,
0.8254071828286963,
0.8371337609761188,
0.9208616896025605,
0.97088622076816,
0.8756023296731961,
1.069556474998857,
0.9624758456108383,
0.859606430698928,
1.0121137776074156,
1.0009158287358066,
1.0606968613440353,
1.2167002147937243,
1.5448616045502435,
1.4364397041036014,
1.4733317930053558,
1.7337123687354337,
1.7061408191262044,
1.7082935327713469,
1.3157687119529706,
1.5440313264941878,
1.3741871246334527,
1.2794151946750831,
1.316956363407178,
1.0836214681243033,
1.4886349551956606,
1.446724139446541,
1.2230758408601685,
1.2782749346832423,
0.9854454206203861,
1.434194633492013,
0.9626338091592379,
1.1336133684715735,
1.1052076671369357,
1.161011285033176,
1.1970799049420233,
1.0871652196598265,
1.3620113483547702,
1.1277231313020353,
1.308906744110637,
1.5659388677027506,
0.8891798744897103,
1.1205095070768643,
1.0892479312224972,
0.7972673116668036,
0.8973261003230031,
1.087875069614574,
1.1212975865093975,
1.0956825447326477,
0.9849263438656184,
0.8330175951020216,
0.9281886269841179,
1.0319483467708424,
0.9451814371622703,
1.0264084387830237,
1.1390326986794022
],
"best_epoch": 44
}
8 changes: 7 additions & 1 deletion results/model_media/model_stats.csv
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,10 @@ DGM_davis4D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E,0.8372000580825498,0.70972544
DGM_davis1D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E,0.8259045903365158,0.6727098716351115,0.5969862332191551,0.4777166877928197,0.3703959785636244,0.6911705200547978
DGM_davis3D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E,0.842545467746034,0.7168413490003224,0.623085263864242,0.4159492493391408,0.3558861870119901,0.6449412758842008
DGM_davis2D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E,0.8219248012487749,0.6776357016678246,0.5900163003204868,0.4347317921133404,0.368497045519577,0.6593419386883717
DGM_davis0D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E,0.8118490670490435,0.681294197609124,0.5718484991508404,0.42935783222492513,0.35005875178466256,0.6552540211436517
DGM_davis0D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E,0.8118490670490435,0.681294197609124,0.5718484991508404,0.4293578322249251,0.3500587517846625,0.6552540211436517
DGM_davisD_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E_chemGPTLF,0.7370126629894659,0.4531619146470319,0.4447325194344097,0.6708082424596539,0.5559361077674945,0.8190288410426423
DGM_davis4D_shannonF_binaryE_64B_0.0001LR_0.4D_2000E,0.8262970114209492,0.6902457009938884,0.6196677619083105,0.4657317363498777,0.3800800188713128,0.6824454090620566
DGM_davis3D_shannonF_binaryE_64B_0.0001LR_0.4D_2000E,0.8387058574888743,0.7253796342930805,0.6397184730308061,0.418395340180838,0.3501044126973091,0.646834863145794
DGM_davis2D_shannonF_binaryE_64B_0.0001LR_0.4D_2000E,0.8496612461728558,0.7437110771416391,0.6565426331150037,0.391104889452298,0.3245244547831619,0.6253837937237405
DGM_davis1D_shannonF_binaryE_64B_0.0001LR_0.4D_2000E,0.8312754273934078,0.6999447666182411,0.6286994684318588,0.448984799382394,0.354222957680865,0.6700632801328499
DGM_davis0D_shannonF_binaryE_64B_0.0001LR_0.4D_2000E,0.8164859423418647,0.693536266777011,0.6013830893679215,0.46093123143684783,0.36639131860035223,0.6789191641402147
6 changes: 6 additions & 0 deletions results/model_media/model_stats_val.csv
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,9 @@ DGM_davis0D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E,0.8124934102025569,0.69098278
DGM_davis3D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E,0.8195586154273969,0.717774769130841,0.5720718176514452,0.3657891513377326,0.3382081135692242,0.6048050523414406
DGM_davis2D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E,0.8233405467261639,0.7040322378359025,0.5687852322102779,0.3543421502907676,0.3433588875976263,0.5952664531877868
DGM_davis1D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E,0.82715800662027,0.7083018179748569,0.6132060396459778,0.4895027135441312,0.3843769813986386,0.699644705221251
DGM_davisD_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E_chemGPTLF,0.7380363162506938,0.4574955465284291,0.4466001977109979,0.6649520759195586,0.551816601340746,0.8154459368465568
DGM_davis4D_shannonF_binaryE_64B_0.0001LR_0.4D_2000E,0.823990429488312,0.7117146794067782,0.600087030044339,0.4287954343848595,0.3553209463269589,0.654824735623861
DGM_davis3D_shannonF_binaryE_64B_0.0001LR_0.4D_2000E,0.827734580656503,0.702739642357159,0.5768866122321109,0.3738840686242872,0.3239032915296037,0.6114606026755013
DGM_davis2D_shannonF_binaryE_64B_0.0001LR_0.4D_2000E,0.8292712666904243,0.7228505402132716,0.5833504792222683,0.3717008538318646,0.3222885614393665,0.6096727432253016
DGM_davis1D_shannonF_binaryE_64B_0.0001LR_0.4D_2000E,0.8248946179115073,0.7385253536342127,0.598720558202774,0.404841846456497,0.3498974830004453,0.6362718337758612
DGM_davis0D_shannonF_binaryE_64B_0.0001LR_0.4D_2000E,0.8299184455370965,0.7248253045920677,0.6116541529841135,0.4180130896587404,0.3448278840382894,0.6465393179526984
50 changes: 50 additions & 0 deletions src/models/lig_mod.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from selfies import encoder

from src.models.prior_work import DGraphDTA
from src.models.pro_mod import EsmDTA


class ChemDTA(DGraphDTA):
def __init__(self, mol_output_dim=128, dropout=0.2, *args, **kwargs):
Expand Down Expand Up @@ -33,6 +35,54 @@ def forward_mol(self, data_mol):
mol_x = data_mol.lig_seq


# get selifes from smile
selfies = [encoder(s) for s in mol_x]

# get tokens
res = self.tokenizer(selfies, return_tensors="pt", padding=True)

res['input_ids'] = res['input_ids'].to(data_mol.x.device)
res['attention_mask'] = res['attention_mask'].to(data_mol.x.device)
res['token_type_ids'] = res['token_type_ids'].to(data_mol.x.device)

# model
model_output = self.model(**res).last_hidden_state

# flatten to [L, 128]
x = torch.mean(model_output, dim=1)

x = self.relu(self.mol_fc_g1(x))
x = self.dropout(x)
x = self.mol_fc_g2(x)
x = self.dropout(x)
return x

class ChemEsmDTA(EsmDTA):
def __init__(self, mol_output_dim=128, dropout=0.2, *args, **kwargs):
super(ChemEsmDTA, self).__init__(dropout=dropout, edge_weight_opt='binary', *args, **kwargs)

print('DGraphDTA Loaded')
num_features_mol = 128

#### ChemGPT ####

# get tokenizer and model
self.tokenizer = AutoTokenizer.from_pretrained("../hf_models/models--ncfrey--ChemGPT-4.7M/snapshots/7438a282460b3038e17a27e25b85b1376e9a23e2/", local_files_only=True)
self.model = AutoModel.from_pretrained("../hf_models/models--ncfrey--ChemGPT-4.7M/snapshots/7438a282460b3038e17a27e25b85b1376e9a23e2/", local_files_only=True)

self.model.requires_grad_(False) # freeze weights

# adding a new token '[PAD]' to the tokenizer, and then using it as the padding token
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})

self.mol_fc_g1 = nn.Linear(num_features_mol, 1024)
self.mol_fc_g2 = nn.Linear(1024, mol_output_dim)

def forward_mol(self, data_mol):
# get smiles list input
mol_x = data_mol.lig_seq


# get selifes from smile
selfies = [encoder(s) for s in mol_x]

Expand Down
2 changes: 1 addition & 1 deletion src/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
confProDy(verbosity='none') # stop printouts from prody


MODEL_OPT = ['DG', 'DGI', 'ED', 'EDA', 'EDI', 'EDAI', 'EAT', 'CD']
MODEL_OPT = ['DG', 'DGI', 'ED', 'EDA', 'EDI', 'EDAI', 'EAT', 'CD', 'CED']

STRUCT_EDGE_OPT = ['anm', 'af2', 'af2-anm'] # edge options that require structural info (pdbs)
EDGE_OPT = ['simple', 'binary'] + STRUCT_EDGE_OPT
Expand Down
10 changes: 8 additions & 2 deletions src/utils/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch.utils.data.distributed import DistributedSampler
from torch_geometric.loader import DataLoader

from src.models.lig_mod import ChemDTA
from src.models.lig_mod import ChemDTA, ChemEsmDTA
from src.models.pro_mod import EsmDTA, EsmAttentionDTA
from src.models.prior_work import DGraphDTA, DGraphDTAImproved
from src.data_processing.datasets import PDBbindDataset, DavisKibaDataset
Expand Down Expand Up @@ -99,7 +99,13 @@ def init_model(model:str, pro_feature:str, pro_edge:str, dropout:float,
elif model == 'CD':
# this model only needs sequence, no additional features.
model = ChemDTA(dropout=dropout)

elif model == 'CED':
model = ChemEsmDTA(esm_head='facebook/esm2_t6_8M_UR50D',
num_features_pro=320,
pro_emb_dim=512, # increase embedding size
dropout=dropout,
pro_feat='esm_only',
edge_weight_opt=pro_edge)
return model

@staticmethod
Expand Down