From 0ab364eabfe63ba330189fdcf701f4e85f202d1f Mon Sep 17 00:00:00 2001 From: jyaacoub Date: Tue, 31 Oct 2023 12:04:34 -0400 Subject: [PATCH] fix(datasets): seq len cutoff #50 Was running into silly mem issues with PDBbind since 5 proteins are above 2000 amino acids in length. #50 --- .gitignore | 2 +- playground.py | 32 +++++++++++++++++++++++++++----- src/data_processing/datasets.py | 8 ++++---- src/utils/loader.py | 6 ++++-- 4 files changed, 36 insertions(+), 12 deletions(-) diff --git a/.gitignore b/.gitignore index 012e5965..2b2946a7 100644 --- a/.gitignore +++ b/.gitignore @@ -207,6 +207,6 @@ lib/mgltools_x86_64Linux2_1.5.7/MGLToolsPckgs/AutoDockTools/Utilities24/* lib/mgltools_x86_64Linux2_1.5.7p1.tar.gz log_test/ -slurm_tests/ +slurm_out_DDP/ /*.sh results/model_checkpoints/ours/*.model* diff --git a/playground.py b/playground.py index bb1395b4..d82e4a9c 100644 --- a/playground.py +++ b/playground.py @@ -1,9 +1,31 @@ -# %% -from src.data_analysis.figures import prepare_df, fig3_edge_feat -df = prepare_df('results/model_media/model_stats.csv') +#%% +from src.data_processing.datasets import PDBbindDataset +from src.utils import config as cfg +import pandas as pd +import matplotlib.pyplot as plt + +# d0 = pd.read_csv(f'{cfg.DATA_ROOT}/DavisKibaDataset/davis/nomsa_anm/full/XY.csv', index_col=0) +d0 = pd.read_csv(f'{cfg.DATA_ROOT}/PDBbindDataset/nomsa_anm/full/XY.csv', index_col=0) + +d0['len'] = d0.prot_seq.str.len() # %% -fig3_edge_feat(df, show=True, exclude=[]) +n, bins, patches = plt.hist(d0['len'], bins=20) +# Set labels and title +plt.xlabel('Protein Sequence length') +plt.ylabel('Frequency') +plt.title('Histogram of Protein Sequence length (davis)') + +# Add counts to each bin +for count, x, patch in zip(n, bins, patches): + plt.text(x + 0.5, count, str(int(count)), ha='center', va='bottom') + +cutoff= 1500 +print(f"Eliminating codes above {cutoff} length would reduce the dataset by: {len(d0[d0['len'] > cutoff])}") +print(f"\t - Eliminates {len(d0[d0['len'] > cutoff].index.unique())} unique proteins") + +# %% -d PDBbind -f nomsa -e anm +from src.utils.loader import Loader +d1 = Loader.load_dataset('PDBbind', 'nomsa', 'anm') # %% -print('test') \ No newline at end of file diff --git a/src/data_processing/datasets.py b/src/data_processing/datasets.py index 03239d5d..fd807cfb 100644 --- a/src/data_processing/datasets.py +++ b/src/data_processing/datasets.py @@ -27,10 +27,10 @@ # See: https://pytorch-geometric.readthedocs.io/en/latest/tutorial/create_dataset.html # for details on how to create a dataset class BaseDataset(torchg.data.InMemoryDataset, abc.ABC): - FEATURE_OPTIONS = cfg.PRO_FEAT_OPT EDGE_OPTIONS = cfg.EDGE_OPT - LIGAND_FEATURE_OPTIONS = cfg.LIG_FEAT_OPT + FEATURE_OPTIONS = cfg.PRO_FEAT_OPT LIGAND_EDGE_OPTIONS = cfg.LIG_EDGE_OPT + LIGAND_FEATURE_OPTIONS = cfg.LIG_FEAT_OPT def __init__(self, save_root:str, data_root:str, aln_dir:str, cmap_threshold:float, feature_opt='nomsa', @@ -92,7 +92,7 @@ def __init__(self, save_root:str, data_root:str, aln_dir:str, self.data_root = data_root self.cmap_threshold = cmap_threshold self.overwrite = overwrite - max_seq_len = 100000 or max_seq_len + max_seq_len = max_seq_len or 100000 assert max_seq_len >= 100, 'max_seq_len cant be smaller than 100.' self.max_seq_len = max_seq_len @@ -383,7 +383,7 @@ def process(self): class PDBbindDataset(BaseDataset): # InMemoryDataset is used if the dataset is small and can fit in CPU memory - def __init__(self, save_root=f'{cfg.DATA_ROOT}/PDBbindDataset/nomsa', + def __init__(self, save_root=f'{cfg.DATA_ROOT}/PDBbindDataset', data_root=f'{cfg.DATA_ROOT}/v2020-other-PL', aln_dir=None, cmap_threshold=8.0, feature_opt='nomsa', *args, **kwargs): diff --git a/src/utils/loader.py b/src/utils/loader.py index e0f832e7..cb99fcb6 100644 --- a/src/utils/loader.py +++ b/src/utils/loader.py @@ -111,7 +111,8 @@ def load_dataset(data:str, pro_feature:str, edge_opt:str, subset:str=None, path: subset=subset, af_conf_dir='../colabfold/pdbbind_af2_out/out0', ligand_feature=ligand_feature, - ligand_edge=ligand_edge + ligand_edge=ligand_edge, + max_seq_len=1500 ) elif data in ['davis', 'kiba']: dataset = DavisKibaDataset( @@ -123,7 +124,8 @@ def load_dataset(data:str, pro_feature:str, edge_opt:str, subset:str=None, path: edge_opt=edge_opt, subset=subset, ligand_feature=ligand_feature, - ligand_edge=ligand_edge + ligand_edge=ligand_edge, + max_seq_len=1500 ) else: raise Exception(f'Invalid data option, pick from {Loader.data_opt}')