-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathupgrade_model_version.py
42 lines (33 loc) · 1.29 KB
/
upgrade_model_version.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import argparse
import re
from collections import OrderedDict
import torch
def convert(in_file, out_file):
"""Convert keys in checkpoints.
There can be some breaking changes during the development of mmdetection,
and this tool is used for upgrading checkpoints trained with old versions
to the latest one.
"""
checkpoint = torch.load(in_file)
in_state_dict = checkpoint.pop('state_dict')
out_state_dict = OrderedDict()
for key, val in in_state_dict.items():
# Use ConvModule instead of nn.Conv2d in RetinaNet
# cls_convs.0.weight -> cls_convs.0.conv.weight
m = re.search(r'(cls_convs|reg_convs).\d.(weight|bias)', key)
if m is not None:
param = m.groups()[1]
new_key = key.replace(param, 'conv.{}'.format(param))
out_state_dict[new_key] = val
continue
out_state_dict[key] = val
checkpoint['state_dict'] = out_state_dict
torch.save(checkpoint, out_file)
def main():
parser = argparse.ArgumentParser(description='Upgrade model version')
parser.add_argument('in_file', help='input checkpoint file')
parser.add_argument('out_file', help='output checkpoint file')
args = parser.parse_args()
convert(args.in_file, args.out_file)
if __name__ == '__main__':
main()