-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
213 lines (181 loc) · 9.64 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
import torch
import numpy as np
import random as rand
import pandas as pd
import os
import time
import esm
import copy
import torch.nn as nn
from train_test import *
from utils import plot_train_dev_metric
from data_handler import prepare_data
from scipy.stats import pearsonr
from sklearn.metrics import roc_auc_score
from transformers import T5Tokenizer
from tape import TAPETokenizer
from config import *
from models import MASSAPLI
if __name__ == "__main__":
rand.seed(SEED)
torch.manual_seed(SEED)
dataset_pack, num_atom_features, num_bond_features = prepare_data(args, task, sampled_frac)
if task == 'PDBBind':
train_set, valid_set, test_set, test_casf2013_set, test_astex_set = dataset_pack
elif task in ['Kinase', 'DUDE']:
# Kianse和DUDE只有train和test
train_set, test_set = dataset_pack
valid_set = test_set
if torch.cuda.is_available():
device_ids = []
device = torch.device("cuda:{}".format(gpu_start))
for i in range(n_gpu):
device_ids.append(gpu_start+i)
print('The code uses GPU...')
else:
device = torch.device('cpu')
print('The code uses CPU!!!')
assert batch_size >= n_gpu, "Batch size must be greater than the number of GPUs used!!!"
""" create model, trainer and tester """
encoder, alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S()
tokenizer = alphabet.get_batch_converter()
num_layers, pro_hid_dim = 12, 768
model = MASSAPLI(encoder, num_layers, pro_hid_dim, num_atom_features, num_bond_features,
args.radius, args.T, args.fingerprint_dim, args.p_dropout, task, return_emb, random)
best_model = None
if do_train:
file_results = os.path.join(path_model, 'results.txt')
file_loss = os.path.join(path_model, 'loss-metric.csv')
f_results = open(file_results, 'a')
start_time = time.time()
model = nn.DataParallel(model, device_ids=device_ids)
model = model.to(device)
if args.guide:
pt_embeddings = pd.read_pickle("/data1/brian/Guided-PDI/data/ESM1b-Pretrain/Uniref50_embeddings.pkl")
pt_embeddings = np.stack(pt_embeddings['pro_emb'].tolist(), axis=0)[:len(train_set)]
pt_embeddings = torch.tensor(pt_embeddings, dtype=torch.float)
trainer = Trainer(model, tokenizer, lr, weight_decay, batch_size, gradient_accumulation, return_emb, freeze_seq_encoder,
guide=args.guide, pt_embeddings=pt_embeddings)
tester = Tester(model, tokenizer, batch_size, return_emb, training=True,
guide=args.guide, pt_embeddings=pt_embeddings)
else:
trainer = Trainer(model, tokenizer, lr, weight_decay, batch_size, gradient_accumulation, return_emb, freeze_seq_encoder)
tester = Tester(model, tokenizer, batch_size, return_emb, training=True)
task_metric = task_metrics[task]
results = (f'Epoch\tTime\tLoss_train\tLoss_dev\t{task_metric}_train\t{task_metric}_dev')
with open(file_results, 'w') as f:
f.write(results + '\n')
"""Start training."""
print('Training...')
print(results)
min_loss_dev = float('inf')
max_metric_dev = -float('inf')
best_epoch = 0
loss_train_epochs, loss_dev_epochs = [], []
metric_train_epochs, metric_dev_epochs = [], []
for epoch in range(1, epochs+1):
start_time_epoch = time.time()
if epoch % decay_interval == 0:
print('LR decay from {:.6f} to {:.6f}'.format(trainer.optimizer.param_groups[0]['lr'],
trainer.optimizer.param_groups[0]['lr']*lr_decay))
trainer.optimizer.param_groups[0]['lr'] *= lr_decay
loss_train, all_predict_labels_train, all_real_labels_train = trainer.train(train_set, device, task)
loss_dev, all_predict_labels_dev, all_real_labels_dev = tester.test(valid_set, device, task)
if task == 'PDBBind':
metric_train, _ = pearsonr(all_real_labels_train, all_predict_labels_train)
metric_dev, _ = pearsonr(all_real_labels_dev, all_predict_labels_dev)
elif task in ['Kinase', 'DUDE']:
metric_train = roc_auc_score(all_real_labels_train, all_predict_labels_train) # stats.prc
metric_dev = roc_auc_score(all_real_labels_dev, all_predict_labels_dev) # stats.prc
loss_train_epochs.append(float("%.3f" % loss_train)), loss_dev_epochs.append(float("%.3f" % loss_dev))
metric_train_epochs.append(float("%.3f" % metric_train)), metric_dev_epochs.append(float("%.3f" % metric_dev))
end_time_epoch = time.time()
seconds = end_time_epoch-start_time_epoch
m, s = divmod(seconds, 60)
h, m = divmod(m, 60)
spend_time_epoch = "%02d:%02d:%02d" % (h, m, s)
loss_train_epoch = "%.3f" % loss_train; loss_dev_epoch = "%.3f" % loss_dev
metric_train_epoch = "%.3f" % metric_train; metric_dev_epoch = "%.3f" % metric_dev
results = [epoch, spend_time_epoch, loss_train_epoch, loss_dev_epoch, metric_train_epoch, metric_dev_epoch]
with open(file_results, 'a') as f:
f.write('\t'.join(map(str, results)) + '\n')
if metric_dev > max_metric_dev:
min_loss_dev = loss_dev
max_metric_dev = metric_dev
best_epoch = epoch
best_model = copy.deepcopy(model)
print('\t'.join(map(str, results)))
# 保存最佳模型
tester.save_model(best_model, os.path.join(path_model, 'model-epoch_{}-metric_{:.3f}.pth'.format(best_epoch, max_metric_dev)))
end_time = time.time()
seconds = end_time-start_time
m, s = divmod(seconds, 60)
h, m = divmod(m, 60)
spend_time = "%02d:%02d:%02d" % (h, m, s)
dict_loss = {}
dict_loss['epochs'] = list(range(1, epochs+1))
dict_loss['loss_train_all'] = loss_train_epochs
dict_loss['loss_dev_all'] = loss_dev_epochs
dict_loss[f'{task_metric}_train_all'] = metric_train_epochs
dict_loss[f'{task_metric}_dev_all'] = metric_dev_epochs
df_loss = pd.DataFrame(dict_loss)
df_loss.to_csv(file_loss, index=False)
plot_train_dev_metric(list(range(1, epochs+1)), loss_train_epochs, loss_dev_epochs, path_model, 'MSE Loss', task)
plot_train_dev_metric(list(range(1, epochs+1)), metric_train_epochs, metric_dev_epochs, path_model, task_metric, task)
final_print = "All epochs spend %s, where the best model is in epoch %d" % (spend_time, best_epoch)
print(final_print)
f_results.write(final_print)
if not do_test:
f_results.close()
if do_test:
if args.guide:
tester = Tester(best_model, tokenizer, batch_size, return_emb, training=True,
guide=args.guide, pt_embeddings=pt_embeddings)
else:
tester = Tester(best_model, tokenizer, batch_size, return_emb, training=True)
loss_test, all_predict_labels_test, all_real_labels_test = tester.test(test_set, device, task)
if task == 'PDBBind':
metric_test, _ = pearsonr(all_real_labels_test, all_predict_labels_test)
elif task in ['Kinase', 'DUDE']:
metric_test = roc_auc_score(all_real_labels_test, all_predict_labels_test)
test_print = "Test results: Loss: %.3f, %s: %.3f" % (loss_test, task_metric, metric_test)
print(test_print)
f_results.write(test_print)
f_results.close()
if do_save_emb:
if best_model is None:
# 加载微调效果最好的模型
suffix = 'guided'
best_model_state_dict = torch.load("weights/init_binding_pred/DUDE-guided.pth", map_location=torch.device('cpu'))
model.load_state_dict(best_model_state_dict)
best_model = model.to(device)
else:
suffix = current_time
if args.guide:
tester = Tester(best_model, tokenizer, batch_size, return_emb, training=False,
guide=args.guide, pt_embeddings=pt_embeddings)
else:
tester = Tester(best_model, tokenizer, batch_size, return_emb, training=False)
print("Embedding dataset num: %d" % len(train_set.dataset))
loss, all_predict_labels, all_real_labels, all_pro_ids, all_pro_seqs, all_pro_embs = tester.test(train_set, device, task)
df_emb = pd.DataFrame()
df_emb['pro_id'] = all_pro_ids
df_emb['pro_seq'] = all_pro_seqs
df_emb['pro_emb'] = all_pro_embs
os.makedirs(f'embeddings/{model_name}', exist_ok=True)
df_emb.to_pickle(f'embeddings/{model_name}/{task}_train_{model_name}_{suffix}.pkl')
if do_save_pretrained_emb:
best_model = model.to(device)
if args.guide:
tester = Tester(best_model, tokenizer, batch_size, return_emb, training=False,
guide=args.guide, pt_embeddings=pt_embeddings)
else:
tester = Tester(best_model, tokenizer, batch_size, return_emb, training=False)
print("Embedding dataset num: %d" % len(train_set.dataset))
loss, all_predict_labels, all_real_labels, all_pro_ids, all_pro_seqs, all_pro_embs = tester.test(train_set, device, task)
df_emb = pd.DataFrame()
df_emb['pro_id'] = all_pro_ids
df_emb['pro_seq'] = all_pro_seqs
df_emb['pro_emb'] = all_pro_embs
os.makedirs(f'embeddings/{model_name}', exist_ok=True)
df_emb.to_pickle(f'embeddings/{model_name}/{task}_train_{model_name}_without_finetune_{current_time}.pkl')