Skip to content

Commit

Permalink
Merge pull request #100 from jyaacoub/development
Browse files Browse the repository at this point in the history
Development
  • Loading branch information
jyaacoub authored May 23, 2024
2 parents e54a179 + c124f0c commit 23c8ae1
Show file tree
Hide file tree
Showing 13 changed files with 360 additions and 132 deletions.
159 changes: 139 additions & 20 deletions playground.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,148 @@
#%%
from src.data_prep.downloaders import Downloader
#%% 1.Gather data for davis,kiba and pdbbind datasets
import os
import pandas as pd
import matplotlib.pyplot as plt
from src.analysis.utils import combine_dataset_pids
from src import config as cfg
df_prots = combine_dataset_pids(dbs=[cfg.DATA_OPT.davis, cfg.DATA_OPT.PDBbind], # just davis and pdbbind for now
subset='test')

Downloader.download_SDFs(['CHEMBL245769', 'CHEMBL55788'], save_dir='./')

#%%
import pandas as pd
#%% 2. Load TCGA data
df_tcga = pd.read_csv('../downloads/TCGA_ALL.maf', sep='\t')

df = pd.read_csv("../data/TCGA_BRCA_Mutations.csv")
df = df.loc[df['SWISSPROT'].dropna().index]
df[['Gene', 'SWISSPROT']].head()
#%% 3. Pre filtering
df_tcga = df_tcga[df_tcga['Variant_Classification'] == 'Missense_Mutation']
df_tcga['seq_len'] = pd.to_numeric(df_tcga['Protein_position'].str.split('/').str[1])
df_tcga = df_tcga[df_tcga['seq_len'] < 5000]
df_tcga['seq_len'].plot.hist(bins=100, title="sequence length histogram capped at 5K")
plt.show()
df_tcga = df_tcga[df_tcga['seq_len'] < 1200]
df_tcga['seq_len'].plot.hist(bins=100, title="sequence length after capped at 1.2K")

# %%
from src.utils.pdb import pdb2uniprot
df2 = pd.read_csv("../data/PlatinumDataset/nomsa_binary_original_binary/full/cleaned_XY.csv", index_col=0)
df2['pdb_id'] = df2.prot_id.str.split("_").str[0]
#%% 4. Merging df_prots with TCGA
df_tcga['uniprot'] = df_tcga['SWISSPROT'].str.split('.').str[0]

dfm = df_tcga.merge(df_prots[df_prots.db != 'davis'],
left_on='uniprot', right_on='prot_id', how='inner')

# for davis we have to merge on HUGO_SYMBOLS
dfm_davis = df_tcga.merge(df_prots[df_prots.db == 'davis'],
left_on='Hugo_Symbol', right_on='prot_id', how='inner')

dfm = pd.concat([dfm,dfm_davis], axis=0)

del dfm_davis # to save mem

# %% 5. Post filtering step
# 5.1. Filter for only those sequences with matching sequence length (to get rid of nonmatched isoforms)
# seq_len_x is from tcga, seq_len_y is from our dataset
tmp = len(dfm)
# allow for some error due to missing amino acids from pdb file in PDBbind dataset
# - assumption here is that isoforms will differ by more than 50 amino acids
dfm = dfm[(dfm.seq_len_y <= dfm.seq_len_x) & (dfm.seq_len_x<= dfm.seq_len_y+50)]
print(f"Filter #1 (seq_len) : {tmp:5d} - {tmp-len(dfm):5d} = {len(dfm):5d}")

# 5.2. Filter out those that dont have the same reference seq according to the "Protein_position" and "Amino_acids" col

# Extract mutation location and reference amino acid from 'Protein_position' and 'Amino_acids' columns
dfm['mt_loc'] = pd.to_numeric(dfm['Protein_position'].str.split('/').str[0])
dfm = dfm[dfm['mt_loc'] < dfm['seq_len_y']]
dfm[['ref_AA', 'mt_AA']] = dfm['Amino_acids'].str.split('/', expand=True)

dfm['db_AA'] = dfm.apply(lambda row: row['prot_seq'][row['mt_loc']-1], axis=1)

# Filter #2: Match proteins with the same reference amino acid at the mutation location
tmp = len(dfm)
dfm = dfm[dfm['db_AA'] == dfm['ref_AA']]
print(f"Filter #2 (ref_AA match): {tmp:5d} - {tmp-len(dfm):5d} = {len(dfm):5d}")
print('\n',dfm.db.value_counts())


# %% final seq len distribution
n_bins = 25
lengths = dfm.seq_len_x
fig, ax = plt.subplots(1, 1, figsize=(10, 5))

# Plot histogram
n, bins, patches = ax.hist(lengths, bins=n_bins, color='blue', alpha=0.7)
ax.set_title('TCGA final filtering for db matches')

