Skip to content

Commit

Permalink
improved printing
Browse files Browse the repository at this point in the history
  • Loading branch information
abcamiletto committed Aug 31, 2023
1 parent 60f7e7b commit 9f7d82f
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 9 deletions.
37 changes: 29 additions & 8 deletions flopper/flopper.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from functools import partial
from typing import Any, Dict, List

import torch.nn as nn
from fvcore.nn import FlopCountAnalysis
from fvcore.nn import FlopCountAnalysis, flop_count_table

from .custom_ops import FLOPPER_OPS

Expand All @@ -23,8 +22,14 @@ def count_flops(model: nn.Module, *args, silent: bool = False, **kwargs) -> int:
>>> import torch
>>> from flopper import count_flops
>>> model = torch.nn.Linear(10, 10)
>>> count_flops(model, torch.randn(1, 10))
FLOPs of the model Linear: 0.00 GFLOPs
>>> flops = count_flops(model, torch.randn(1, 10))
FLOPs of the model Linear: X.XX GFLOPs
Get more specific information about the FLOPs with the following methods:
>>> flops.by_operator()
>>> flops.by_module()
>>> flops.by_module_and_operator()
>>> flops.get_table()
"""
if kwargs is not None:
model.forward = partial(model.forward, **kwargs) # To be tested
Expand All @@ -33,7 +38,23 @@ def count_flops(model: nn.Module, *args, silent: bool = False, **kwargs) -> int:
flops = FlopCountAnalysis(model, args).set_op_handle(**FLOPPER_OPS)

if not silent:
gflops = flops.total() / 1e9
print(f"FLOPs of the model {model.__class__.__name__}: {gflops:.2f} GFLOPs")

return flops.total()
flops_str = smart_format(flops)
name = model.__class__.__name__
print(f"FLOPs of the model {name}: {flops_str}FLOPs")

flops.get_table = lambda: flop_count_table(flops)

return flops


def smart_format(num, decimals=2):
if num > 1e12:
return f"{num / 1e12:.{decimals}f} T"
elif num > 1e9:
return f"{num / 1e9:.{decimals}f} G"
elif num > 1e6:
return f"{num / 1e6:.{decimals}f} M"
elif num > 1e3:
return f"{num / 1e3:.{decimals}f} K"
else:
return f"{num:.{decimals}f }"
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ authors = [
]
dependencies = ["fvcore>=0.1.5", "torch>=1.8"]
requires-python = ">=3.8"
version = "0.1.1"
version = "0.1.2"

[project.urls]
Homepage = "https://github.com/abcamiletto/flopper"
Expand Down

0 comments on commit 9f7d82f

Please sign in to comment.