-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
29 lines (27 loc) · 953 Bytes
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from n20_net import N20Net
from torch.utils.data import DataLoader
import torch
from saha_dataset import SahaDataset
def test():
BATCH_SIZE = 2000
dataset = SahaDataset(is_train=False)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)
model = N20Net(dataset.x_dim)
model.load_state_dict(torch.load("models/saha.h5"))
model.eval()
criterion = torch.nn.MSELoss(reduction='mean')
loss = None
print(f"Test started ...")
with torch.no_grad():
for data, y_true in dataloader:
y_pred = model(data)
y_pred = y_pred.reshape(-1)
loss = criterion(y_pred, y_true)
print("Ground Truth\t\tPredicted")
for i in range(y_pred.shape[0]):
gt_val = y_true[i]
predicted = y_pred[i]
print(f"{gt_val:.4f}\t\t\t\t{predicted:.4f}")
print(f"MSE {loss:.4f}")
if __name__ == "__main__":
test()