# Add counts to each bin
for count, x, patch in zip(n, bins, patches):
ax.text(x + 0.5, count, str(int(count)), ha='center', va='bottom')

uniprots = pdb2uniprot(df2.pdb_id.unique())
df_pid = pd.DataFrame(list(uniprots.items()), columns=['pdbID', 'uniprot'])
ax.set_xlabel('Sequence Length')
ax.set_ylabel('Frequency')

plt.tight_layout()
plt.show()

# %% Getting updated sequences
def apply_mut(row):
ref_seq = list(row['prot_seq'])
ref_seq[row['mt_loc']-1] = row['mt_AA']
return ''.join(ref_seq)

dfm['mt_seq'] = dfm.apply(apply_mut, axis=1)


# %%
dfm.to_csv("/cluster/home/t122995uhn/projects/data/tcga/tcga_maf_davis_pdbbind.csv")
# %%
# find overlap between uniprots and swissprot ids
uniprots = set(uniprots) # convert to set for faster lookup
df['uniprot'] = df['SWISSPROT'].str.split('.').str[0]
# Merge the original df with uni_to_pdb_df
df = df.merge(df_pid, on='uniprot', how='left')
from src.utils.seq_alignment import MSARunner
from tqdm import tqdm
import pandas as pd
import os

DATA_DIR = '/cluster/home/t122995uhn/projects/data/tcga'
CSV = f'{DATA_DIR}/tcga_maf_davis_pdbbind.csv'
N_CPUS= 6
NUM_ARRAYS = 10
array_idx = 0#${SLURM_ARRAY_TASK_ID}

df = pd.read_csv(CSV, index_col=0)
df.sort_values(by='seq_len_y', inplace=True)


# %%
df[~df['pdbID'].isna()]
for DB in df.db.unique():
print('DB', DB)
RAW_DIR = f'{DATA_DIR}/{DB}'
# should already be unique if these are proteins mapped form tcga!
unique_df = df[df['db'] == DB]
########################## Get job partition
partition_size = len(unique_df) / NUM_ARRAYS
start, end = int(array_idx*partition_size), int((array_idx+1)*partition_size)

unique_df = unique_df[start:end]

#################################### create fastas
fa_dir = os.path.join(RAW_DIR, f'{DB}_fa')
fasta_fp = lambda idx,pid: os.path.join(fa_dir, f"{idx}-{pid}.fasta")
os.makedirs(fa_dir, exist_ok=True)
for idx, (prot_id, pro_seq) in tqdm(
unique_df[['prot_id', 'prot_seq']].iterrows(),
desc='Creating fastas',
total=len(unique_df)):
with open(fasta_fp(idx,prot_id), "w") as f:
f.write(f">{prot_id},{idx},{DB}\n{pro_seq}")

##################################### Run hhblits
aln_dir = os.path.join(RAW_DIR, f'{DB}_aln')
aln_fp = lambda idx,pid: os.path.join(aln_dir, f"{idx}-{pid}.a3m")
os.makedirs(aln_dir, exist_ok=True)

# finally running
for idx, (prot_id, pro_seq) in tqdm(
unique_df[['prot_id', 'mt_seq']].iterrows(),
desc='Running hhblits',
total=len(unique_df)):
in_fp = fasta_fp(idx,prot_id)
out_fp = aln_fp(idx,prot_id)

if not os.path.isfile(out_fp):
print(MSARunner.hhblits(in_fp, out_fp, n_cpus=N_CPUS, return_cmd=True))
break

# %%
52 changes: 26 additions & 26 deletions rayTrain_Tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,38 +79,14 @@ def train_func(config):
torch.cuda.device_count(), "devices")
print("CUDA VERSION:", torch.__version__)

search_space = {
## constants:
"epochs": 20,
"model": cfg.MODEL_OPT.DG,
"dataset": cfg.DATA_OPT.PDBbind,
"feature_opt": cfg.PRO_FEAT_OPT.nomsa,
"edge_opt": cfg.PRO_EDGE_OPT.aflow,
"lig_feat_opt": cfg.LIG_FEAT_OPT.original,
"lig_edge_opt": cfg.LIG_EDGE_OPT.binary,

"fold_selection": 0,
"save_checkpoint": False,

## hyperparameters to tune:
"lr": ray.tune.loguniform(1e-5, 1e-3),
"batch_size": ray.tune.choice([32, 64, 128]), # local batch size

# model architecture hyperparams
"architecture_kwargs":{
"dropout": ray.tune.uniform(0.0, 0.5),
"output_dim": ray.tune.choice([128, 256, 512]),
}
}
# 'gvpL_aflow': ('nomsa', 'aflow', 'gvp', 'binary'):
# search_space = {
# ## constants:
# "epochs": 20,
# "model": cfg.MODEL_OPT.GVPL,
# "model": cfg.MODEL_OPT.DG,
# "dataset": cfg.DATA_OPT.PDBbind,
# "feature_opt": cfg.PRO_FEAT_OPT.nomsa,
# "edge_opt": cfg.PRO_EDGE_OPT.aflow,
# "lig_feat_opt": cfg.LIG_FEAT_OPT.gvp,
# "lig_feat_opt": cfg.LIG_FEAT_OPT.original,
# "lig_edge_opt": cfg.LIG_EDGE_OPT.binary,

