diff --git a/playground.py b/playground.py index a49b1fa6..1c2ce530 100644 --- a/playground.py +++ b/playground.py @@ -1,29 +1,148 @@ -#%% -from src.data_prep.downloaders import Downloader +#%% 1.Gather data for davis,kiba and pdbbind datasets +import os +import pandas as pd +import matplotlib.pyplot as plt +from src.analysis.utils import combine_dataset_pids +from src import config as cfg +df_prots = combine_dataset_pids(dbs=[cfg.DATA_OPT.davis, cfg.DATA_OPT.PDBbind], # just davis and pdbbind for now + subset='test') -Downloader.download_SDFs(['CHEMBL245769', 'CHEMBL55788'], save_dir='./') -#%% -import pandas as pd +#%% 2. Load TCGA data +df_tcga = pd.read_csv('../downloads/TCGA_ALL.maf', sep='\t') -df = pd.read_csv("../data/TCGA_BRCA_Mutations.csv") -df = df.loc[df['SWISSPROT'].dropna().index] -df[['Gene', 'SWISSPROT']].head() +#%% 3. Pre filtering +df_tcga = df_tcga[df_tcga['Variant_Classification'] == 'Missense_Mutation'] +df_tcga['seq_len'] = pd.to_numeric(df_tcga['Protein_position'].str.split('/').str[1]) +df_tcga = df_tcga[df_tcga['seq_len'] < 5000] +df_tcga['seq_len'].plot.hist(bins=100, title="sequence length histogram capped at 5K") +plt.show() +df_tcga = df_tcga[df_tcga['seq_len'] < 1200] +df_tcga['seq_len'].plot.hist(bins=100, title="sequence length after capped at 1.2K") -# %% -from src.utils.pdb import pdb2uniprot -df2 = pd.read_csv("../data/PlatinumDataset/nomsa_binary_original_binary/full/cleaned_XY.csv", index_col=0) -df2['pdb_id'] = df2.prot_id.str.split("_").str[0] +#%% 4. Merging df_prots with TCGA +df_tcga['uniprot'] = df_tcga['SWISSPROT'].str.split('.').str[0] + +dfm = df_tcga.merge(df_prots[df_prots.db != 'davis'], + left_on='uniprot', right_on='prot_id', how='inner') + +# for davis we have to merge on HUGO_SYMBOLS +dfm_davis = df_tcga.merge(df_prots[df_prots.db == 'davis'], + left_on='Hugo_Symbol', right_on='prot_id', how='inner') + +dfm = pd.concat([dfm,dfm_davis], axis=0) + +del dfm_davis # to save mem + +# %% 5. Post filtering step +# 5.1. Filter for only those sequences with matching sequence length (to get rid of nonmatched isoforms) +# seq_len_x is from tcga, seq_len_y is from our dataset +tmp = len(dfm) +# allow for some error due to missing amino acids from pdb file in PDBbind dataset +# - assumption here is that isoforms will differ by more than 50 amino acids +dfm = dfm[(dfm.seq_len_y <= dfm.seq_len_x) & (dfm.seq_len_x<= dfm.seq_len_y+50)] +print(f"Filter #1 (seq_len) : {tmp:5d} - {tmp-len(dfm):5d} = {len(dfm):5d}") + +# 5.2. Filter out those that dont have the same reference seq according to the "Protein_position" and "Amino_acids" col + +# Extract mutation location and reference amino acid from 'Protein_position' and 'Amino_acids' columns +dfm['mt_loc'] = pd.to_numeric(dfm['Protein_position'].str.split('/').str[0]) +dfm = dfm[dfm['mt_loc'] < dfm['seq_len_y']] +dfm[['ref_AA', 'mt_AA']] = dfm['Amino_acids'].str.split('/', expand=True) + +dfm['db_AA'] = dfm.apply(lambda row: row['prot_seq'][row['mt_loc']-1], axis=1) + +# Filter #2: Match proteins with the same reference amino acid at the mutation location +tmp = len(dfm) +dfm = dfm[dfm['db_AA'] == dfm['ref_AA']] +print(f"Filter #2 (ref_AA match): {tmp:5d} - {tmp-len(dfm):5d} = {len(dfm):5d}") +print('\n',dfm.db.value_counts()) + + +# %% final seq len distribution +n_bins = 25 +lengths = dfm.seq_len_x +fig, ax = plt.subplots(1, 1, figsize=(10, 5)) + +# Plot histogram +n, bins, patches = ax.hist(lengths, bins=n_bins, color='blue', alpha=0.7) +ax.set_title('TCGA final filtering for db matches') + +# Add counts to each bin +for count, x, patch in zip(n, bins, patches): + ax.text(x + 0.5, count, str(int(count)), ha='center', va='bottom') -uniprots = pdb2uniprot(df2.pdb_id.unique()) -df_pid = pd.DataFrame(list(uniprots.items()), columns=['pdbID', 'uniprot']) +ax.set_xlabel('Sequence Length') +ax.set_ylabel('Frequency') + +plt.tight_layout() +plt.show() + +# %% Getting updated sequences +def apply_mut(row): + ref_seq = list(row['prot_seq']) + ref_seq[row['mt_loc']-1] = row['mt_AA'] + return ''.join(ref_seq) + +dfm['mt_seq'] = dfm.apply(apply_mut, axis=1) + + +# %% +dfm.to_csv("/cluster/home/t122995uhn/projects/data/tcga/tcga_maf_davis_pdbbind.csv") # %% -# find overlap between uniprots and swissprot ids -uniprots = set(uniprots) # convert to set for faster lookup -df['uniprot'] = df['SWISSPROT'].str.split('.').str[0] -# Merge the original df with uni_to_pdb_df -df = df.merge(df_pid, on='uniprot', how='left') +from src.utils.seq_alignment import MSARunner +from tqdm import tqdm +import pandas as pd +import os + +DATA_DIR = '/cluster/home/t122995uhn/projects/data/tcga' +CSV = f'{DATA_DIR}/tcga_maf_davis_pdbbind.csv' +N_CPUS= 6 +NUM_ARRAYS = 10 +array_idx = 0#${SLURM_ARRAY_TASK_ID} + +df = pd.read_csv(CSV, index_col=0) +df.sort_values(by='seq_len_y', inplace=True) + # %% -df[~df['pdbID'].isna()] +for DB in df.db.unique(): + print('DB', DB) + RAW_DIR = f'{DATA_DIR}/{DB}' + # should already be unique if these are proteins mapped form tcga! + unique_df = df[df['db'] == DB] + ########################## Get job partition + partition_size = len(unique_df) / NUM_ARRAYS + start, end = int(array_idx*partition_size), int((array_idx+1)*partition_size) + + unique_df = unique_df[start:end] + + #################################### create fastas + fa_dir = os.path.join(RAW_DIR, f'{DB}_fa') + fasta_fp = lambda idx,pid: os.path.join(fa_dir, f"{idx}-{pid}.fasta") + os.makedirs(fa_dir, exist_ok=True) + for idx, (prot_id, pro_seq) in tqdm( + unique_df[['prot_id', 'prot_seq']].iterrows(), + desc='Creating fastas', + total=len(unique_df)): + with open(fasta_fp(idx,prot_id), "w") as f: + f.write(f">{prot_id},{idx},{DB}\n{pro_seq}") + + ##################################### Run hhblits + aln_dir = os.path.join(RAW_DIR, f'{DB}_aln') + aln_fp = lambda idx,pid: os.path.join(aln_dir, f"{idx}-{pid}.a3m") + os.makedirs(aln_dir, exist_ok=True) + + # finally running + for idx, (prot_id, pro_seq) in tqdm( + unique_df[['prot_id', 'mt_seq']].iterrows(), + desc='Running hhblits', + total=len(unique_df)): + in_fp = fasta_fp(idx,prot_id) + out_fp = aln_fp(idx,prot_id) + + if not os.path.isfile(out_fp): + print(MSARunner.hhblits(in_fp, out_fp, n_cpus=N_CPUS, return_cmd=True)) + break + # %% diff --git a/rayTrain_Tune.py b/rayTrain_Tune.py index 103dcdee..89f8e815 100644 --- a/rayTrain_Tune.py +++ b/rayTrain_Tune.py @@ -79,38 +79,14 @@ def train_func(config): torch.cuda.device_count(), "devices") print("CUDA VERSION:", torch.__version__) - search_space = { - ## constants: - "epochs": 20, - "model": cfg.MODEL_OPT.DG, - "dataset": cfg.DATA_OPT.PDBbind, - "feature_opt": cfg.PRO_FEAT_OPT.nomsa, - "edge_opt": cfg.PRO_EDGE_OPT.aflow, - "lig_feat_opt": cfg.LIG_FEAT_OPT.original, - "lig_edge_opt": cfg.LIG_EDGE_OPT.binary, - - "fold_selection": 0, - "save_checkpoint": False, - - ## hyperparameters to tune: - "lr": ray.tune.loguniform(1e-5, 1e-3), - "batch_size": ray.tune.choice([32, 64, 128]), # local batch size - - # model architecture hyperparams - "architecture_kwargs":{ - "dropout": ray.tune.uniform(0.0, 0.5), - "output_dim": ray.tune.choice([128, 256, 512]), - } - } - # 'gvpL_aflow': ('nomsa', 'aflow', 'gvp', 'binary'): # search_space = { # ## constants: # "epochs": 20, - # "model": cfg.MODEL_OPT.GVPL, + # "model": cfg.MODEL_OPT.DG, # "dataset": cfg.DATA_OPT.PDBbind, # "feature_opt": cfg.PRO_FEAT_OPT.nomsa, # "edge_opt": cfg.PRO_EDGE_OPT.aflow, - # "lig_feat_opt": cfg.LIG_FEAT_OPT.gvp, + # "lig_feat_opt": cfg.LIG_FEAT_OPT.original, # "lig_edge_opt": cfg.LIG_EDGE_OPT.binary, # "fold_selection": 0, @@ -126,6 +102,30 @@ def train_func(config): # "output_dim": ray.tune.choice([128, 256, 512]), # } # } + # 'gvpL_aflow': ('nomsa', 'aflow', 'gvp', 'binary'): + search_space = { + ## constants: + "epochs": 20, + "model": cfg.MODEL_OPT.GVPL, + "dataset": cfg.DATA_OPT.davis, + "feature_opt": cfg.PRO_FEAT_OPT.nomsa, + "edge_opt": cfg.PRO_EDGE_OPT.aflow, + "lig_feat_opt": cfg.LIG_FEAT_OPT.gvp, + "lig_edge_opt": cfg.LIG_EDGE_OPT.binary, + + "fold_selection": 0, + "save_checkpoint": False, + + ## hyperparameters to tune: + "lr": ray.tune.loguniform(1e-5, 1e-3), + "batch_size": ray.tune.choice([32, 64, 128]), # local batch size + + # model architecture hyperparams + "architecture_kwargs":{ + "dropout": ray.tune.uniform(0.0, 0.5), + "output_dim": ray.tune.choice([128, 256, 512]), + } + } # search space for GVPL_RNG MODEL: # search_space = { # ## constants: diff --git a/results/model_media/model_stats.csv b/results/model_media/model_stats.csv index 0bddc7fe..4e3821ae 100644 --- a/results/model_media/model_stats.csv +++ b/results/model_media/model_stats.csv @@ -176,4 +176,9 @@ DGM_PDBbind1D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_20 DGM_PDBbind3D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_2000E_originalLF_binaryLE,0.7015552107083158,0.6057001165731564,0.569729312002546,2.358233326230585,1.2048145031929016,1.5356540385876585 DGM_PDBbind4D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_2000E_originalLF_binaryLE,0.7085153085744522,0.6208329976682688,0.5859156875580817,2.2240677925526744,1.1722256461779277,1.4913308796349234 DGM_PDBbind2D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_2000E_originalLF_binaryLE,0.6947751328102371,0.5816010010184015,0.5528265670239861,2.3897321791041333,1.2105133893376303,1.545875861479224 -DGM_PDBbind0D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_2000E_originalLF_binaryLE,0.6855330290119371,0.5726521315119578,0.5325884299986027,2.4758763933999823,1.2316968268061441,1.57349178370908 +DGM_PDBbind0D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_2000E_originalLF_binaryLE,0.6855330290119371,0.5726521315119578,0.5325884299986027,2.4758763933999823,1.231696826806144,1.57349178370908 +GVPLM_davis0D_nomsaF_aflowE_128B_0.0001360163557088453LR_0.027175922988649594D_2000E_gvpLF_binaryLE,0.7581108896345126,0.48942831390709474,0.45529634961852566,0.5072034546494222,0.4161494350523831,0.7121821779919953 +GVPLM_davis1D_nomsaF_aflowE_128B_0.0001360163557088453LR_0.027175922988649594D_2000E_gvpLF_binaryLE,0.768273891734927,0.550820575655664,0.4743629409308578,0.4386155850598172,0.3857342382535988,0.6622805939024767 +GVPLM_davis3D_nomsaF_aflowE_128B_0.0001360163557088453LR_0.027175922988649594D_2000E_gvpLF_binaryLE,0.7603887296892994,0.5280132204660057,0.4597706307310624,0.4791434307834897,0.3716412002266697,0.6922018714099881 +GVPLM_davis2D_nomsaF_aflowE_128B_0.0001360163557088453LR_0.027175922988649594D_2000E_gvpLF_binaryLE,0.7694890904708525,0.5217685104759502,0.4753118765332927,0.4673188139402668,0.4174866665020387,0.6836072073495618 +GVPLM_davis4D_nomsaF_aflowE_128B_0.0001360163557088453LR_0.027175922988649594D_2000E_gvpLF_binaryLE,0.760520722940427,0.5078128106282308,0.4629361403886756,0.5235965111752061,0.3783202273117297,0.723599689866715 diff --git a/results/model_media/model_stats_val.csv b/results/model_media/model_stats_val.csv index 602be48f..d018c6b1 100644 --- a/results/model_media/model_stats_val.csv +++ b/results/model_media/model_stats_val.csv @@ -156,3 +156,8 @@ DGM_PDBbind3D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_20 DGM_PDBbind4D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_2000E_originalLF_binaryLE,0.6980325998261337,0.5867948698567493,0.5613217730433377,2.504824019579536,1.2573465726293105,1.5826635838293417 DGM_PDBbind2D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_2000E_originalLF_binaryLE,0.6760856277267194,0.5188577783623333,0.5051689763782492,2.7333317052016373,1.2894853849669785,1.653279076623677 DGM_PDBbind0D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_2000E_originalLF_binaryLE,0.678631965986838,0.5169422908392893,0.5136610285075189,2.659271637069346,1.2811931355709805,1.6307273337591868 +GVPLM_davis0D_nomsaF_aflowE_128B_0.0001360163557088453LR_0.027175922988649594D_2000E_gvpLF_binaryLE,0.7727048764494124,0.5829543940410362,0.46680064762373447,0.4654286995546293,0.3997238369103796,0.6822233501974476 +GVPLM_davis1D_nomsaF_aflowE_128B_0.0001360163557088453LR_0.027175922988649594D_2000E_gvpLF_binaryLE,0.7767585519522199,0.5805240837724788,0.4946235296015073,0.4138118485637668,0.3817763907346182,0.6432820909708017 +GVPLM_davis3D_nomsaF_aflowE_128B_0.0001360163557088453LR_0.027175922988649594D_2000E_gvpLF_binaryLE,0.8202768328439823,0.6402255623870499,0.5435232833589796,0.3676188640274792,0.3097654533567193,0.6063158121206136 +GVPLM_davis2D_nomsaF_aflowE_128B_0.0001360163557088453LR_0.027175922988649594D_2000E_gvpLF_binaryLE,0.7854232789039847,0.5996614387831698,0.527865354318397,0.4509265081187592,0.4061779738808213,0.67151061653466 +GVPLM_davis4D_nomsaF_aflowE_128B_0.0001360163557088453LR_0.027175922988649594D_2000E_gvpLF_binaryLE,0.7849321063417397,0.6186046172952341,0.5194671423553721,0.4498645421549498,0.3525368623308257,0.670719421334249 diff --git a/src/analysis/figures.py b/src/analysis/figures.py index adb39e8f..a0380f80 100644 --- a/src/analysis/figures.py +++ b/src/analysis/figures.py @@ -344,7 +344,7 @@ def fig6_protein_appearance(datasets=['kiba', 'PDBbind'], show=False): def fig_combined(df, datasets=['PDBbind','davis', 'kiba'], metrics=['cindex', 'mse'], fig_callable=fig4_pro_feat_violin, fig_scale=(5,4), - show=False, **kwargs): + show=False, title_postfix='', **kwargs): # Create subplots with datasets as columns and metrics as rows fig, axes = plt.subplots(len(metrics), len(datasets), figsize=(fig_scale[0]*len(datasets), @@ -362,7 +362,7 @@ def fig_combined(df, datasets=['PDBbind','davis', 'kiba'], metrics=['cindex', 'm # Add titles only to the top row and left column if j == 0: - ax.set_title(f'{dataset}') + ax.set_title(f'{dataset}{title_postfix}') ax.set_xlabel('') ax.set_xticklabels([]) elif j < len(metrics)-1: # middle row diff --git a/src/analysis/utils.py b/src/analysis/utils.py index 6ad8a4dd..34feb2ef 100644 --- a/src/analysis/utils.py +++ b/src/analysis/utils.py @@ -3,6 +3,7 @@ from scipy.stats import ttest_ind import pandas as pd import numpy as np +from src import config as cfg def count_missing_res(pdb_file: str) -> Tuple[int,int]: @@ -117,6 +118,32 @@ def generate_markdown(results, names=None, verbose=False, thresh_sig=False, cind if verbose: print(md_output) return md_table +def combine_dataset_pids(data_dir=cfg.DATA_ROOT, + dbs=[cfg.DATA_OPT.davis, cfg.DATA_OPT.kiba, cfg.DATA_OPT.PDBbind], + target='nomsa_aflow_gvp_binary', subset='full', + xy='cleaned_XY.csv'): + df_all = None + dir_p = {'davis':'DavisKibaDataset/davis', + 'kiba': 'DavisKibaDataset/kiba', + 'PDBbind': 'PDBbindDataset'} + dbs = {d.value: f'{data_dir}/{dir_p[d]}/{target}/{subset}/{xy}' for d in dbs} + + for DB, fp in dbs.items(): + print(DB, fp) + df = pd.read_csv(fp, index_col=0) + df['pdb_id'] = df.prot_id.str.split("_").str[0] + df = df[['prot_id', 'prot_seq']].drop_duplicates(subset='prot_id') + + df['seq_len'] = df['prot_seq'].str.len() + df['db'] = DB + df.reset_index(inplace=True) + + df = df[['db','code', 'prot_id', 'seq_len', 'prot_seq']] # reorder them. + df.index.name = 'db_idx' + df_all = df if df_all is None else pd.concat([df_all, df], axis=0) + + return df_all + if __name__ == '__main__': #NOTE: the following is code for stratifying AutoDock Vina results by # missing residues to identify if there is a correlation between missing diff --git a/src/data_prep/datasets.py b/src/data_prep/datasets.py index b873c98a..da504a8d 100644 --- a/src/data_prep/datasets.py +++ b/src/data_prep/datasets.py @@ -447,9 +447,10 @@ def _create_ligand_graphs(self, df:pd.DataFrame, node_feat, edge): processed_ligs = {} errors = [] if node_feat == cfg.LIG_FEAT_OPT.gvp: - for code, lig_seq in tqdm(df['SMILE'].items(), desc='Creating ligand graphs', + for code, (lig_seq, lig_id) in tqdm(df[['SMILE', 'lig_id']].iterrows(), desc='Creating ligand graphs', total=len(df)): - processed_ligs[lig_seq] = GVPFeaturesLigand().featurize_as_graph(self.sdf_p(code)) + processed_ligs[lig_seq] = GVPFeaturesLigand().featurize_as_graph(self.sdf_p(code, + lig_id=lig_id)) return processed_ligs for lig_seq in tqdm(df['SMILE'].unique(), desc='Creating ligand graphs'): @@ -506,12 +507,12 @@ def file_real(fp): logging.info('Created cleaned_XY.csv file') - ###### Get Protein Graphs ###### - processed_prots = self._create_protein_graphs(self.df, self.pro_feat_opt, self.pro_edge_opt) - ###### Get Ligand Graphs ###### processed_ligs = self._create_ligand_graphs(self.df, self.ligand_feature, self.ligand_edge) + ###### Get Protein Graphs ###### + processed_prots = self._create_protein_graphs(self.df, self.pro_feat_opt, self.pro_edge_opt) + ###### Save ###### logging.info('Saving...') torch.save(processed_prots, self.processed_paths[1]) @@ -562,7 +563,7 @@ def af_conf_files(self, pid) -> list[str]|str: def pdb_p(self, code): return os.path.join(self.data_root, code, f'{code}_protein.pdb') - def sdf_p(self, code): + def sdf_p(self, code, **kwargs): return os.path.join(self.data_root, code, f'{code}_ligand.sdf') def cmap_p(self, pid): @@ -632,7 +633,8 @@ def pre_process(self): # Get binding data: df_binding = PDBbindProcessor.get_binding_data(self.raw_paths[0]) # _data.2020 - df_binding.drop(columns=['resolution', 'release_year', 'lig_name'], inplace=True) + df_binding.drop(columns=['resolution', 'release_year'], inplace=True) + df_binding.rename({'lig_name':'lig_id'}, inplace=True) pdb_codes = df_binding.index # pdbcodes ############## validating codes ############# @@ -728,27 +730,32 @@ def __init__(self, save_root=f'{cfg.DATA_ROOT}/DavisKibaDataset/', feature_opt=feature_opt, *args, **kwargs) def af_conf_files(self, code) -> list[str]: - """Davis has issues since prot_ids are not really that unique""" - # removing () from string since file names cannot include them and localcolabfold replaces them with _ + """Davis has issues since prot_ids are not really unique""" + if self.alphaflow: + fp = f'{self.af_conf_dir}/{code}.pdb' + fp = fp if os.path.exists(fp) else None + return fp + + # removing () from string since localcolabfold replaces them with _ code = re.sub(r'[()]', '_', code) + # localcolabfold has 'unrelaxed' as the first part after the code/ID. # output must be in out directory - - # TODO: fix this to work with alphaflow outputs - # if self.alphaflow: - # fp = f'{self.af_conf_dir}/{pid}.pdb' - # fp = fp if os.path.exists(fp) else None - # return fp - - return glob(f'{self.af_conf_dir}/out?/{code}_unrelaxed*_alphafold2_ptm_model_*.pdb') + def sdf_p(self, code, lig_id): + # code is just a placeholder since other datasets (pdbbind) need it. + return os.path.join(self.data_root, 'lig_sdf', f'{lig_id}.sdf') + def pdb_p(self, code, safe=True): code = re.sub(r'[()]', '_', code) # davis and kiba dont have their own structures so this must be made using # af or some other method beforehand. if (self.pro_edge_opt not in cfg.OPT_REQUIRES_PDB) and \ (self.pro_feat_opt not in cfg.OPT_REQUIRES_PDB): return None + + if self.alphaflow: + return self.af_conf_files(code) file = glob(os.path.join(self.af_conf_dir, f'highQ/{code}_unrelaxed_rank_001*.pdb')) # should only be one file @@ -862,11 +869,13 @@ def pre_process(self): prot_seq = list(prot_seq.values()) # get ligand sequences (order is important since they are indexed by row in affinity matrix): - ligand_seq = json.load(open(self.raw_paths[1], 'r'), + lig_dict = json.load(open(self.raw_paths[1], 'r'), object_hook=OrderedDict) - ligand_seq = list(ligand_seq.values()) + lig_id = list(lig_dict.keys()) + ligand_seq = list(lig_dict.values()) # Get binding data: + # for davis this matrix should contain no nan values affinity_mat = pickle.load(open(self.raw_paths[2], 'rb'), encoding='latin1') lig_r, prot_c = np.where(~np.isnan(affinity_mat)) # index values corresponding to non-nan values @@ -895,7 +904,7 @@ def pre_process(self): # af_conf_files will be different for alphaflow (single file) no_confs = [c for c in codes if ( (self.pdb_p(c, safe=False) is None) or # no highQ structure - (Chain.get_model_count(self.af_conf_files(c)) < 5))] # only if not for foldseek + (Chain.get_model_count(self.af_conf_files(c)) < 5))] # single file needs Chain.get_model_count else: no_confs = [c for c in codes if ( (self.pdb_p(c, safe=False) is None) or # no highQ structure @@ -907,7 +916,7 @@ def pre_process(self): print(f'Number of codes missing af2 configurations: {len(no_confs)} / {len(codes)}') invalid_codes = set(no_aln + no_cmap + no_confs) - # filtering out invalid codes: + # filtering out invalid codes and storing their index vals. lig_r = [r for i,r in enumerate(lig_r) if codes[prot_c[i]] not in invalid_codes] prot_c = [c for c in prot_c if codes[c] not in invalid_codes] @@ -916,7 +925,8 @@ def pre_process(self): # creating binding dataframe: # code,SMILE,pkd,prot_seq df = pd.DataFrame({ - 'code': [codes[c] for c in prot_c], + 'code': [codes[c] for c in prot_c], + 'lig_id': [lig_id[r] for r in lig_r], 'SMILE': [ligand_seq[r] for r in lig_r], 'prot_seq': [prot_seq[c] for c in prot_c] }) @@ -993,9 +1003,8 @@ def af_conf_files(self, pid, map=True) -> list[str]: return glob(f'{self.af_conf_dir}/{pid}_model_*.pdb') - def sdf_p(self, code) -> str: + def sdf_p(self, code, lig_id) -> str: """Needed for gvp ligand branch (uses coordinate info)""" - lig_id = self.df.loc[code].lig_id return os.path.join(self.raw_paths[2], f'{lig_id}.sdf') def pdb_p(self, code): diff --git a/src/data_prep/downloaders.py b/src/data_prep/downloaders.py index 876a915d..60307818 100644 --- a/src/data_prep/downloaders.py +++ b/src/data_prep/downloaders.py @@ -1,10 +1,11 @@ -from typing import Iterable, List -import os +from typing import Iterable, List, Callable +import os, time import requests as r from io import StringIO from urllib.parse import quote from tqdm import tqdm +from concurrent.futures import ThreadPoolExecutor, as_completed class Downloader: @staticmethod @@ -77,55 +78,78 @@ def get_file_obj(ID: str, url=lambda x: f'https://files.rcsb.org/download/{x}.pd The file object. """ return StringIO(r.get(url(ID)).text) + + @staticmethod + def download_single_file(id: str, save_path: Callable[[str], str], url: Callable[[str], str], + url_backup: Callable[[str], str], max_retries=4) -> tuple: + """ + Helper function to download a single file. + """ + fp = save_path(id) + # Check if the file already exists + if os.path.isfile(fp): + return id, 'already downloaded' + + os.makedirs(os.path.dirname(fp), exist_ok=True) + def fetch_url(url): + retries = 0 + while retries <= max_retries: + resp = r.get(url) + if resp.status_code == 503: + wait_time = 2 ** retries # Exponential backoff + time.sleep(wait_time) + retries += 1 + else: + return resp + return resp # Return the last response after exhausting retries + + resp = fetch_url(url(id)) + if resp.status_code >= 400 and url_backup: + resp = fetch_url(url_backup(id)) + + if resp.status_code >= 400: + return id, resp.status_code + else: + with open(fp, 'w') as f: + f.write(resp.text) + return id, 'downloaded' + @staticmethod def download(IDs: Iterable[str], - save_path=lambda x:'./data/structures/ligands/{x}.sdf', - url=lambda x: f'https://files.rcsb.org/ligands/download/{x}_ideal.sdf', - tqdm_desc='Downloading files', - tqdm_disable=False) -> dict: + save_path=lambda x: f'./data/structures/ligands/{x}.sdf', + url=lambda x: f'https://files.rcsb.org/ligands/download/{x}_ideal.sdf', + tqdm_desc='Downloading files', + url_backup=None, # for if the first url fails + tqdm_disable=False, + max_workers=None) -> dict: """ - Generalized download function for downloading any file type from any site. + Generalized multithreaded download function for downloading any file type from any site. - URL and save_path are passed in as callable functions which accept a string (the ID) - and return a url or save path for that file. - Parameters ---------- - `IDs` : Iterable[str] + IDs : Iterable[str] List of IDs to download - `save_path` : Callable[[str], str], optional + save_path : Callable[[str], str], optional Callable fn that returns the save path for file, by default lambda x :'./data/structures/ligands/{x}.sdf' - `url` : Callable[[str], str], optional + url : Callable[[str], str], optional Callable fn that returns the url to download file, by default lambda x :f'https://files.rcsb.org/ligands/download/{x}_ideal.sdf' - + max_workers : int, optional + Number of threads to use for downloading files. + Returns ------- dict - status of each ID (whether it was downloaded or not) + status of each ID (whether it was downloaded or not) """ - ID_status = {} - for id in tqdm(IDs, tqdm_desc, disable=tqdm_disable): - if id in ID_status: continue - fp = save_path(id) - # checking to make sure that we didnt already download file - if os.path.isfile(fp): - ID_status[id] = 'already downloaded' - continue - - os.makedirs(os.path.dirname(fp), - exist_ok=True) - - resp = r.get(url(id)) - if resp.status_code >= 400: - ID_status[id] = resp.status_code - else: - with open(fp, 'w') as f: - f.write(resp.text) - ID_status[id] = 'downloaded' + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = {executor.submit(Downloader.download_single_file, id, save_path, url, url_backup): id for id in IDs} + for future in tqdm(as_completed(futures), desc=tqdm_desc, total=len(IDs), disable=tqdm_disable): + id, status = future.result() + ID_status[id] = status return ID_status @staticmethod @@ -173,6 +197,7 @@ def download_SDFs(ligand_ids: List[str], 'CHEMBL': lambda x: f'https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/xref/registryID/{x}/record/sdf?record_type=3d', 'name': lambda x: f'https://files.rcsb.org/ligands/download/{x}_ideal.sdf'} + lid = ligand_ids[0] if lid.isdigit(): url = urls['CID'] @@ -182,7 +207,8 @@ def download_SDFs(ligand_ids: List[str], url = urls['name'] save_path = lambda x: os.path.join(save_dir, f'{x}.sdf') - return Downloader.download(ligand_ids, save_path=save_path, url=url, + url_backup=lambda x: url(x).split('?')[0] # fallback to 2d conformer structure + return Downloader.download(ligand_ids, save_path=save_path, url=url, url_backup=url_backup, tqdm_desc='Downloading ligand sdfs', **kwargs) if __name__ == '__main__': diff --git a/src/data_prep/init_dataset.py b/src/data_prep/init_dataset.py index f9529476..d9613040 100644 --- a/src/data_prep/init_dataset.py +++ b/src/data_prep/init_dataset.py @@ -12,9 +12,9 @@ from src.train_test.splitting import train_val_test_split, balanced_kfold_split def create_datasets(data_opt:list[str]|str, feat_opt:list[str]|str, edge_opt:list[str]|str, - pro_overlap:bool=False, data_root:str=cfg.DATA_ROOT, ligand_features:list[str]=['original'], ligand_edges:list[str]=['binary'], + pro_overlap:bool=False, data_root:str=cfg.DATA_ROOT, k_folds:int=None, random_seed:int=0, train_split:float=0.8, @@ -62,7 +62,7 @@ def create_datasets(data_opt:list[str]|str, feat_opt:list[str]|str, edge_opt:lis create_pfm_np_files(f'{data_root}/{data}/aln', processes=4) if 'af_conf_dir' not in kwargs: if EDGE in cfg.OPT_REQUIRES_AFLOW_CONF: - kwargs['af_conf_dir'] = f'/{data}/alphaflow_io/out_pdb_MD-distilled/' + kwargs['af_conf_dir'] = f'{data_root}/{data}/alphaflow_io/out_pdb_MD-distilled/' else: kwargs['af_conf_dir'] = f'../colabfold/{data}_af2_out/' @@ -144,12 +144,3 @@ def create_datasets(data_opt:list[str]|str, feat_opt:list[str]|str, edge_opt:lis dataset.save_subset(test_loader, subset_names[2]) del dataset # free up memory - -if __name__ == "__main__": - create_datasets(data_opt=['davis'], # 'PDBbind' 'kiba' davis - feat_opt=['nomsa'], # nomsa 'msa' 'shannon'] - edge_opt=['af2-anm'], # for anm and af2 we need structures! (see colabfold-highQ) - pro_overlap=False, - #/home/jyaacoub/projects/data/ - #'/cluster/home/t122995uhn/projects/data/' - data_root='/cluster/home/t122995uhn/projects/data/') \ No newline at end of file diff --git a/src/data_prep/processors.py b/src/data_prep/processors.py index 157f853e..84d21a58 100644 --- a/src/data_prep/processors.py +++ b/src/data_prep/processors.py @@ -79,6 +79,24 @@ def csv_to_fasta_dir(csv_or_df:str or pd.DataFrame, out_dir:str): with open(fasta_fp, "w") as f: f.write(f">{prot_id}\n{pro_seq}") + @staticmethod + def fasta_to_df(fp) -> pd.DataFrame: + d = {} + with open(fp, 'r') as f: + line = f.readline() + while line: + if line.startswith('>'): + desc = line[1:].strip() + seq = '' + line = f.readline() + while line and not line.startswith('>'): + seq += line.strip() + line = f.readline() + d[desc] = seq + else: + line = f.readline() + return pd.DataFrame.from_records(list(d.items()), columns=['names', 'prot_seq']) + @staticmethod def fasta_to_aln_file(in_fp, out_fp): """ diff --git a/src/utils/config.py b/src/utils/config.py index aacc236b..9e914fa8 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -123,6 +123,13 @@ class LIG_FEAT_OPT(StringEnum): MMSEQ2_BIN = f'{Path.home()}/lib/mmseqs/bin/mmseqs' RING3_BIN = f'{Path.home()}/lib/ring-3.0.0/ring/bin/ring' +if 'uhnh4h' in DOMAIN_NAME: + UniRef_dir = '/cluster/projects/kumargroup/sequence_databases/UniRef30_2020_06/UniRef30_2020_06' + hhsuite_bin_dir = '/cluster/tools/software/centos7/hhsuite/3.3.0/bin' +else: + UniRef_dir = '' + hhsuite_bin_dir = '' + ########################### # LOGGING STUFF: diff --git a/src/utils/residue.py b/src/utils/residue.py index 6f493095..c4c96811 100644 --- a/src/utils/residue.py +++ b/src/utils/residue.py @@ -194,6 +194,15 @@ def __len__(self): def __repr__(self): return f'' + def __iter__(self): + """ + Generator function to iterate through all chains and return their sequences. + """ + for cid in self._chains.keys(): + self.t_chain = cid + self._seq = None # Reset sequence cache + yield cid, self.getSequence() + def reset_attributes(self): self._seq = None self._coords = None @@ -228,8 +237,8 @@ def t_chain(self, chain_ID:str): assert len(chain_ID) == 1, f"Invalid chain ID {chain_ID}" # reset so that they are updated on next getter calls self.reset_attributes() - self._t_chain = chain_ID - + self._t_chain = chain_ID + @property def sequence(self) -> str: return self.getSequence() diff --git a/src/utils/seq_alignment.py b/src/utils/seq_alignment.py index f2677a31..937085d7 100644 --- a/src/utils/seq_alignment.py +++ b/src/utils/seq_alignment.py @@ -9,10 +9,9 @@ class MSARunner(Processor): - hhsuite_bin_dir = '/cluster/tools/software/centos7/hhsuite/3.3.0/bin' - bin_hhblits = f'{hhsuite_bin_dir}/hhblits' - bin_hhfilter = f'{hhsuite_bin_dir}/hhfilter' - UniRef_dir = '/cluster/projects/kumargroup/mslobody/Protein_Communities/01_MSA/databases/UniRef30_2020_06' + UniRef_dir = cfg.UniRef_dir + bin_hhblits = f'{cfg.hhsuite_bin_dir}/hhblits' + bin_hhfilter = f'{cfg.hhsuite_bin_dir}/hhfilter' @staticmethod def hhblits(f_in:str, f_out:str, n_cpus=6, n_iter:int=2, @@ -216,20 +215,33 @@ def read_clusttsv_output(tsv_path:str): if __name__ == '__main__': - from src.data_prep.datasets import BaseDataset + from src.utils.seq_alignment import MSARunner + from tqdm import tqdm + import pandas as pd + import os csv = '/cluster/home/t122995uhn/projects/data/PlatinumDataset/nomsa_binary/full/XY.csv' df = pd.read_csv(csv, index_col=0) #################### Get unique proteins: - unique_df = BaseDataset.get_unique_prots(df) - + # sorting by sequence length before dropping so that we keep the longest protein sequence instead of just the first. + df['seq_len'] = df['prot_seq'].str.len() + df = df.sort_values(by='seq_len', ascending=False) + + # create new numerated index col for ensuring the first unique uniprotID is fetched properly + df.reset_index(drop=False, inplace=True) + unique_pro = df[['prot_id']].drop_duplicates(keep='first') + + # reverting index to code-based index + df.set_index('code', inplace=True) + unique_df = df.iloc[unique_pro.index] + ########################## Get job partition - num_arrays = 100 + NUM_ARRAYS = 100 array_idx = 0#${SLURM_ARRAY_TASK_ID} - partition_size = len(unique_df) / num_arrays + partition_size = len(unique_df) / NUM_ARRAYS start, end = int(array_idx*partition_size), int((array_idx+1)*partition_size) - + unique_df = unique_df[start:end] - + raw_dir = '/cluster/home/t122995uhn/projects/data/PlatinumDataset/raw' #################################### create fastas