-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathrun_mutagenesis.py
189 lines (152 loc) · 8.37 KB
/
run_mutagenesis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import os, logging, argparse
parser = argparse.ArgumentParser(description='Runs Mutagenesis on an input PDB file and a given ligand SMILES.')
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument('-ls', '--ligand_smiles', type=str, help='Ligand SMILES string.')
group.add_argument('-sdf', '--ligand_sdf', type=str, help='File path to SDF file (needed for GVPL features).')
parser.add_argument('--ligand_id', type=str, required=True,
help='Ligand SMILES identifier, required for creating unique output path.')
parser.add_argument('--pdb_file', type=str, required=True, help='Path to the PDB file.')
parser.add_argument('--out_path', type=str, default='./',
help='Output directory path to save resulting mutagenesis numpy matrix with predicted pkd values')
full_mut = parser.add_argument_group('FULL SATURATION MUTAGENESIS ARGS',
description="This is the default unless mutations are specified")
full_mut.add_argument('--res_start', type=int, default=0, help='Start index for mutagenesis (zero-indexed).')
full_mut.add_argument('--res_end', type=int, default=float('inf'), help='End index for mutagenesis.')
partial_mut = parser.add_argument_group('PARTIAL MUTAGENESIS ARGS',
description="Less intensive for when a full staturation is not needed. " + \
"Runs inference twice - once on native and once on mutated structure.")
partial_mut.add_argument('-mut', '--mutations', type=str, nargs='+', required=False,
help="The mutations to apply to the native structure in the format <native AA><index><mut AA> "+\
"(e.g.: M230A). Note that index starts as 1 as per PDB documentation.")
model_args = parser.add_argument_group('MODEL ARGS')
model_args.add_argument('--model_opt', type=str, default='davis_DG',
choices=['davis_DG', 'davis_gvpl', 'davis_esm',
'kiba_DG', 'kiba_esm', 'kiba_gvpl',
'PDBbind_DG', 'PDBbind_esm', 'PDBbind_gvpl',
'PDBbind_gvpl_aflow'],
help='Model option. See MutDTA/src/__init__.py for details.')
model_args.add_argument('--fold', type=int, default=1,
help='Which model fold to use (there are 5 models for each option due to 5-fold CV).')
model_args.add_argument("-D", "--only_download", help="for downloading esm models if the are missing", default=False, action="store_true")
args = parser.parse_args()
# Assign variables
LIGAND_SMILES = args.ligand_smiles
LIGAND_SDF = args.ligand_sdf
if LIGAND_SDF:
from rdkit import Chem
LIGAND_SMILES = Chem.MolToSmiles(Chem.MolFromMolFile(LIGAND_SDF))
LIGAND_ID = args.ligand_id
PDB_FILE = args.pdb_file
OUT_PATH = args.out_path
MODEL_OPT = args.model_opt
FOLD = args.fold
RES_START = args.res_start
RES_END = args.res_end
MUTATIONS=args.mutations
OUT_DIR = f'{OUT_PATH}/{LIGAND_ID}/{MODEL_OPT}'
ONLY_DOWNLOAD = args.only_download
logging.getLogger().setLevel(logging.DEBUG)
print("#"*50)
print(f"LIGAND_SMILES: {LIGAND_SMILES}")
print(f" LIGAND_SDF: {LIGAND_SDF}")
print(f" LIGAND_ID: {LIGAND_ID}")
print(f" PDB_FILE: {PDB_FILE}")
print(f" OUT_PATH: {OUT_PATH}")
print(f" OUT_DIR: {OUT_DIR}")
print(f"\n RES_START: {RES_START}")
print(f" RES_END: {RES_END}")
print(f"\n MUTATIONS: {MUTATIONS}")
print(f"\n MODEL_OPT: {MODEL_OPT}")
print(f" FOLD: {FOLD}")
print(f"ONLY_DOWNLOAD: {ONLY_DOWNLOAD}")
print("#"*50, end="\n\n")
os.makedirs(OUT_DIR, exist_ok=True)
import numpy as np
import torch
from tqdm import tqdm
from src import TUNED_MODEL_CONFIGS
from src.utils.loader import Loader
from src.utils.residue import ResInfo
from src.data_prep.quick_prep import get_ligand_features, get_protein_features
from src.utils.mutate_model import run_modeller_multiple
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
MODEL_PARAMS = TUNED_MODEL_CONFIGS[MODEL_OPT]
PDB_FILE_NAME = os.path.basename(PDB_FILE).split('.pdb')[0]
##################################################
### Loading the model and get original pkd value #
##################################################
MODEL, _ = Loader.load_tuned_model(MODEL_OPT, fold=FOLD, device=DEVICE)
MODEL.eval()
print(f"MODEL LOADED - {MODEL.__class__}")
if ONLY_DOWNLOAD:
logging.critical("ONLY DOWNLOAD OPTION SET, EXITING")
exit()
# build ligand graph
lig = get_ligand_features(LIGAND_SMILES, MODEL_PARAMS['lig_feat_opt'], MODEL_PARAMS['lig_edge_opt'], LIGAND_SDF)
# build protein graph
pro, pdb_original = get_protein_features(PDB_FILE, MODEL_PARAMS['feature_opt'], MODEL_PARAMS['edge_opt'])
original_seq = pdb_original.sequence
original_pkd = MODEL(pro.to(DEVICE), lig.to(DEVICE))
print("Original pkd:", original_pkd, end="\n\n")
if MUTATIONS:
mut_pdb_file = run_modeller_multiple(PDB_FILE, MUTATIONS)
print(mut_pdb_file)
pro, _ = get_protein_features(mut_pdb_file, MODEL_PARAMS['feature_opt'], MODEL_PARAMS['edge_opt'])
mut_pkd = MODEL(pro.to(DEVICE), lig.to(DEVICE))
print("\nMutated pkd:", mut_pkd)
else:
logging.warning("No mutations were passed in - running full saturation mutagenesis")
# zero indexed res range to mutate:
res_range = (max(RES_START, 0), min(RES_END, len(original_seq)))
from src.utils.mutate_model import run_modeller
amino_acids = ResInfo.amino_acids[:-1] # not including "X" - unknown
muta = np.zeros(shape=(len(amino_acids), len(original_seq)))
with tqdm(range(*res_range), ncols=100, total=(res_range[1]-res_range[0]),
desc='Saturation mutagenesis') as t:
for j in t:
for i, AA in enumerate(amino_acids):
if i%2 == 0:
t.set_postfix(res=j, AA=i+1)
if original_seq[j] == AA: # skip same AA modifications
muta[i,j] = original_pkd
continue
out_pdb_fp = run_modeller(PDB_FILE, j+1, ResInfo.code_to_pep[AA], "A")
pro, _ = get_protein_features(out_pdb_fp, MODEL_PARAMS['feature_opt'], MODEL_PARAMS['edge_opt'])
assert pro.pro_seq != original_seq and pro.pro_seq[j] == AA, \
f"ERROR in modeller, {pro.pro_seq} == {original_seq} \nor {pro.pro_seq[j]} != {AA}"
muta[i,j] = MODEL(pro.to(DEVICE), lig.to(DEVICE))
# delete after use
os.remove(out_pdb_fp)
if MUTATIONS:
mut_pdb_file = run_modeller_multiple(PDB_FILE, MUTATIONS)
print(mut_pdb_file)
pro, _ = get_protein_features(mut_pdb_file, MODEL_PARAMS['feature_opt'], MODEL_PARAMS['edge_opt'])
mut_pkd = MODEL(pro.to(DEVICE), lig.to(DEVICE))
print("\nMutated pkd:", mut_pkd)
else:
logging.warning("No mutations were passed in - running full saturation mutagenesis")
# zero indexed res range to mutate:
res_range = (max(RES_START, 0), min(RES_END, len(original_seq)))
from src.utils.mutate_model import run_modeller
amino_acids = ResInfo.amino_acids[:-1] # not including "X" - unknown
muta = np.zeros(shape=(len(amino_acids), len(original_seq)))
with tqdm(range(*res_range), ncols=100, total=(res_range[1]-res_range[0]),
desc='Saturation mutagenesis') as t:
for j in t:
for i, AA in enumerate(amino_acids):
if i%2 == 0:
t.set_postfix(res=j, AA=i+1)
if original_seq[j] == AA: # skip same AA modifications
muta[i,j] = original_pkd
continue
out_pdb_fp = run_modeller(PDB_FILE, j+1, ResInfo.code_to_pep[AA], "A")
pro, _ = get_protein_features(out_pdb_fp, MODEL_PARAMS['feature_opt'], MODEL_PARAMS['edge_opt'])
assert pro.pro_seq != original_seq and pro.pro_seq[j] == AA, \
f"ERROR in modeller, {pro.pro_seq} == {original_seq} \nor {pro.pro_seq[j]} != {AA}"
muta[i,j] = MODEL(pro.to(DEVICE), lig.to(DEVICE))
# delete after use
os.remove(out_pdb_fp)
# Save mutagenesis matrix
OUT_FP = f"{OUT_DIR}/{res_range[0]}_{res_range[1]}-{os.path.basename(PDB_FILE).split('.pdb')[0]}.npy"
print("Saving mutagenesis numpy matrix to", OUT_FP)
np.save(OUT_FP, muta)