-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
127 lines (103 loc) · 5.11 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import sys
sys.path.append("net")
sys.path.append("net/modelzoo")
sys.path.append("net/basemodel")
sys.path.append("motion")
import torch
from absl import app
from absl import flags
from torch.utils.data import DataLoader
import pandas as pd
from BodyAE import BodyAE
from BodyMotionGenerator import BodyMotionGenerator
from ConvMotionTransformVAE import ConvMotionTransformVAE
from LstmBodyAE import LstmBodyAE
from HagglingDataset import HagglingDataset
from CharControlMotionVAE import CharControlMotionVAE
from Metrics import Metrics
FLAGS = flags.FLAGS
flags.DEFINE_integer('batch_size', 256, 'Training set mini batch size')
flags.DEFINE_string('meta', 'meta/', 'Directory containing metadata files')
flags.DEFINE_string('test', 'MannData/test/', 'Directory containing test files')
flags.DEFINE_string('output_dir', 'Data/output/', 'Folder to store final videos')
flags.DEFINE_string('ckpt_dir', 'ckpt/', 'file containing the model weights')
flags.DEFINE_float('lmd', 0.2, 'L1 Regularization factor')
flags.DEFINE_boolean('bodyae', False, 'if True checks BodyAE model')
flags.DEFINE_integer('enc_hidden_units', 256, 'Encoder LSTM hidden units')
flags.DEFINE_integer('dec_hidden_units', 256, 'Decoder LSTM hidden units')
flags.DEFINE_integer('gat_hidden_units', 256, 'Gating network hidden units')
flags.DEFINE_integer('enc_layers', 3, 'encoder layers')
flags.DEFINE_integer('dec_layers', 1, 'decoder layers')
flags.DEFINE_integer('num_experts', 8, 'number of experts in decoder')
flags.DEFINE_float('enc_dropout', 0.25, 'encoder dropout')
flags.DEFINE_float('dec_dropout', 0.25, 'decoder dropout')
flags.DEFINE_float('gat_dropout', 0.25, 'gating network dropout')
flags.DEFINE_float('dropout', 0.25, 'dense network dropout')
flags.DEFINE_float('tf_ratio', 0.3, 'teacher forcing ratio')
flags.DEFINE_integer('seq_length', 120, 'time steps in the sequence')
flags.DEFINE_integer('latent_dim', 32, 'latent dimension')
flags.DEFINE_float('start_scheduled_sampling', 0.2, 'when to start scheduled sampling')
flags.DEFINE_float('end_scheduled_sampling', 0.4, 'when to stop scheduled sampling')
flags.DEFINE_integer('c_dim', 2, 'number of conditional variables added to latent dimension')
flags.DEFINE_bool('speak', True, 'speak classification required')
flags.DEFINE_float('lmd2', 0.2, 'Regularization factor for speaking predcition')
flags.DEFINE_float('lmd3', 0.2, 'Regularization factor for velocity predcition')
flags.DEFINE_integer('frechet_pose_dim', 42, 'Number of joint directions')
flags.DEFINE_string('frechet_ckpt', 'ckpt/Frechet/', 'file containing the model weights')
flags.DEFINE_integer('input_dim', 244, 'input pose vector dimension')
flags.DEFINE_integer('output_dim', 244, 'input pose vector dimension')
flags.DEFINE_string('model', "MVAE", 'Defines the name of the model')
flags.DEFINE_bool('pretrain', False, 'Use a pretrained model')
flags.DEFINE_bool('CNN', False, 'Cnn based model')
flags.DEFINE_bool('VAE', True, 'VAE training')
flags.DEFINE_string('pretrainedModel', 'bodyAE', 'path to pretrained weights')
flags.DEFINE_integer('batch_runs', 1, 'Number of times give the same input to VAE')
flags.DEFINE_integer('num_saves', 5, 'number of outputs to save')
flags.DEFINE_integer('test_ckpt', 350, 'checkpoint to test')
flags.DEFINE_string('fmt', 'mann', 'data format')
flags.DEFINE_string('device', 'cuda:0', 'Device to train on')
flags.DEFINE_integer('pretrained_ckpt', None, 'Number of epochs to checkpoint of pretrained model')
pss = lambda a, b: a == b
def get_model():
"""
Returns the appropriate model for training
:return: PyTorch model that extends nn.Module
"""
if FLAGS.model == 'bodyAE':
x = BodyAE(FLAGS).to(torch.device(FLAGS.device))
return x
elif FLAGS.model == 'lstmAE':
return LstmBodyAE(FLAGS).to(torch.device(FLAGS.device))
elif FLAGS.model == 'MTVAE':
return ConvMotionTransformVAE(FLAGS).to(torch.device(FLAGS.device))
elif FLAGS.model == 'MVAE':
return CharControlMotionVAE(FLAGS).to(torch.device(FLAGS.device))
else:
return BodyMotionGenerator(FLAGS).to(torch.device(FLAGS.device))
def main(arg):
test_dataset = HagglingDataset(FLAGS.test, FLAGS)
test_dataloader = DataLoader(test_dataset, batch_size=FLAGS.batch_size, num_workers=10)
ckpt = FLAGS.ckpt_dir + FLAGS.model
model = get_model()
model.load_model(ckpt, FLAGS.test_ckpt)
model.eval()
metrics = Metrics(FLAGS)
df = pd.DataFrame()
with torch.no_grad():
for i_batch, batch in enumerate(test_dataloader):
batch_runs = FLAGS.batch_runs
if FLAGS.VAE:
batch_runs = FLAGS.batch_runs
for test_num in range(0, batch_runs):
predictions, targets = model(batch)
out = metrics.compute_and_save(predictions, targets, batch, i_batch, test_num)
print(out)
df = df.append(out, ignore_index=True)
df_mean = df.mean(axis=0)
df_std = df.std(axis=0)
print(df_mean)
print(df_std)
df_mean.to_csv('testResults/'+FLAGS.model+'/mean.csv')
df_std.to_csv('testResults/' + FLAGS.model + '/std.csv')
if __name__ == "__main__":
app.run(main)