Skip to content

Commit

Permalink
modify the model_split.py file
Browse files Browse the repository at this point in the history
``` shell
python demo/top_down_img_demo.py configs/wholebody/2d_kpt_sview_rgb_img/topdown_heatmap/coco-wholebody/ViTPose_base_wholebody_256x192.py target/wholebody.pth  --img-root tests/data/coco/ --json-file tests/data/coco/test_coco.json --out-img-root vis_results
```
When I use the previous model_split.py, above command occurs error, and the modified version can work properly.
  • Loading branch information
seaman1900 authored Feb 25, 2023
1 parent 4fd8507 commit e871213
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions tools/model_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def main():
value = torch.cat([value, experts[key.replace('fc2.', f'experts.{target_expert}.')]], dim=0)
new_ckpt['state_dict'][key] = value

torch.save(new_ckpt, os.path.join(args.targetPath, 'coco.pth'))
torch.save(new_ckpt, os.path.join(args.target, 'coco.pth'))

names = ['aic', 'mpii', 'ap10k', 'apt36k','wholebody']
num_keypoints = [14, 16, 17, 17, 133]
Expand Down Expand Up @@ -86,8 +86,19 @@ def main():

for tensor_name in ['keypoint_head.final_layer.weight', 'keypoint_head.final_layer.bias']:
new_ckpt['state_dict'][tensor_name] = new_ckpt['state_dict'][tensor_name][:num_keypoints[i]]


# remove unnecessary part in the state dict
for j in range(5):
# remove associate part
for tensor_name in weight_names:
new_ckpt['state_dict'].pop(tensor_name.replace('keypoint_head', f'associate_keypoint_heads.{j}'))
# remove expert part
keys = new_ckpt['state_dict'].keys()
for key in list(keys):
if 'expert' in keys:
new_ckpt['state_dict'].pop(key)

torch.save(new_ckpt, os.path.join(args.target, f'{names[i]}.pth'))

if __name__ == '__main__':
main()
main()

0 comments on commit e871213

Please sign in to comment.