# "fold_selection": 0,
Expand All @@ -126,6 +102,30 @@ def train_func(config):
# "output_dim": ray.tune.choice([128, 256, 512]),
# }
# }
# 'gvpL_aflow': ('nomsa', 'aflow', 'gvp', 'binary'):
search_space = {
## constants:
"epochs": 20,
"model": cfg.MODEL_OPT.GVPL,
"dataset": cfg.DATA_OPT.davis,
"feature_opt": cfg.PRO_FEAT_OPT.nomsa,
"edge_opt": cfg.PRO_EDGE_OPT.aflow,
"lig_feat_opt": cfg.LIG_FEAT_OPT.gvp,
"lig_edge_opt": cfg.LIG_EDGE_OPT.binary,

"fold_selection": 0,
"save_checkpoint": False,

## hyperparameters to tune:
"lr": ray.tune.loguniform(1e-5, 1e-3),
"batch_size": ray.tune.choice([32, 64, 128]), # local batch size

# model architecture hyperparams
"architecture_kwargs":{
"dropout": ray.tune.uniform(0.0, 0.5),
"output_dim": ray.tune.choice([128, 256, 512]),
}
}
# search space for GVPL_RNG MODEL:
# search_space = {
# ## constants:
Expand Down
7 changes: 6 additions & 1 deletion results/model_media/model_stats.csv
Original file line number Diff line number Diff line change
Expand Up @@ -176,4 +176,9 @@ DGM_PDBbind1D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_20
DGM_PDBbind3D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_2000E_originalLF_binaryLE,0.7015552107083158,0.6057001165731564,0.569729312002546,2.358233326230585,1.2048145031929016,1.5356540385876585
DGM_PDBbind4D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_2000E_originalLF_binaryLE,0.7085153085744522,0.6208329976682688,0.5859156875580817,2.2240677925526744,1.1722256461779277,1.4913308796349234
DGM_PDBbind2D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_2000E_originalLF_binaryLE,0.6947751328102371,0.5816010010184015,0.5528265670239861,2.3897321791041333,1.2105133893376303,1.545875861479224
DGM_PDBbind0D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_2000E_originalLF_binaryLE,0.6855330290119371,0.5726521315119578,0.5325884299986027,2.4758763933999823,1.2316968268061441,1.57349178370908
DGM_PDBbind0D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_2000E_originalLF_binaryLE,0.6855330290119371,0.5726521315119578,0.5325884299986027,2.4758763933999823,1.231696826806144,1.57349178370908
GVPLM_davis0D_nomsaF_aflowE_128B_0.0001360163557088453LR_0.027175922988649594D_2000E_gvpLF_binaryLE,0.7581108896345126,0.48942831390709474,0.45529634961852566,0.5072034546494222,0.4161494350523831,0.7121821779919953
GVPLM_davis1D_nomsaF_aflowE_128B_0.0001360163557088453LR_0.027175922988649594D_2000E_gvpLF_binaryLE,0.768273891734927,0.550820575655664,0.4743629409308578,0.4386155850598172,0.3857342382535988,0.6622805939024767
GVPLM_davis3D_nomsaF_aflowE_128B_0.0001360163557088453LR_0.027175922988649594D_2000E_gvpLF_binaryLE,0.7603887296892994,0.5280132204660057,0.4597706307310624,0.4791434307834897,0.3716412002266697,0.6922018714099881
GVPLM_davis2D_nomsaF_aflowE_128B_0.0001360163557088453LR_0.027175922988649594D_2000E_gvpLF_binaryLE,0.7694890904708525,0.5217685104759502,0.4753118765332927,0.4673188139402668,0.4174866665020387,0.6836072073495618
GVPLM_davis4D_nomsaF_aflowE_128B_0.0001360163557088453LR_0.027175922988649594D_2000E_gvpLF_binaryLE,0.760520722940427,0.5078128106282308,0.4629361403886756,0.5235965111752061,0.3783202273117297,0.723599689866715
5 changes: 5 additions & 0 deletions results/model_media/model_stats_val.csv
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,8 @@ DGM_PDBbind3D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_20
DGM_PDBbind4D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_2000E_originalLF_binaryLE,0.6980325998261337,0.5867948698567493,0.5613217730433377,2.504824019579536,1.2573465726293105,1.5826635838293417
DGM_PDBbind2D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_2000E_originalLF_binaryLE,0.6760856277267194,0.5188577783623333,0.5051689763782492,2.7333317052016373,1.2894853849669785,1.653279076623677
DGM_PDBbind0D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_2000E_originalLF_binaryLE,0.678631965986838,0.5169422908392893,0.5136610285075189,2.659271637069346,1.2811931355709805,1.6307273337591868
GVPLM_davis0D_nomsaF_aflowE_128B_0.0001360163557088453LR_0.027175922988649594D_2000E_gvpLF_binaryLE,0.7727048764494124,0.5829543940410362,0.46680064762373447,0.4654286995546293,0.3997238369103796,0.6822233501974476
GVPLM_davis1D_nomsaF_aflowE_128B_0.0001360163557088453LR_0.027175922988649594D_2000E_gvpLF_binaryLE,0.7767585519522199,0.5805240837724788,0.4946235296015073,0.4138118485637668,0.3817763907346182,0.6432820909708017
GVPLM_davis3D_nomsaF_aflowE_128B_0.0001360163557088453LR_0.027175922988649594D_2000E_gvpLF_binaryLE,0.8202768328439823,0.6402255623870499,0.5435232833589796,0.3676188640274792,0.3097654533567193,0.6063158121206136
GVPLM_davis2D_nomsaF_aflowE_128B_0.0001360163557088453LR_0.027175922988649594D_2000E_gvpLF_binaryLE,0.7854232789039847,0.5996614387831698,0.527865354318397,0.4509265081187592,0.4061779738808213,0.67151061653466
GVPLM_davis4D_nomsaF_aflowE_128B_0.0001360163557088453LR_0.027175922988649594D_2000E_gvpLF_binaryLE,0.7849321063417397,0.6186046172952341,0.5194671423553721,0.4498645421549498,0.3525368623308257,0.670719421334249
4 changes: 2 additions & 2 deletions src/analysis/figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def fig6_protein_appearance(datasets=['kiba', 'PDBbind'], show=False):

