diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index a6718a9a19..dee68bdb0c 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -418,9 +418,8 @@ def add_ub( if name in ub_cfgs: final_cfg = get_default_config(name) final_cfg.update(ub_cfgs[name]) - final_cfg["fp8_buf"] = ( - (name in layers_all_gather_overlap) - or ub_cfgs[name].get("fp8_buf", False) + final_cfg["fp8_buf"] = (name in layers_all_gather_overlap) or ub_cfgs[name].get( + "fp8_buf", False ) add_ub(name, **final_cfg)