Skip to content

Commit

Permalink
fix(datasets): improve robustness + TypeError fix
Browse files Browse the repository at this point in the history
- Use index file to get pdb_codes instead of directory,
- Check ligand options to ensure they are valid in init.
- fixed TypeError due to unexpected arg with `get_target_edge_weights`.
  • Loading branch information
jyaacoub committed Oct 31, 2023
1 parent 0a90155 commit fcce036
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 17 deletions.
37 changes: 21 additions & 16 deletions src/data_processing/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
class BaseDataset(torchg.data.InMemoryDataset, abc.ABC):
FEATURE_OPTIONS = cfg.PRO_FEAT_OPT
EDGE_OPTIONS = cfg.EDGE_OPT
LIGAND_FEATURE_OPTIONS = cfg.LIG_FEAT_OPT
LIGAND_EDGE_OPTIONS = cfg.LIG_EDGE_OPT

def __init__(self, save_root:str, data_root:str, aln_dir:str,
cmap_threshold:float, feature_opt='nomsa',
Expand Down Expand Up @@ -112,9 +114,15 @@ def __init__(self, save_root:str, data_root:str, aln_dir:str,
f"Invalid edge_opt '{edge_opt}', choose from {self.EDGE_OPTIONS}"
self.edge_opt = edge_opt

# Validating subset
# check ligand options:
assert ligand_feature in self.LIGAND_FEATURE_OPTIONS, \
f"Invalid ligand_feature '{ligand_feature}', choose from {self.LIGAND_FEATURE_OPTIONS}"
self.ligand_feature = ligand_feature
assert ligand_edge in self.LIGAND_EDGE_OPTIONS, \
f"Invalid ligand_edge '{ligand_edge}', choose from {self.LIGAND_EDGE_OPTIONS}"
self.ligand_edge = ligand_edge

# Validating subset
subset = subset or 'full'
save_root = os.path.join(save_root, f'{self.feature_opt}_{self.edge_opt}_{self.ligand_feature}_{self.ligand_edge}') # e.g.: path/to/root/nomsa_anm
print('save_root:', save_root)
Expand Down Expand Up @@ -334,8 +342,7 @@ def process(self):
edge_opt=self.edge_opt,
cmap=pro_cmap,
n_modes=5, n_cpu=4,
af_confs=af_confs,
edgew_p=self.edgew_p(code))
af_confs=af_confs)
np.save(self.edgew_p(code), pro_edge_weight)
pro_edge_weight = torch.Tensor(pro_edge_weight[edge_idx[0], edge_idx[1]])

Expand Down Expand Up @@ -473,15 +480,11 @@ def pre_process(self):
missing_pid = df_pid.prot_id == '------'
df_pid[missing_pid] = df_pid[missing_pid].assign(prot_id = df_pid[missing_pid].index)

############# get pdb codes based on data root dir #############
pdb_codes = os.listdir(self.data_root)
# filter out readme and index folders
pdb_codes = [p for p in pdb_codes if p != 'index' and p != 'readme']

############# creating MSA: #############
#NOTE: assuming MSAs are already created, since this would take a long time to do.
# create_aln_files(df_seq, self.aln_p)
if self.aln_dir is not None:
pdb_codes = df_binding.index # pdbcodes
############# validating codes #############
if self.aln_dir is not None: # create msa if 'msaF' is selected
#NOTE: assuming MSAs are already created, since this would take a long time to do.
# create_aln_files(df_seq, self.aln_p)
# WARNING: use feature_extraction.process_msa method instead
# PDBbindProcessor.fasta_to_aln_dir(self.aln_dir,
# os.path.join(os.path.dirname(self.aln_dir),
Expand All @@ -491,11 +494,13 @@ def pre_process(self):
valid_codes = [c for c in pdb_codes if os.path.isfile(self.aln_p(c))]
# filters out those that do not have aln file
print(f'Number of codes with aln files: {len(valid_codes)} out of {len(pdb_codes)}')
pdb_codes = valid_codes
else: # check if exists
valid_codes = [c for c in pdb_codes if os.path.isfile(self.pdb_p(c))]

#TODO: filter out pdbs that dont have confirmations if edge type af2


pdb_codes = valid_codes
#TODO: filter out pdbs that dont have confirmations if edge type is af2
# currently we treat all edges as the same if no confirmations are found...
# (see protein_edges.get_target_edge_weights():232)
assert len(pdb_codes) > 0, 'Too few PDBCodes, need at least 1...'


Expand Down
2 changes: 1 addition & 1 deletion src/utils/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def init_model(model:str, pro_feature:str, edge:str, dropout:float, ligand_featu
@validate_args({'data': data_opt, 'pro_feature': pro_feature_opt, 'edge_opt': edge_opt,
'ligand_feature':cfg.LIG_FEAT_OPT, 'ligand_edge':cfg.LIG_EDGE_OPT})
def load_dataset(data:str, pro_feature:str, edge_opt:str, subset:str=None, path:str=cfg.DATA_ROOT,
ligand_feature:str=None, ligand_edge:str=None):
ligand_feature:str='original', ligand_edge:str='binary'):
# subset is used for train/val/test split.
# can also be used to specify the cross-val fold used by train1, train2, etc.
if data == 'PDBbind':
Expand Down

0 comments on commit fcce036

Please sign in to comment.