-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathinference.py
78 lines (61 loc) · 2.71 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os, random, dgl, torch
import pandas as pd
from tqdm import tqdm
from dgl.dataloading import GraphDataLoader
from data.data import BAPredDataset
from model.model import PredictionPKD
def inference(protein_pdb, ligand_file, output, batch_size, model_path, device='cpu'):
dataset = BAPredDataset(protein_pdb=protein_pdb, ligand_file=ligand_file)
loader = GraphDataLoader(dataset, batch_size=batch_size, shuffle=False, pin_memory=True)
model = PredictionPKD(57, 256, 13, 25, 20, 6, 0.2).to(device)
weight_path = f'{model_path}/BAPred.pth'
model.load_state_dict(torch.load(weight_path, weights_only=True)['model_state_dict'])
model.eval()
results = {
"Name": [],
"pKd": [],
"Kcal/mol": [],
}
with torch.no_grad():
progress_bar = tqdm(total=len(loader.dataset), unit='ligand')
for data in loader:
bgp, bgl, bgc, error, idx, name = data
bgp, bgl, bgc = bgp.to(device), bgl.to(device), bgc.to(device)
pkd = model(bgp, bgl, bgc)
pkd = pkd.view(-1)
pkd[error == 1] = torch.tensor(float('nan'))
results["Name"].extend([str(i) for i in name])
results['pKd'].extend(pkd.tolist())
results['Kcal/mol'].extend((pkd / -0.73349).tolist())
progress_bar.update(len(idx))
progress_bar.close()
df = pd.DataFrame(results)
df = df.round(4)
df.to_csv(output, sep='\t', na_rep='NaN', index=False)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-r', '--protein_pdb', default='./example/1KLT_rec.pdb', help='receptor .pdb')
parser.add_argument('-l', '--ligand_file', default='./example/chk.sdf', help='ligand .sdf/.mol2/.txt')
parser.add_argument('-o', '--output', default='./example/result.csv', help='result output file')
parser.add_argument('--batch_size', default=128, type=int, help='batch size')
parser.add_argument('--ncpu', default=4, type=int, help="cpu worker number")
parser.add_argument('--device', type=str, default='cuda', help='choose device: cpu or cuda')
parser.add_argument('--model_path', type=str, default='./weight', help='model weight path')
args = parser.parse_args()
if args.device == 'cpu':
device = torch.device("cpu")
else:
if torch.cuda.is_available():
device = torch.device("cuda")
else:
print("gpu is not available, run on cpu")
device = torch.device("cpu")
inference(
protein_pdb=args.protein_pdb,
ligand_file=args.ligand_file,
output=args.output,
batch_size=args.batch_size,
model_path=args.model_path,
device=args.device
)