diff --git a/flopper/flopper.py b/flopper/flopper.py index eb8206c..3794839 100644 --- a/flopper/flopper.py +++ b/flopper/flopper.py @@ -1,8 +1,7 @@ from functools import partial -from typing import Any, Dict, List import torch.nn as nn -from fvcore.nn import FlopCountAnalysis +from fvcore.nn import FlopCountAnalysis, flop_count_table from .custom_ops import FLOPPER_OPS @@ -23,8 +22,14 @@ def count_flops(model: nn.Module, *args, silent: bool = False, **kwargs) -> int: >>> import torch >>> from flopper import count_flops >>> model = torch.nn.Linear(10, 10) - >>> count_flops(model, torch.randn(1, 10)) - FLOPs of the model Linear: 0.00 GFLOPs + >>> flops = count_flops(model, torch.randn(1, 10)) + FLOPs of the model Linear: X.XX GFLOPs + + Get more specific information about the FLOPs with the following methods: + >>> flops.by_operator() + >>> flops.by_module() + >>> flops.by_module_and_operator() + >>> flops.get_table() """ if kwargs is not None: model.forward = partial(model.forward, **kwargs) # To be tested @@ -33,7 +38,23 @@ def count_flops(model: nn.Module, *args, silent: bool = False, **kwargs) -> int: flops = FlopCountAnalysis(model, args).set_op_handle(**FLOPPER_OPS) if not silent: - gflops = flops.total() / 1e9 - print(f"FLOPs of the model {model.__class__.__name__}: {gflops:.2f} GFLOPs") - - return flops.total() + flops_str = smart_format(flops) + name = model.__class__.__name__ + print(f"FLOPs of the model {name}: {flops_str}FLOPs") + + flops.get_table = lambda: flop_count_table(flops) + + return flops + + +def smart_format(num, decimals=2): + if num > 1e12: + return f"{num / 1e12:.{decimals}f} T" + elif num > 1e9: + return f"{num / 1e9:.{decimals}f} G" + elif num > 1e6: + return f"{num / 1e6:.{decimals}f} M" + elif num > 1e3: + return f"{num / 1e3:.{decimals}f} K" + else: + return f"{num:.{decimals}f }" diff --git a/pyproject.toml b/pyproject.toml index 60fdde0..bc255c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ authors = [ ] dependencies = ["fvcore>=0.1.5", "torch>=1.8"] requires-python = ">=3.8" -version = "0.1.1" +version = "0.1.2" [project.urls] Homepage = "https://github.com/abcamiletto/flopper"