This repository has been archived by the owner on Jan 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathtest.py
220 lines (161 loc) · 6.51 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1, 2"
import yaml
from argparse import ArgumentParser, Namespace
import torch
torch.backends.cudnn.benchmark = True
import torch.nn as nn
import numpy as np
from tqdm import tqdm, trange
from model import VTN
from utils.data import UCF101, SMTHV2, Kinetics400
from utils.utils import preprocess
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim import Adam, SGD, Adagrad
from torch.optim.lr_scheduler import ReduceLROnPlateau
from utils import load_yaml
from einops import rearrange
import json
from glob import glob
# Parse arguments
parser = ArgumentParser()
parser.add_argument("--annotations", type=str, default="dataset/kinetics-400/annotations.json", help="Dataset labels path")
parser.add_argument("--root-dir", type=str, default="dataset/kinetics-400/val", help="Dataset files root-dir")
parser.add_argument("--classInd", type=str, default="dataset/ucf/annotation/classInd.txt", help="ClassInd file")
parser.add_argument("--classes", type=int, default=400, help="Number of classes")
parser.add_argument("--dataset", choices=['ucf', 'smth', 'kinetics'], default='kinetics', help='Dataset type')
parser.add_argument("--per_sample", type=int, default=2, help="Clips per sample")
parser.add_argument("--weight-path", type=str, default="weights/kinetics/lin-v3/weights_20.pth", help='Path to load weights')
# Hyperparameters
parser.add_argument("--batch-size", type=int, default=8, help="Batch size")
parser.add_argument("--config", type=str, default="configs/lin-vtn.yaml", help="Config file")
# Parse arguments
args = parser.parse_args()
print(args)
# Load config
cfg = load_yaml(args.config)
# Load model
model = VTN(**vars(cfg))
if torch.cuda.is_available():
model = nn.DataParallel(model).cuda()
model.load_state_dict(torch.load(args.weight_path))
model.eval()
# Load dataset
if args.dataset == 'ucf':
# Load class name to index
class_map = {}
with open(args.classInd, "r") as f:
for line in f.readlines():
index, name = line.strip().split()
index = int(index)
class_map[name] = index
dataset = UCF101(args.annotations, args.root_dir, preprocess=preprocess, classes=args.classes, frames=cfg.frames, train=False, class_map=class_map)
elif args.dataset == 'smth':
dataset = SMTHV2(args.annotations, args.root_dir, preprocess=preprocess, frames=cfg.frames)
elif args.dataset == 'kinetics':
import av
from functools import lru_cache
from torchvision import transforms
def getK(arr, k=16):
out = []
ratio = len(arr)/k
for i in range(k):
out.append(arr[int(i*ratio)])
return out
@lru_cache
def read_video(root, frames, target=5.12, per_sample=10):
try:
# Read video
cap = av.open(root)
# Metadata
fps = float(cap.streams.video[0].average_rate)
duration = cap.streams.video[0].frames / fps
target_fps = frames/target
# Number of new frames
new_frames = int(target_fps * duration) if duration>=target else frames
imgs = getK([ img.to_image() for img in cap.decode(video=0)], k=new_frames)
diff = (new_frames - frames) / max(per_sample - 1, 1)
# Generate imgs
out = []
for i in range(per_sample):
start = int(i*diff)
out.append(imgs[start: start+frames])
return out
except Exception as e:
print(f"Read error of video {root}, {e}")
# Kinetics-400
class Kinetics400(Dataset):
def __init__(self, labels, root_dir, mean, std, frames=16, per_sample=1):
assert per_sample > 0
with open(labels, "r") as f:
labels = json.load(f)
files = glob(f"{root_dir}/*/*")
self.src = [ (file, labels[file.split('/')[-2]] ) for file in files ]
self.frames = frames
self.preprocess = preprocess
self.per_sample = per_sample
self.resize = transforms.Resize(256)
self.three_crop = transforms.Compose([
transforms.FiveCrop(224),
transforms.Lambda(lambda crops: [ crop for i, crop in enumerate(crops) if i in [0, 3, 4] ])
])
self.preprocess = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)
])
def __len__(self):
return len(self.src)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
id, label = self.src[idx]
videos = read_video(id, self.frames, per_sample=self.per_sample)
if self.preprocess is not None:
out = []
for imgs in videos:
three_imgs = list(map(lambda img: self.three_crop(self.resize(img)), imgs))
for j in range(3):
imgs = [ self.preprocess(three_imgs[i][j]).unsqueeze(0) for i in range(len(three_imgs))]
imgs = torch.cat(imgs)
out.append(imgs.unsqueeze(0))
return torch.cat(out), int(label)
dataset = Kinetics400(args.annotations, args.root_dir, mean=model.module.spatial_transformer.default_cfg['mean'], std=model.module.spatial_transformer.default_cfg['std'], frames=cfg.frames, per_sample=args.per_sample)
dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=10, persistent_workers=True)
# Loss
loss_func = nn.CrossEntropyLoss()
# Softmax
softmax = nn.LogSoftmax(dim=1)
# Validation
val_loss = 0
top1_acc = 0
top5_acc = 0
for src, target in tqdm(dataloader, desc="Validating"):
# src, target = train_loader[i]
if torch.cuda.is_available():
# print(src.shape)
src = rearrange(src, 'b p f c h w -> (b p) f c h w')
src = src.cuda()
target = target.cuda()
with torch.no_grad():
output = model(src)
# Rearrange
loss_avg = torch.mean(rearrange(output, '(b p) d -> b p d', p=args.per_sample*3), dim=1)
loss = loss_func(loss_avg, target)
val_loss += loss.item()
output = softmax(output)
# Rearrange
output = torch.mean(rearrange(output, '(b p) d -> b p d', p=args.per_sample*3), dim=1)
# Top 1
top1_acc += torch.sum(torch.argmax(output, dim=1) == target).cpu().detach().item()
# Top 5
_, idx = torch.topk(output, 5, dim=1)
for label, top5 in zip(target, idx):
if label in top5:
top5_acc += 1
count = len(dataloader) * args.batch_size
val_loss = val_loss / len(dataloader)
top1_acc = top1_acc / count
top5_acc = top5_acc / count
print(f'Loss: {val_loss}, Top 1: {top1_acc}, Top 5: {top5_acc}')