Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

resolves #113 #117

Merged
merged 32 commits into from
Jul 9, 2024
Merged
Changes from 1 commit
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
dfc29e0
feat(downloader): KLIFS pocket sequences #109
jyaacoub Jun 21, 2024
cf077ae
results(davis): gvpl_aflow with corrected hparams #90
jyaacoub Jun 21, 2024
1c9c441
feat(tcga): analysis script #95 #111
jyaacoub Jun 21, 2024
7027018
results(tuned): tuned davis_gvpl_esm_aflow and retuned kiba_GVPL_aflo…
jyaacoub Jun 25, 2024
76b8acf
fix(ESM_GVPL): safe_load_state_dict via inheritence of BaseModel #90
jyaacoub Jun 25, 2024
4152347
fix(ESM_GVPL): syntax, missing comma
jyaacoub Jun 26, 2024
891a39c
fix(tuned): davis_DG params + loader #90
jyaacoub Jun 26, 2024
cf4f2e9
refactor(playground): prediction with tuned models moved to its own s…
jyaacoub Jun 27, 2024
42542a0
results(DG,GVPL): updated training with unified test set #90
jyaacoub Jun 28, 2024
d9f8091
chore: update gitignore for new models folder
jyaacoub Jun 28, 2024
bd796de
results(GVPL_ESM): davis GVPL+esm performance #90
jyaacoub Jul 2, 2024
3b8b0a8
results(kiba): updated gvpL_aflow results #90
jyaacoub Jul 2, 2024
f442919
fix(prepare_df): parse for GVPL_ESM model results #90
jyaacoub Jul 3, 2024
9ac093f
feat(resplit): resplit stub for #113
jyaacoub Jul 3, 2024
2e61fd2
refactor(datasets): logging for #114
jyaacoub Jul 3, 2024
cdf930f
fix(split): explicit val_size for balanced_kfold_split
jyaacoub Jul 3, 2024
765caf2
fix(loader): init_dataset_object and splitting overlap issue #112 #113
jyaacoub Jul 3, 2024
0c43a7c
fix(loader): max_seq_len kwarg #112
jyaacoub Jul 4, 2024
c47be94
feat(resplit): for resplitting existing datasets into proper folds #1…
jyaacoub Jul 4, 2024
ef0106c
feat(resplit): extract csvs from "like_dataset" #112 #113
jyaacoub Jul 4, 2024
c4c7741
feat: davis splits #112 #113
jyaacoub Jul 4, 2024
099c3a3
fix(splitting): created davis splits #113
jyaacoub Jul 4, 2024
b032b5b
fix(config): new results dir for issue #113
jyaacoub Jul 4, 2024
e80e225
chore(gitignore): ignoring checkpoints for #113
jyaacoub Jul 5, 2024
c7fdc86
refactor(playground): #113
jyaacoub Jul 5, 2024
256563c
chore(pdbbind): created pdbbind test set #113
jyaacoub Jul 6, 2024
b15e83d
fix(init_dataset): adding `resplit` to create_datasets #113
jyaacoub Jul 8, 2024
20c9343
fix(datasets): paths for aflow files #116
jyaacoub Jul 8, 2024
3fd9367
results(davis_DGM): retrained davis_DGM on new splits #113 #112
jyaacoub Jul 8, 2024
69add71
feat(splits): created kiba and pdbind splits #113
jyaacoub Jul 8, 2024
1361c7e
results(davis_gvpl): retrained davis_gvpl on new splits #113 #112
jyaacoub Jul 8, 2024
a0e4405
results(davis): retrained aflow models #113 due to issue #116
jyaacoub Jul 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix(loader): init_dataset_object and splitting overlap issue #112 #113
Example usage:
```python
#%% now based on this test set we can create the splits that will be used for all models
# 5-fold cross validation + test set
import pandas as pd
from src import cfg
from src.train_test.splitting import balanced_kfold_split
from src.utils.loader import Loader

test_df = pd.read_csv('/home/jean/projects/data/splits/davis_test_genes_oncoG.csv')
test_prots = set(test_df.prot_id)

db = Loader.load_dataset(f'{cfg.DATA_ROOT}/DavisKibaDataset/davis/nomsa_binary_original_binary/full/')

#%%
train, val, test = balanced_kfold_split(db,
                k_folds=5, test_split=0.1, val_split=0.1,
                test_prots=test_prots, random_seed=0, verbose=True
                )

#%%
db.save_subset_folds(train, 'train')
db.save_subset_folds(val, 'val')
db.save_subset(test, 'test')
```
  • Loading branch information
