-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtest.py
executable file
·62 lines (47 loc) · 1.89 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import argparse
import h5py
import numpy as np
import os
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
from dataset import ModelNetCExtraTest
from model import PCC
def test(args):
device = torch.device('cuda')
test_loader = DataLoader(
ModelNetCExtraTest(h5_path=args.h5_path),
batch_size=args.test_batch_size,
shuffle=False,
drop_last=False
)
model = PCC(voxel_size=args.voxel_size).to(device)
model.load_state_dict(torch.load(args.ckpt, map_location=device))
model.eval()
test_pred = []
with torch.no_grad():
for pcd in tqdm(test_loader):
pcd = pcd.to(device)
pcd = pcd.permute(0, 2, 1)
logits = model(pcd)
preds = logits.argmax(dim=1)
test_pred.append(preds.detach().cpu().numpy())
test_pred = np.concatenate(test_pred)
os.makedirs(args.saved_path, exist_ok=True)
f = h5py.File(os.path.join(args.saved_path, 'results.h5'), 'w')
f.create_dataset('label', data=test_pred)
f.close()
if __name__ == '__main__':
# Dataset settings
parser = argparse.ArgumentParser(description='Extra ModelNet-C Prediction')
parser.add_argument('--h5_path', type=str, default='/mnt/ssd1/lifa_rdata/PointCloud-C/cls_extra_test_data.h5',
metavar='N', help='Name of the experiment')
parser.add_argument('--test_batch_size', type=int, default=32, metavar='batch_size',
help='Size of batch)')
# Model settings
parser.add_argument('--voxel_size', type=float, default=0.05, help='down sample voxel size')
parser.add_argument('--ckpt', type=str, metavar='N', help='the trained checkpoint path')
# Saved settings
parser.add_argument('--saved_path', type=str, help='the path to saved path')
args = parser.parse_args()
test(args)