From 121ff62af8c6938c9cdab15e04e62c97ac524264 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 6 Aug 2024 10:00:20 -0700 Subject: [PATCH] [PyTorch] Improve logging/messaging in attention (#1074) * fix logging in attention Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove logging in fwd/bwd methods due to CPU overhead Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * WIP: fix check_set_window_size messaging Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix typo Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix window_size messaging Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove redundant imports Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 60 +++++++------------ .../pytorch/module/grouped_linear.py | 23 ------- .../pytorch/module/layernorm_linear.py | 22 ------- transformer_engine/pytorch/module/linear.py | 23 ------- 4 files changed, 22 insertions(+), 106 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 0790315400..7586cc1bcb 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -98,12 +98,12 @@ _NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) # NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0 _NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0")) -log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL -log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG} -logging.basicConfig( - format="[%(levelname)-8s | %(name)-19s]: %(message)s", - level=log_levels[log_level if log_level in [0, 1, 2] else 2], -) +_log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL +_log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG} +_log_level = _log_levels[_log_level if _log_level in [0, 1, 2] else 2] +_formatter = logging.Formatter("[%(levelname)-8s | %(name)-19s]: %(message)s") +_stream_handler = logging.StreamHandler() +_stream_handler.setFormatter(_formatter) _NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1")) _NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1")) @@ -266,6 +266,9 @@ def get_attention_backend( # Run config logger = logging.getLogger("DotProductAttention") + logger.setLevel(_log_level) + if not logger.hasHandlers(): + logger.addHandler(_stream_handler) device_compute_capability = get_device_compute_capability() cudnn_version = get_cudnn_version() run_config = { @@ -3236,31 +3239,28 @@ def check_set_window_size( """ orig_window_size = window_size if "causal" in attn_mask_type: - if orig_window_size is None or ( - orig_window_size[0] == -1 and orig_window_size[1] in [-1, 0] - ): + if orig_window_size is None: window_size = (-1, 0) - warnings.warn( - "window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type - ) - elif orig_window_size[0] >= 0: + elif orig_window_size == (-1, -1) or ( + orig_window_size[0] >= 0 and orig_window_size[1] != 0 + ): window_size = (orig_window_size[0], 0) warnings.warn( "window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type ) - else: + elif orig_window_size != (-1, 0) and (orig_window_size[0] < 0 or orig_window_size[1] != 0): assert False, ( "window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type ) elif attn_mask_type in ["no_mask", "padding", "arbitrary"]: - if orig_window_size is None or ( - orig_window_size[0] == -1 and orig_window_size[1] in [-1, 0] - ): + if orig_window_size is None: + window_size = (-1, -1) + elif orig_window_size == (-1, 0): window_size = (-1, -1) warnings.warn( "window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type ) - elif orig_window_size[0] < 0 or orig_window_size[1] < 0: + elif orig_window_size != (-1, -1) and (orig_window_size[0] < 0 or orig_window_size[1] < 0): assert False, ( "window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type ) @@ -3560,9 +3560,7 @@ def forward( fp8_meta, deterministic, ): - logger = logging.getLogger("FusedAttnFunc_qkvpacked") if fp8: - logger.debug("Running forward in FP8") if fp8_meta["recipe"].fp8_mha: assert isinstance(qkv, Float8Tensor), "qkv must be Float8Tensors for FP8 MHA." fp8_meta["scaling_fwd"].scale_inv[META_QKV] = qkv._scale_inv @@ -3646,7 +3644,6 @@ def forward( fp8_meta["scaling_fwd"].scale_inv.clone(), ) else: - logger.debug("Running forward in %s", qkv.dtype) out_ret, aux_ctx_tensors = fused_attn_fwd_qkvpacked( is_training, max_seqlen, @@ -3699,7 +3696,6 @@ def forward( @staticmethod def backward(ctx, d_out): - logger = logging.getLogger("FusedAttnFunc_qkvpacked") if ctx.fp8_meta["recipe"].fp8_mha: assert isinstance( d_out, Float8Tensor @@ -3753,7 +3749,6 @@ def backward(ctx, d_out): else: with torch.cuda.nvtx.range("_FusedAttn_qkvpacked"): if ctx.fp8: - logger.debug("Running backward in FP8") fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_backward = get_fp8_te_dtype( ctx.fp8_meta["recipe"], fprop_tensor=False @@ -3819,7 +3814,6 @@ def backward(ctx, d_out): ctx.qkv_dtype, ).view(dqkv_fp8.shape) else: - logger.debug("Running backward in %s", qkv.dtype) if d_out.dtype == torch.uint8: d_out = d_out_f8tensor.from_float8(qkv.dtype) dqkv, *rest = fused_attn_bwd_qkvpacked( @@ -3937,9 +3931,7 @@ def forward( fp8_meta, deterministic, ): - logger = logging.getLogger("FusedAttnFunc_kvpacked") if fp8: - logger.debug("Running forward in FP8") if fp8_meta["recipe"].fp8_mha: assert isinstance(q, Float8Tensor) and isinstance( kv, Float8Tensor @@ -4036,7 +4028,6 @@ def forward( fp8_meta["scaling_fwd"].scale_inv.clone(), ) else: - logger.debug("Running forward in %s", q.dtype) out_ret, aux_ctx_tensors = fused_attn_fwd_kvpacked( is_training, max_seqlen_q, @@ -4100,7 +4091,6 @@ def forward( @staticmethod def backward(ctx, d_out): - logger = logging.getLogger("FusedAttnFunc_kvpacked") if ctx.fp8_meta["recipe"].fp8_mha: assert isinstance( d_out, Float8Tensor @@ -4158,7 +4148,6 @@ def backward(ctx, d_out): else: with torch.cuda.nvtx.range("_FusedAttn_kvpacked"): if ctx.fp8: - logger.debug("Running backward in FP8") fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_backward = get_fp8_te_dtype( ctx.fp8_meta["recipe"], fprop_tensor=False @@ -4243,7 +4232,6 @@ def backward(ctx, d_out): ctx.qkv_dtype, ).view(dkv_fp8.shape) else: - logger.debug("Running backward in %s", q.dtype) if d_out.dtype == torch.uint8: d_out = d_out_f8tensor.from_float8(q.dtype) dq, dkv, *rest = fused_attn_bwd_kvpacked( @@ -4374,9 +4362,7 @@ def forward( fp8_meta, deterministic, ): - logger = logging.getLogger("FusedAttnFunc") if fp8: - logger.debug("Running forward in FP8") fused_attention_backend = FusedAttnBackend["FP8"] fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) if fp8_meta["recipe"].fp8_mha: @@ -4544,7 +4530,6 @@ def forward( fp8_meta["scaling_fwd"].scale_inv.clone(), ) else: - logger.debug("Running forward in %s", q.dtype) out_ret, aux_ctx_tensors = fused_attn_fwd( is_training, max_seqlen_q, @@ -4618,7 +4603,6 @@ def forward( @staticmethod def backward(ctx, d_out): - logger = logging.getLogger("FusedAttnFunc") if ctx.fp8_meta["recipe"].fp8_mha: assert isinstance( d_out, Float8Tensor @@ -4680,7 +4664,6 @@ def backward(ctx, d_out): else: with torch.cuda.nvtx.range("_FusedAttn"): if ctx.fp8: - logger.debug("Running backward in FP8") fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_backward = get_fp8_te_dtype( ctx.fp8_meta["recipe"], fprop_tensor=False @@ -4818,7 +4801,6 @@ def backward(ctx, d_out): ctx.qkv_dtype, ).view(dv_fp8.shape) else: - logger.debug("Running backward in %s", q.dtype) if d_out.dtype == torch.uint8: d_out = d_out_f8tensor.from_float8(q.dtype) dq, dk, dv, *rest = fused_attn_bwd( @@ -4959,7 +4941,6 @@ def __init__( ) -> None: super().__init__() - self.logger = logging.getLogger("FusedAttention") self.softmax_scale = softmax_scale self.attention_dropout = attention_dropout self.attention_dropout_ctx = attention_dropout_ctx @@ -5306,6 +5287,9 @@ def __init__( super().__init__() self.logger = logging.getLogger("DotProductAttention") + self.logger.setLevel(_log_level) + if not self.logger.hasHandlers(): + self.logger.addHandler(_stream_handler) self.qkv_format = qkv_format attn_mask_type = attn_mask_type.replace(",", "_") if attn_mask_type == "causal_padding": @@ -5619,7 +5603,7 @@ def forward( if self.fp8_meta["recipe"].fp8_mha: if not self.fp8_meta["recipe"].fp8_dpa: self.fp8_meta["recipe"].fp8_dpa = True - self.logger.WARNING( + self.logger.warning( """Forcing fp8_meta["recipe"].fp8_dpa=True due to """ """fp8_meta["recipe"].fp8_mha=True""" ) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 8aeb068412..c55225eed9 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -3,8 +3,6 @@ # See LICENSE for license information. """GroupedLinear API""" -import os -import logging from typing import Union, Optional, Callable, Tuple, List, Dict, Any import torch @@ -45,17 +43,6 @@ from ..graph import is_graph_capturing from ..float8_tensor import Float8Tensor -# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0 -_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) -# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0 -_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0")) -log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL -log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG} -logging.basicConfig( - format="[%(levelname)-8s | %(name)-19s]: %(message)s", - level=log_levels[log_level if log_level in [0, 1, 2] else 2], -) - __all__ = ["GroupedLinear"] """ @@ -97,7 +84,6 @@ def forward( is_grad_enabled: bool, *weights_and_biases: Union[Float8Tensor, torch.Tensor, None], ) -> torch.Tensor: - logger = logging.getLogger("GroupedLinear") num_gemms = len(m_splits) weights = weights_and_biases[:num_gemms] weights_fp8 = weights_and_biases[num_gemms : 2 * num_gemms] @@ -151,8 +137,6 @@ def forward( inputmats = inputmats_no_fp8 if fp8: - logger.debug("Running forward in FP8") - bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype biases = [cast_if_needed(bias, bias_dtype) for bias in biases] if use_bias else biases @@ -184,8 +168,6 @@ def forward( use_split_accumulator=_2X_ACC_FPROP, ) else: - logger.debug("Running forward in %s", activation_dtype) - # Cast for native AMP weights = [cast_if_needed(w, activation_dtype) for w in weights] biases = ( @@ -286,8 +268,6 @@ def forward( @staticmethod def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: - logger = logging.getLogger("GroupedLinear") - with torch.cuda.nvtx.range("_GroupedLinear_backward"): ( fwd_scale_inverses, @@ -353,7 +333,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.requires_dgrad: if ctx.fp8: - logger.debug("Running backward in FP8") dgrad = torch.empty( (sum(ctx.m_splits), weights_fp8[i].size(1)), dtype=ctx.activation_dtype, @@ -376,8 +355,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], use_split_accumulator=_2X_ACC_DGRAD, ) else: - logger.debug("Running backward in %s", ctx.activation_dtype) - dgrad = torch.empty( (sum(ctx.m_splits), weights[0].size(1)), dtype=ctx.activation_dtype, diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 262c6f8d16..10560cdad6 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -5,7 +5,6 @@ """LayerNormLinear API""" import os import warnings -import logging from typing import Any, Callable, Dict, Optional, Tuple, Union import torch @@ -48,17 +47,6 @@ from ._common import _apply_normalization, _noop_cat from ..float8_tensor import Float8Tensor -# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0 -_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) -# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0 -_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0")) -log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL -log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG} -logging.basicConfig( - format="[%(levelname)-8s | %(name)-19s]: %(message)s", - level=log_levels[log_level if log_level in [0, 1, 2] else 2], -) - __all__ = ["LayerNormLinear"] @@ -104,7 +92,6 @@ def forward( ub_name: str, fsdp_group: Union[dist_group_type, None], ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: - logger = logging.getLogger("LayerNormLinear") # Make sure input dimensions are compatible in_features = ln_weight.numel() assert inp.shape[-1] == in_features, "GEMM not possible" @@ -203,8 +190,6 @@ def forward( ln_out = ln_out_total if fp8: - logger.debug("Running forward in FP8") - bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype bias = cast_if_needed(bias, bias_dtype) if use_bias else bias @@ -259,8 +244,6 @@ def forward( dtype=activation_dtype, ) else: - logger.debug("Running forward in %s", activation_dtype) - # Cast for native AMP weight = cast_if_needed(weight, activation_dtype) bias = cast_if_needed(bias, activation_dtype) if use_bias else bias @@ -379,7 +362,6 @@ def forward( def backward( ctx, *grad_outputs: Tuple[torch.Tensor, ...] ) -> Tuple[Union[torch.Tensor, None], ...]: - logger = logging.getLogger("LayerNormLinear") if isinstance(grad_outputs[0], Float8Tensor): ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT1] = grad_outputs[ 0 @@ -500,8 +482,6 @@ def backward( ub_obj = None if ctx.fp8: - logger.debug("Running backward in FP8") - fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) out_index, meta_tensor, out_te_type, out_type = ( @@ -544,8 +524,6 @@ def backward( ) clear_tensor_data(grad_output_c) else: - logger.debug("Running backward in %s", ctx.activation_dtype) - # DGRAD: Evaluated unconditionally to feed into Linear backward _, _, _ = tex.gemm( weight, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 7510254a9d..68d333262d 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -3,8 +3,6 @@ # See LICENSE for license information. """Linear API""" -import os -import logging from typing import Any, Callable, Dict, Optional, Tuple, Union import torch @@ -51,17 +49,6 @@ from ..graph import is_graph_capturing from ..float8_tensor import Float8Tensor -# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0 -_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) -# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0 -_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0")) -log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL -log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG} -logging.basicConfig( - format="[%(levelname)-8s | %(name)-19s]: %(message)s", - level=log_levels[log_level if log_level in [0, 1, 2] else 2], -) - __all__ = ["Linear"] @@ -97,7 +84,6 @@ def forward( is_first_module_in_mha: bool, fsdp_group: Union[dist_group_type, None], ) -> torch.Tensor: - logger = logging.getLogger("Linear") is_input_fp8 = isinstance(inp, Float8Tensor) if is_input_fp8: fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM1_INPUT] = inp._scale_inv[0] @@ -158,8 +144,6 @@ def forward( else: inputmat_total = inputmat if fp8: - logger.debug("Running forward in FP8") - bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype bias = cast_if_needed(bias, bias_dtype) if use_bias else bias @@ -248,8 +232,6 @@ def forward( dtype=activation_dtype, ) else: - logger.debug("Running forward in %s", activation_dtype) - # Cast for native AMP weight = cast_if_needed(weight, activation_dtype) bias = cast_if_needed(bias, activation_dtype) if use_bias else bias @@ -373,7 +355,6 @@ def forward( @staticmethod def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: - logger = logging.getLogger("Linear") if isinstance(grad_output, Float8Tensor): ctx.fp8_meta["scaling_bwd"].scale_inv[ tex.FP8BwdTensors.GRAD_OUTPUT1 @@ -450,8 +431,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.requires_dgrad: if ctx.fp8: - logger.debug("Running backward in FP8") - if ctx.is_input_fp8: out_index, meta_tensor, output_te_dtype, output_dtype = ( tex.FP8BwdTensors.GRAD_INPUT1, @@ -494,8 +473,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dtype=ctx.activation_dtype, ) else: - logger.debug("Running backward in %s", ctx.activation_dtype) - dgrad, _, _ = gemm( weight, grad_output,