jyaacoub committed Jul 3, 2024
commit 765caf29813abeec73279d90fe88f03cf26673ac
25 changes: 25 additions & 0 deletions playground.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,28 @@
#%% now based on this test set we can create the splits that will be used for all models
# 5-fold cross validation + test set
import pandas as pd
from src import cfg
from src.train_test.splitting import balanced_kfold_split
from src.utils.loader import Loader

test_df = pd.read_csv('/home/jean/projects/data/splits/davis_test_genes_oncoG.csv')
test_prots = set(test_df.prot_id)

db = Loader.load_dataset(f'{cfg.DATA_ROOT}/DavisKibaDataset/davis/nomsa_binary_original_binary/full/')

#%%
train, val, test = balanced_kfold_split(db,
k_folds=5, test_split=0.1, val_split=0.1,
test_prots=test_prots, random_seed=0, verbose=True
)


#%%
db.save_subset_folds(train, 'train')
db.save_subset_folds(val, 'val')
db.save_subset(test, 'test')


# %%
########################################################################
########################## VIOLIN PLOTTING #############################
6 changes: 4 additions & 2 deletions src/train_test/splitting.py
Original file line number Diff line number Diff line change
@@ -133,7 +133,8 @@ def train_val_test_split(dataset: BaseDataset,

return train_loader, val_loader, test_loader

def balanced_kfold_split(dataset: BaseDataset,
@init_dataset_object(strict=True)
def balanced_kfold_split(dataset: BaseDataset |str,
k_folds:int=5, test_split=.1, val_split=.1,
shuffle_dataset=True, random_seed=None,
batch_train=128,
@@ -210,6 +211,7 @@ def balanced_kfold_split(dataset: BaseDataset,

# removing selected proteins from prots
prots = [p for p in prots if p not in test_prots]
prot_counts = {p: c for p, c in prot_counts.items() if p not in test_prots} # remove test_prots
print(f'Number of unique proteins in test set: {len(test_prots)} == {count} samples')


@@ -221,7 +223,7 @@ def balanced_kfold_split(dataset: BaseDataset,
prot_folds = [[[], 0, -1] for i in range(k_folds)]
# score = fold.weight - abs(fold.weight/len(fold) - item.weight)
prot_counts = sorted(list(prot_counts.items()), key=lambda x: x[1], reverse=True)
for p, c in prot_counts:
for p, c in prot_counts:
# Update scores for each fold
for fold in prot_folds:
f_len = len(fold[0])
122 changes: 64 additions & 58 deletions src/utils/loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import logging
from functools import wraps
from typing import Iterable
from torch.utils.data.distributed import DistributedSampler
@@ -24,6 +25,9 @@ def wrapper(*args, **kwargs):
return wrapper
return decorator

##########################################################
################## Class Method ##########################
##########################################################
class Loader():
model_opt = cfg.MODEL_OPT
edge_opt = cfg.PRO_EDGE_OPT
@@ -162,11 +166,12 @@ def init_model(model:str, pro_feature:str, pro_edge:str, dropout:float=0.2, **kw
edge_weight_opt=pro_edge,
**kwargs)
return model

@staticmethod
@validate_args({'data': data_opt, 'pro_feature': pro_feature_opt, 'edge_opt': edge_opt,
'ligand_feature':cfg.LIG_FEAT_OPT, 'ligand_edge':cfg.LIG_EDGE_OPT})
def load_dataset(data:str, pro_feature:str, edge_opt:str, subset:str=None, path:str=cfg.DATA_ROOT,
def load_dataset(data:str, pro_feature:str=None, edge_opt:str=None, subset:str=None,
path:str=cfg.DATA_ROOT,
ligand_feature:str='original', ligand_edge:str='binary'):
# subset is used for train/val/test split.
# can also be used to specify the cross-val fold used by train1, train2, etc.
@@ -212,6 +217,10 @@ def load_dataset(data:str, pro_feature:str, edge_opt:str, subset:str=None, path:
subset=subset,
)
else:
# Check if dataset is a string (file path) and it exists
if isinstance(data, str) and os.path.exists(data):
kwargs = Loader.parse_db_kwargs(data)
return Loader.load_dataset(**kwargs)
raise Exception(f'Invalid data option, pick from {Loader.data_opt}')

return dataset
@@ -316,61 +325,58 @@ def load_distributed_DataLoaders(num_replicas:int, rank:int, seed:int, data:str,
loaders[d] = loader

return loaders

@staticmethod
def parse_db_kwargs(db_path):
"""
Parses parameters given a path string to a db you want to load up.
If subset folder is not included then we default to 'full' for the subset
"""
kwargs = {
'data': None,
'subset': 'full',
}
# get db class/type
db_path_s = [x for x in db_path.split('/') if x]
if 'PDBbindDataset' in db_path_s:
idx_cls = db_path_s.index('PDBbindDataset')
kwargs['data'] = cfg.DATA_OPT.PDBbind
if len(db_path_s) > idx_cls+2: # +2 to skip over db_params
kwargs['subset'] = db_path_s[idx_cls+2]
# remove from string
db_path = '/'.join(db_path_s[:idx_cls+2])
elif 'DavisKibaDataset' in db_path_s:
idx_cls = db_path_s.index('DavisKibaDataset')
kwargs['data'] = cfg.DATA_OPT.davis if db_path_s[idx_cls+1] == 'davis' else cfg.DATA_OPT.kiba
if len(db_path_s) > idx_cls+3:
kwargs['subset'] = db_path_s[idx_cls+3]
db_path = '/'.join(db_path_s[:idx_cls+3])
else:
raise ValueError(f"Invalid path string, couldn't find db class info - {db_path_s}")




##################################################
########## Extra related helpful methods #########
def parse_db_kwargs(db_path):
"""
Parses parameters given a path string to a db you want to load up.
If subset folder is not included then we default to 'full' for the subset
"""
kwargs = {
'data': None,
'subset': 'full',
# get db parameters:
kwargs_p = {
'pro_feature': cfg.PRO_FEAT_OPT,
'edge_opt': cfg.PRO_EDGE_OPT,
'ligand_feature': cfg.LIG_FEAT_OPT,
'ligand_edge': cfg.LIG_EDGE_OPT,
}
# get db class/type
db_path_s = [x for x in db_path.split('/') if x]
if 'PDBbindDataset' in db_path_s:
idx_cls = db_path_s.index('PDBbindDataset')
kwargs['data'] = cfg.DATA_OPT.PDBbind
if len(db_path_s) > idx_cls+2: # +2 to skip over db_params
kwargs['subset'] = db_path_s[idx_cls+2]
# remove from string
db_path = '/'.join(db_path_s[:idx_cls+2])
elif 'DavisKibaDataset' in db_path_s:
idx_cls = db_path_s.index('DavisKibaDataset')
kwargs['data'] = cfg.DATA_OPT.davis if db_path_s[idx_cls+1] == 'davis' else cfg.DATA_OPT.kiba
if len(db_path_s) > idx_cls+3:
kwargs['subset'] = db_path_s[idx_cls+3]
db_path = '/'.join(db_path_s[:idx_cls+3])
else:
raise ValueError(f"Invalid path string, couldn't find db class info - {db_path_s}")

# get db parameters:
kwargs_p = {
'pro_feature': cfg.PRO_FEAT_OPT,
'edge_opt': cfg.PRO_EDGE_OPT,
'ligand_feature': cfg.LIG_FEAT_OPT,
'ligand_edge': cfg.LIG_EDGE_OPT,
}
db_params = os.path.basename(db_path.strip('/')).split('_')
for k, params in kwargs_p.items():
double = "_".join(db_params[:2])
db_params = os.path.basename(db_path.strip('/')).split('_')
for k, params in kwargs_p.items():
double = "_".join(db_params[:2])

if double in params:
kwargs_p[k] = double
db_params = db_params[2:]
elif db_params[0] in params:
kwargs_p[k] = db_params[0]
db_params = db_params[1:]
else:
raise ValueError(f'Invalid option, did not find {double} or {db_params[0]} in {params}')
assert len(db_params) == 0, f"still some unparsed params - {db_params}"

if double in params:
kwargs_p[k] = double
db_params = db_params[2:]
elif db_params[0] in params:
kwargs_p[k] = db_params[0]
db_params = db_params[1:]
else:
raise ValueError(f'Invalid option, did not find {double} or {db_params[0]} in {params}')
assert len(db_params) == 0, f"still some unparsed params - {db_params}"

return {**kwargs, **kwargs_p}
return {**kwargs, **kwargs_p}


# decorator to allow for input to simply be the path to the dataset directory.
def init_dataset_object(strict=True):
@@ -385,9 +391,9 @@ def wrapper(*args, **kwargs):
raise FileNotFoundError(f'Dataset does not exist - {dataset}')

# Parse and build dataset
kwargs = parse_db_kwargs(dataset)
print('Loading dataset with', kwargs)
built = Loader.load_dataset(**kwargs)
db_kwargs = Loader.parse_db_kwargs(dataset)
logging.info(f'Loading dataset with {db_kwargs}')
built = Loader.load_dataset(**db_kwargs)
elif isinstance(dataset, BaseDataset):
built = dataset
elif dataset is None:
@@ -404,4 +410,4 @@ def wrapper(*args, **kwargs):
# Return the function call output
return func(*args, **kwargs)
return wrapper
return decorator
return decorator