diff --git a/deepspeed/profiling/flops_profiler/profiler.py b/deepspeed/profiling/flops_profiler/profiler.py index f87f1beb7e4e..c6b607aad813 100644 --- a/deepspeed/profiling/flops_profiler/profiler.py +++ b/deepspeed/profiling/flops_profiler/profiler.py @@ -16,6 +16,7 @@ from deepspeed.moe.layer import MoE from deepspeed.utils.timer import FORWARD_GLOBAL_TIMER, BACKWARD_GLOBAL_TIMER, STEP_GLOBAL_TIMER from deepspeed.utils.torch import required_torch_version +import einops Tensor = torch.Tensor @@ -82,6 +83,7 @@ def start_profile(self, ignore_list=None): self.reset_profile() _patch_functionals() _patch_tensor_methods() + _patch_miscellaneous_operations() def register_module_hooks(module, ignore_list): if ignore_list and type(module) in ignore_list: @@ -137,6 +139,7 @@ def stop_profile(self): if self.started and self.func_patched: _reload_functionals() _reload_tensor_methods() + _reload_miscellaneous_operations() self.func_patched = False def remove_profile_attrs(module): @@ -787,6 +790,29 @@ def _einsum_flops_compute(equation, *operands): raise NotImplementedError("Unsupported einsum operation.") +def _einops_einsum_flops_compute(*args): + """ + Count flops for the einops.einsum operation. + """ + *operands, equation = args + input_shapes = [o.shape for o in operands] + + # Re-map equation so that same equation with different alphabet + # representations will look the same. + letter_order = OrderedDict((k, 0) for k in equation if k.isalpha()).keys() + mapping = {ord(x): 97 + i for i, x in enumerate(letter_order)} + equation = equation.translate(mapping) + + np_arrs = [np.zeros(s) for s in input_shapes] + optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1] + for line in optim.split("\n"): + if "optimized flop" in line.lower(): + flop = int(float(line.split(":")[-1])) + return flop, 0 + + raise NotImplementedError("Unsupported einops.einsum operation.") + + def _tensor_addmm_flops_compute(self, mat1, mat2, *, beta=1, alpha=1, out=None): """ Count flops for the tensor addmm operation. @@ -937,6 +963,10 @@ def _patch_tensor_methods(): torch.baddbmm = wrapFunc(torch.baddbmm, _tensor_addmm_flops_compute) +def _patch_miscellaneous_operations(): + einops.einsum = wrapFunc(einops.einsum, _einops_einsum_flops_compute) + + def _reload_functionals(): # torch.nn.functional does not support importlib.reload() F.linear = old_functions[F.linear.__str__] @@ -995,6 +1025,10 @@ def _reload_tensor_methods(): torch.baddbmm = old_functions[torch.baddbmm.__str__] +def _reload_miscellaneous_operations(): + einops.einsum = old_functions[einops.einsum.__str__] + + def _rnn_flops(flops, rnn_module, w_ih, w_hh, input_size): gates_size = w_ih.shape[0] # matrix matrix mult ih state and internal state