-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathflops.py
48 lines (43 loc) · 2.66 KB
/
flops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import argparse
import os
import torch
from thop import profile
from thop import clever_format
import utils
if __name__ == "__main__":
# ----------------------------------------
# Initialize the parameters
# ----------------------------------------
parser = argparse.ArgumentParser()
# Pre-train, saving, and loading parameters
parser.add_argument('--pre_train', type = bool, default = True, help = 'pre_train or not')
parser.add_argument('--load_name', type = str, default = './track2/G_epoch10000_bs8.pth', \
help = 'load the pre-trained model with certain epoch, None for pre-training')
parser.add_argument('--test_batch_size', type = int, default = 1, help = 'size of the testing batches for single GPU')
parser.add_argument('--num_workers', type = int, default = 2, help = 'number of cpu threads to use during batch generation')
parser.add_argument('--val_path', type = str, default = './validation', help = 'saving path that is a folder')
parser.add_argument('--task_name', type = str, default = 'track2', help = 'task name for loading networks, saving, and log')
# Network initialization parameters
parser.add_argument('--pad', type = str, default = 'reflect', help = 'pad type of networks')
parser.add_argument('--activ', type = str, default = 'lrelu', help = 'activation type of networks')
parser.add_argument('--norm', type = str, default = 'none', help = 'normalization type of networks')
parser.add_argument('--in_channels', type = int, default = 3, help = 'input channels for generator')
parser.add_argument('--out_channels', type = int, default = 31, help = 'output channels for generator')
parser.add_argument('--start_channels', type = int, default = 64, help = 'start channels for generator')
parser.add_argument('--init_type', type = str, default = 'xavier', help = 'initialization type of generator')
parser.add_argument('--init_gain', type = float, default = 0.02, help = 'initialization gain of generator')
# Dataset parameters
parser.add_argument('--baseroot', type = str, default = './NTIRE2020_Validation_RealWorld', help = 'baseroot')
# NTIRE2020_Validation_Clean NTIRE2020_Validation_RealWorld
opt = parser.parse_args()
# ----------------------------------------
# Test
# ----------------------------------------
# Initialize
generator = utils.create_generator(opt).cuda()
# Forward
input = torch.randn(1, 3, 256, 256).cuda()
macs, params = profile(generator, inputs = (input, ))
macs, params = clever_format([macs, params], "%.3f")
print(macs)
print(params)