def fig_combined(df, datasets=['PDBbind','davis', 'kiba'], metrics=['cindex', 'mse'],
fig_callable=fig4_pro_feat_violin, fig_scale=(5,4),
show=False, **kwargs):
show=False, title_postfix='', **kwargs):
# Create subplots with datasets as columns and metrics as rows
fig, axes = plt.subplots(len(metrics), len(datasets),
figsize=(fig_scale[0]*len(datasets),
Expand All @@ -362,7 +362,7 @@ def fig_combined(df, datasets=['PDBbind','davis', 'kiba'], metrics=['cindex', 'm

# Add titles only to the top row and left column
if j == 0:
ax.set_title(f'{dataset}')
ax.set_title(f'{dataset}{title_postfix}')
ax.set_xlabel('')
ax.set_xticklabels([])
elif j < len(metrics)-1: # middle row
Expand Down
27 changes: 27 additions & 0 deletions src/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from scipy.stats import ttest_ind
import pandas as pd
import numpy as np
from src import config as cfg


def count_missing_res(pdb_file: str) -> Tuple[int,int]:
Expand Down Expand Up @@ -117,6 +118,32 @@ def generate_markdown(results, names=None, verbose=False, thresh_sig=False, cind
if verbose: print(md_output)
return md_table

def combine_dataset_pids(data_dir=cfg.DATA_ROOT,
dbs=[cfg.DATA_OPT.davis, cfg.DATA_OPT.kiba, cfg.DATA_OPT.PDBbind],
target='nomsa_aflow_gvp_binary', subset='full',
xy='cleaned_XY.csv'):
df_all = None
dir_p = {'davis':'DavisKibaDataset/davis',
'kiba': 'DavisKibaDataset/kiba',
'PDBbind': 'PDBbindDataset'}
dbs = {d.value: f'{data_dir}/{dir_p[d]}/{target}/{subset}/{xy}' for d in dbs}

for DB, fp in dbs.items():
print(DB, fp)
df = pd.read_csv(fp, index_col=0)
df['pdb_id'] = df.prot_id.str.split("_").str[0]
df = df[['prot_id', 'prot_seq']].drop_duplicates(subset='prot_id')

df['seq_len'] = df['prot_seq'].str.len()
df['db'] = DB
df.reset_index(inplace=True)

df = df[['db','code', 'prot_id', 'seq_len', 'prot_seq']] # reorder them.
df.index.name = 'db_idx'
df_all = df if df_all is None else pd.concat([df_all, df], axis=0)

return df_all

if __name__ == '__main__':
#NOTE: the following is code for stratifying AutoDock Vina results by
# missing residues to identify if there is a correlation between missing
Expand Down
Loading

0 comments on commit 23c8ae1

Please sign in to comment.