Skip to content

Commit

Permalink
fix flops counter
Browse files Browse the repository at this point in the history
  • Loading branch information
ycszen committed Apr 20, 2021
1 parent b4c0ae4 commit 0ff7560
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 14 deletions.
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -232,13 +232,12 @@ Examples:
### Get the compulationaly complexity
You can use the following commands to compute the complexity of one model.
```shell
python tools/summary_network.py ${CONFIG_FILE} --shape ${SHAPE} [--with-head]
python tools/summary_network.py ${CONFIG_FILE} --shape ${SHAPE}
```

Arguments:

- `SHAPE`: Input size.
- `--with-head`: If specified, the computed complexity contains the complexity of the pose head.

Examples:

Expand All @@ -247,7 +246,6 @@ Examples:
```shell
python tools/summary_network.py configs/top_down/lite_hrnet/coco/litehrnet_18_coco_256x192.py \
--shape 256 256 \
--with-head
```

## Acknowledgement
Expand Down
16 changes: 6 additions & 10 deletions tools/summary_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@

from mmcv import Config
# from mmcv.cnn import get_model_complexity_info
from tools.torchstat_utils import model_stats
from torchstat_utils import model_stats

from mmpose.models import build_posenet
import sys
sys.path.append('.')
from models import build_posenet


def parse_args():
Expand All @@ -17,10 +19,6 @@ def parse_args():
nargs='+',
default=[2048, 1024],
help='input image size')
parser.add_argument(
'--with-head',
action='store_true',
help='whether to compute the complexity of the deconv head.')
parser.add_argument('--out-file', type=str,
help='Output file name')
args = parser.parse_args()
Expand All @@ -42,10 +40,8 @@ def main():
model = build_posenet(cfg.model)
model.eval()

if args.with_head and hasattr(model, 'forward_with_head'):
model.forward = model.forward_with_head
elif not args.with_head and hasattr(model, 'forward_without_head'):
model.forward = model.forward_without_head
if hasattr(model, 'forward_dummy'):
model.forward = model.forward_dummy
else:
raise NotImplementedError(
'FLOPs counter is currently not currently supported with {}'.
Expand Down
2 changes: 1 addition & 1 deletion tools/torchstat_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from .torchstat import analyze
from torchstat import analyze
import pandas as pd
import copy

Expand Down

0 comments on commit 0ff7560

Please sign in to comment.