Skip to content

Commit

Permalink
[PyTorch] Improve logging/messaging in attention (NVIDIA#1074)
Browse files Browse the repository at this point in the history
* fix logging in attention

Signed-off-by: Charlene Yang <[email protected]>

* remove logging in fwd/bwd methods due to CPU overhead

Signed-off-by: Charlene Yang <[email protected]>

* WIP: fix check_set_window_size messaging

Signed-off-by: Charlene Yang <[email protected]>

* fix typo

Signed-off-by: Charlene Yang <[email protected]>

* fix window_size messaging

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove redundant imports

Signed-off-by: Charlene Yang <[email protected]>

---------

Signed-off-by: Charlene Yang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
cyanguwa and pre-commit-ci[bot] authored Aug 6, 2024
1 parent 5bb3a41 commit 121ff62
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 106 deletions.
60 changes: 22 additions & 38 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,12 @@
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
logging.basicConfig(
format="[%(levelname)-8s | %(name)-19s]: %(message)s",
level=log_levels[log_level if log_level in [0, 1, 2] else 2],
)
_log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
_log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
_log_level = _log_levels[_log_level if _log_level in [0, 1, 2] else 2]
_formatter = logging.Formatter("[%(levelname)-8s | %(name)-19s]: %(message)s")
_stream_handler = logging.StreamHandler()
_stream_handler.setFormatter(_formatter)

_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1"))
_NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1"))
Expand Down Expand Up @@ -266,6 +266,9 @@ def get_attention_backend(

# Run config
logger = logging.getLogger("DotProductAttention")
logger.setLevel(_log_level)
if not logger.hasHandlers():
logger.addHandler(_stream_handler)
device_compute_capability = get_device_compute_capability()
cudnn_version = get_cudnn_version()
run_config = {
Expand Down Expand Up @@ -3236,31 +3239,28 @@ def check_set_window_size(
"""
orig_window_size = window_size
if "causal" in attn_mask_type:
if orig_window_size is None or (
orig_window_size[0] == -1 and orig_window_size[1] in [-1, 0]
):
if orig_window_size is None:
window_size = (-1, 0)
warnings.warn(
"window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type
)
elif orig_window_size[0] >= 0:
elif orig_window_size == (-1, -1) or (
orig_window_size[0] >= 0 and orig_window_size[1] != 0
):
window_size = (orig_window_size[0], 0)
warnings.warn(
"window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type
)
else:
elif orig_window_size != (-1, 0) and (orig_window_size[0] < 0 or orig_window_size[1] != 0):
assert False, (
"window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type
)
elif attn_mask_type in ["no_mask", "padding", "arbitrary"]:
if orig_window_size is None or (
orig_window_size[0] == -1 and orig_window_size[1] in [-1, 0]
):
if orig_window_size is None:
window_size = (-1, -1)
elif orig_window_size == (-1, 0):
window_size = (-1, -1)
warnings.warn(
"window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type
)
elif orig_window_size[0] < 0 or orig_window_size[1] < 0:
elif orig_window_size != (-1, -1) and (orig_window_size[0] < 0 or orig_window_size[1] < 0):
assert False, (
"window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type
)
Expand Down Expand Up @@ -3560,9 +3560,7 @@ def forward(
fp8_meta,
deterministic,
):
logger = logging.getLogger("FusedAttnFunc_qkvpacked")
if fp8:
logger.debug("Running forward in FP8")
if fp8_meta["recipe"].fp8_mha:
assert isinstance(qkv, Float8Tensor), "qkv must be Float8Tensors for FP8 MHA."
fp8_meta["scaling_fwd"].scale_inv[META_QKV] = qkv._scale_inv
Expand Down Expand Up @@ -3646,7 +3644,6 @@ def forward(
fp8_meta["scaling_fwd"].scale_inv.clone(),
)
else:
logger.debug("Running forward in %s", qkv.dtype)
out_ret, aux_ctx_tensors = fused_attn_fwd_qkvpacked(
is_training,
max_seqlen,
Expand Down Expand Up @@ -3699,7 +3696,6 @@ def forward(

@staticmethod
def backward(ctx, d_out):
logger = logging.getLogger("FusedAttnFunc_qkvpacked")
if ctx.fp8_meta["recipe"].fp8_mha:
assert isinstance(
d_out, Float8Tensor
Expand Down Expand Up @@ -3753,7 +3749,6 @@ def backward(ctx, d_out):
else:
with torch.cuda.nvtx.range("_FusedAttn_qkvpacked"):
if ctx.fp8:
logger.debug("Running backward in FP8")
fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False
Expand Down Expand Up @@ -3819,7 +3814,6 @@ def backward(ctx, d_out):
ctx.qkv_dtype,
).view(dqkv_fp8.shape)
else:
logger.debug("Running backward in %s", qkv.dtype)
if d_out.dtype == torch.uint8:
d_out = d_out_f8tensor.from_float8(qkv.dtype)
dqkv, *rest = fused_attn_bwd_qkvpacked(
Expand Down Expand Up @@ -3937,9 +3931,7 @@ def forward(
fp8_meta,
deterministic,
):
logger = logging.getLogger("FusedAttnFunc_kvpacked")
if fp8:
logger.debug("Running forward in FP8")
if fp8_meta["recipe"].fp8_mha:
assert isinstance(q, Float8Tensor) and isinstance(
kv, Float8Tensor
Expand Down Expand Up @@ -4036,7 +4028,6 @@ def forward(
fp8_meta["scaling_fwd"].scale_inv.clone(),
)
else:
logger.debug("Running forward in %s", q.dtype)
out_ret, aux_ctx_tensors = fused_attn_fwd_kvpacked(
is_training,
max_seqlen_q,
Expand Down Expand Up @@ -4100,7 +4091,6 @@ def forward(

@staticmethod
def backward(ctx, d_out):
logger = logging.getLogger("FusedAttnFunc_kvpacked")
if ctx.fp8_meta["recipe"].fp8_mha:
assert isinstance(
d_out, Float8Tensor
Expand Down Expand Up @@ -4158,7 +4148,6 @@ def backward(ctx, d_out):
else:
with torch.cuda.nvtx.range("_FusedAttn_kvpacked"):
if ctx.fp8:
logger.debug("Running backward in FP8")
fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False
Expand Down Expand Up @@ -4243,7 +4232,6 @@ def backward(ctx, d_out):
ctx.qkv_dtype,
).view(dkv_fp8.shape)
else:
logger.debug("Running backward in %s", q.dtype)
if d_out.dtype == torch.uint8:
d_out = d_out_f8tensor.from_float8(q.dtype)
dq, dkv, *rest = fused_attn_bwd_kvpacked(
Expand Down Expand Up @@ -4374,9 +4362,7 @@ def forward(
fp8_meta,
deterministic,
):
logger = logging.getLogger("FusedAttnFunc")
if fp8:
logger.debug("Running forward in FP8")
fused_attention_backend = FusedAttnBackend["FP8"]
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if fp8_meta["recipe"].fp8_mha:
Expand Down Expand Up @@ -4544,7 +4530,6 @@ def forward(
fp8_meta["scaling_fwd"].scale_inv.clone(),
)
else:
logger.debug("Running forward in %s", q.dtype)
out_ret, aux_ctx_tensors = fused_attn_fwd(
is_training,
max_seqlen_q,
Expand Down Expand Up @@ -4618,7 +4603,6 @@ def forward(

@staticmethod
def backward(ctx, d_out):
logger = logging.getLogger("FusedAttnFunc")
if ctx.fp8_meta["recipe"].fp8_mha:
assert isinstance(
d_out, Float8Tensor
Expand Down Expand Up @@ -4680,7 +4664,6 @@ def backward(ctx, d_out):
else:
with torch.cuda.nvtx.range("_FusedAttn"):
if ctx.fp8:
logger.debug("Running backward in FP8")
fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False
Expand Down Expand Up @@ -4818,7 +4801,6 @@ def backward(ctx, d_out):
ctx.qkv_dtype,
).view(dv_fp8.shape)
else:
logger.debug("Running backward in %s", q.dtype)
if d_out.dtype == torch.uint8:
d_out = d_out_f8tensor.from_float8(q.dtype)
dq, dk, dv, *rest = fused_attn_bwd(
Expand Down Expand Up @@ -4959,7 +4941,6 @@ def __init__(
) -> None:
super().__init__()

self.logger = logging.getLogger("FusedAttention")
self.softmax_scale = softmax_scale
self.attention_dropout = attention_dropout
self.attention_dropout_ctx = attention_dropout_ctx
Expand Down Expand Up @@ -5306,6 +5287,9 @@ def __init__(
super().__init__()

self.logger = logging.getLogger("DotProductAttention")
self.logger.setLevel(_log_level)
if not self.logger.hasHandlers():
self.logger.addHandler(_stream_handler)
self.qkv_format = qkv_format
attn_mask_type = attn_mask_type.replace(",", "_")
if attn_mask_type == "causal_padding":
Expand Down Expand Up @@ -5619,7 +5603,7 @@ def forward(
if self.fp8_meta["recipe"].fp8_mha:
if not self.fp8_meta["recipe"].fp8_dpa:
self.fp8_meta["recipe"].fp8_dpa = True
self.logger.WARNING(
self.logger.warning(
"""Forcing fp8_meta["recipe"].fp8_dpa=True due to """
"""fp8_meta["recipe"].fp8_mha=True"""
)
Expand Down
23 changes: 0 additions & 23 deletions transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
# See LICENSE for license information.

"""GroupedLinear API"""
import os
import logging
from typing import Union, Optional, Callable, Tuple, List, Dict, Any

import torch
Expand Down Expand Up @@ -45,17 +43,6 @@
from ..graph import is_graph_capturing
from ..float8_tensor import Float8Tensor

# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
logging.basicConfig(
format="[%(levelname)-8s | %(name)-19s]: %(message)s",
level=log_levels[log_level if log_level in [0, 1, 2] else 2],
)

__all__ = ["GroupedLinear"]

"""
Expand Down Expand Up @@ -97,7 +84,6 @@ def forward(
is_grad_enabled: bool,
*weights_and_biases: Union[Float8Tensor, torch.Tensor, None],
) -> torch.Tensor:
logger = logging.getLogger("GroupedLinear")
num_gemms = len(m_splits)
weights = weights_and_biases[:num_gemms]
weights_fp8 = weights_and_biases[num_gemms : 2 * num_gemms]
Expand Down Expand Up @@ -151,8 +137,6 @@ def forward(
inputmats = inputmats_no_fp8

if fp8:
logger.debug("Running forward in FP8")

bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype
biases = [cast_if_needed(bias, bias_dtype) for bias in biases] if use_bias else biases

Expand Down Expand Up @@ -184,8 +168,6 @@ def forward(
use_split_accumulator=_2X_ACC_FPROP,
)
else:
logger.debug("Running forward in %s", activation_dtype)

# Cast for native AMP
weights = [cast_if_needed(w, activation_dtype) for w in weights]
biases = (
Expand Down Expand Up @@ -286,8 +268,6 @@ def forward(

@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
logger = logging.getLogger("GroupedLinear")

with torch.cuda.nvtx.range("_GroupedLinear_backward"):
(
fwd_scale_inverses,
Expand Down Expand Up @@ -353,7 +333,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],

if ctx.requires_dgrad:
if ctx.fp8:
logger.debug("Running backward in FP8")
dgrad = torch.empty(
(sum(ctx.m_splits), weights_fp8[i].size(1)),
dtype=ctx.activation_dtype,
Expand All @@ -376,8 +355,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
use_split_accumulator=_2X_ACC_DGRAD,
)
else:
logger.debug("Running backward in %s", ctx.activation_dtype)

dgrad = torch.empty(
(sum(ctx.m_splits), weights[0].size(1)),
dtype=ctx.activation_dtype,
Expand Down
22 changes: 0 additions & 22 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
"""LayerNormLinear API"""
import os
import warnings
import logging
from typing import Any, Callable, Dict, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -48,17 +47,6 @@
from ._common import _apply_normalization, _noop_cat
from ..float8_tensor import Float8Tensor

# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
logging.basicConfig(
format="[%(levelname)-8s | %(name)-19s]: %(message)s",
level=log_levels[log_level if log_level in [0, 1, 2] else 2],
)

__all__ = ["LayerNormLinear"]


Expand Down Expand Up @@ -104,7 +92,6 @@ def forward(
ub_name: str,
fsdp_group: Union[dist_group_type, None],
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
logger = logging.getLogger("LayerNormLinear")
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
assert inp.shape[-1] == in_features, "GEMM not possible"
Expand Down Expand Up @@ -203,8 +190,6 @@ def forward(
ln_out = ln_out_total

if fp8:
logger.debug("Running forward in FP8")

bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype
bias = cast_if_needed(bias, bias_dtype) if use_bias else bias

Expand Down Expand Up @@ -259,8 +244,6 @@ def forward(
dtype=activation_dtype,
)
else:
logger.debug("Running forward in %s", activation_dtype)

# Cast for native AMP
weight = cast_if_needed(weight, activation_dtype)
bias = cast_if_needed(bias, activation_dtype) if use_bias else bias
Expand Down Expand Up @@ -379,7 +362,6 @@ def forward(
def backward(
ctx, *grad_outputs: Tuple[torch.Tensor, ...]
) -> Tuple[Union[torch.Tensor, None], ...]:
logger = logging.getLogger("LayerNormLinear")
if isinstance(grad_outputs[0], Float8Tensor):
ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT1] = grad_outputs[
0
Expand Down Expand Up @@ -500,8 +482,6 @@ def backward(
ub_obj = None

if ctx.fp8:
logger.debug("Running backward in FP8")

fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
out_index, meta_tensor, out_te_type, out_type = (
Expand Down Expand Up @@ -544,8 +524,6 @@ def backward(
)
clear_tensor_data(grad_output_c)
else:
logger.debug("Running backward in %s", ctx.activation_dtype)

# DGRAD: Evaluated unconditionally to feed into Linear backward
_, _, _ = tex.gemm(
weight,
Expand Down
Loading

0 comments on commit 121ff62

Please sign in to comment.