Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pyTorch] Infrastructure for C++ QuantizedTensor #1251

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion transformer_engine/common/recipe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,15 @@ class _OverrideLinearPrecision(NamedTuple):
dgrad: bool = False
wgrad: bool = False

@dataclass
class Recipe:
"""
Abstract base recipe class.
"""
pass

@dataclass()
class DelayedScaling:
class DelayedScaling(Recipe):
"""
Use the delayed scaling factor strategy. Use scale factor from previous
iteration and record amax history of `amax_history_len` steps.
Expand Down
28 changes: 5 additions & 23 deletions transformer_engine/pytorch/cpp_extensions/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,41 +8,23 @@
import torch

import transformer_engine_torch as tex

from ..quantization_params import QuantizationParams
from ._common import canonicalize_fp8_scales

__all__ = ["gelu", "relu", "reglu", "geglu", "swiglu", "qgelu", "srelu"]


def gelu(
inp: torch.Tensor,
fp8_meta_tensor: Optional[tex.FP8TensorMeta],
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None],
otype: tex.DType,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
qparams: QuantizationParams,
) -> torch.Tensor:
"""GeLU with FP8 output"""

# Get FP8 scaling factors
fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales(
scale=scale,
amax=amax,
scale_inv=scale_inv,
fp8_meta=fp8_meta_tensor,
fp8_meta_index=fp8_tensor,
allow_multiple_offsets=False,
)

# Launch kernel
return torch.ops.tex_ts.gelu_ts(
return tex.gelu(
inp,
fp8_scales["scale"],
fp8_scales["amax"],
fp8_scales["scale_inv"],
fp8_scales_offsets["scale_offset"],
otype,
)
qparams)


def relu(
Expand Down
129 changes: 129 additions & 0 deletions transformer_engine/pytorch/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
import functools
from typing import Optional, Tuple, Union, List
import torch
from ..quantization_params import QuantizationParams
import transformer_engine_torch as tex
from ..constants import TE_DType
from ..utils import assert_dim_for_fp8_exec


__all__ = [
"general_gemm",
"gemm",
"fp8_gemm",
"grouped_gemm",
Expand All @@ -25,6 +27,133 @@ def _empty_tensor() -> torch.Tensor:
return torch.Tensor()


def general_gemm(
A: torch.Tensor,
B: torch.Tensor,
workspace: torch.Tensor,
out_dtype: torch.dtype,
quantization_params: Optional[QuantizationParams] = None,
gelu: bool = False,
accumulate: bool = False,
out: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
use_split_accumulator: bool = False,
ub_algo: tex.UbufOverlapAlgo = None,
ub: Union[tex.UbufCommOverlap, tex.UbufP2PCommOverlap] = None,
ub_buffer: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""GEMM supporting fp8 inputs."""

empty_tensor = _empty_tensor()
assert quantization_params is None, "FP8 output not supported yet"
if out is not None:
if not out.is_contiguous():
raise ValueError("Output tensor is not contiguous.")

# Use bfloat16 as default bias_dtype
bias_dtype = torch.bfloat16 if bias is None else bias.dtype
bias_dtype = TE_DType[bias_dtype]

args = (
A,
True, # transa
B,
False, # transb
out,
quantization_params,
out_dtype,
bias,
bias_dtype,
gelu,
False, # grad
workspace,
workspace.shape[0],
accumulate,
use_split_accumulator,
)
fn = tex.te_gemm2
if ub_algo is not None:
raise ValueError("Not implemented yet!")
assert ub is not None, "ub object is None!"
if ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_AG:
fn = ub.bulk_overlap
extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
args = tuple(
args
+ (
1,
extra_output_tensor,
)
)
elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS:
fn = ub.bulk_overlap
extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
args = tuple(
args
+ (
0,
extra_output_tensor,
)
)
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P:
fn = ub.split_overlap_ag_p2p
extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
args = tuple(args + (extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P:
fn = ub.atomic_gemm_overlap_ag_p2p
extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
args = tuple(args + (extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS:
fn = ub.split_overlap_rs
assert (
extra_output_tensor is not None
), "SPLIT_PIPELINED_RS requires extra output tensor"
args = tuple(
args
+ (
True,
extra_output_tensor,
)
)
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P:
fn = ub.split_overlap_rs_p2p
assert (
extra_output_tensor is not None
), "SPLIT_PIPELINED_RS_P2P requires extra output tensor"
args = tuple(args + (extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_RS:
fn = ub.atomic_gemm_overlap_rs
assert extra_output_tensor is not None, "ATOMIC_GEMM_RS requires extra output tensor"
args = tuple(
args
+ (
True,
extra_output_tensor,
)
)
elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P:
fn = ub.atomic_gemm_overlap_rs_p2p
assert (
extra_output_tensor is not None
), "ATOMIC_GEMM_RS_P2P requires extra output tensor"
args = tuple(args + (extra_output_tensor,))
if ub_algo is not None and ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P:
out = fn(*args)
gelu_input = empty_tensor
else:
out, gelu_input = fn(*args)

return out, gelu_input


def fp8_gemm(
A: torch.Tensor,
A_scale_inv: torch.Tensor,
Expand Down
41 changes: 41 additions & 0 deletions transformer_engine/pytorch/csrc/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include <stdexcept>
#include <vector>

#include "c10/util/ArrayRef.h"
#include "common/util/logging.h"

namespace transformer_engine {
Expand Down Expand Up @@ -83,6 +84,14 @@ enum FP8BwdTensors {
GRAD_INPUT3 = 5
};

class Float8Tensor {
public:
at::Tensor data;
std::optional<at::Tensor> transpose = std::nullopt;
at::Tensor scale_inv;
DType dtype;
};

} // namespace transformer_engine

transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid,
Expand Down Expand Up @@ -126,6 +135,7 @@ inline transformer_engine::DType GetTransformerEngineDType(at::ScalarType t) {
case torch::kInt64:
return transformer_engine::DType::kInt64;
default:
std::cout << "Type: " << static_cast<int>(t) << std::endl;
NVTE_ERROR("Invalid type");
}
}
Expand Down Expand Up @@ -168,4 +178,35 @@ at::Tensor allocateTorchTensor(int M, transformer_engine::DType dtype);

void* getDataPtr(at::Tensor tensor, int offset = 0);

namespace std {
template <typename T>
string to_string(const vector<T>& vec) {
string ret = "[";
for (const auto& val : vec) {
ret += to_string(val) + ",";
}
if (ret.size() > 1) {
ret[ret.size() - 1] = ']';
} else {
ret += "]";
}
return ret;
}

// Torch shape -> string
template <typename T>
string to_string(const c10::ArrayRef<T>& vec) {
string ret = "[";
for (const auto& val : vec) {
ret += to_string(val) + ",";
}
if (ret.size() > 1) {
ret[ret.size() - 1] = ']';
} else {
ret += "]";
}
return ret;
}
} // namespace std

#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_
31 changes: 27 additions & 4 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_
#define TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_

#include <optional>
#include "common.h"
#include "common/common.h"

/***************************************************************************************************
* Permutation
Expand Down Expand Up @@ -138,6 +138,21 @@ at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v);
* GEMM
**************************************************************************************************/

using MaybeTensor = std::optional<at::Tensor>;

std::vector<at::Tensor> te_gemm2(transformer_engine::Float8Tensor A, bool transa,
transformer_engine::Float8Tensor B, bool transb, MaybeTensor D,
MaybeTensor D_scale, transformer_engine::DType D_type,
MaybeTensor D_amax, MaybeTensor bias,
transformer_engine::DType bias_type, bool gelu, bool grad,
at::Tensor workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator);
std::vector<at::Tensor> te_gemm2(at::Tensor A, bool transa, at::Tensor B, bool transb,
MaybeTensor D, MaybeTensor D_scale,
transformer_engine::DType D_type, MaybeTensor D_amax,
MaybeTensor bias, transformer_engine::DType bias_type, bool gelu,
bool grad, at::Tensor workspace, size_t workspaceSize,
bool accumulate, bool use_split_accumulator);
void te_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type,
bool transa, at::Tensor B, at::Tensor B_scale_inverse,
transformer_engine::DType B_type, bool transb, at::Tensor D, at::Tensor D_scale,
Expand Down Expand Up @@ -221,9 +236,7 @@ std::tuple<std::vector<at::Tensor>, std::vector<at::Tensor>> fused_multi_cast_tr
std::vector<int> scale_indices, std::vector<int> amax_indices,
std::vector<int> scale_inv_indices, transformer_engine::DType otype);

at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype);

void fp8_transpose_noalloc(at::Tensor input, at::Tensor output, transformer_engine::DType otype);
at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype, std::optional<at::Tensor> output = std::nullopt);

void fp8_transpose_noalloc_noop(at::Tensor input, at::Tensor output, at::Tensor noop,
transformer_engine::DType otype);
Expand Down Expand Up @@ -350,6 +363,16 @@ at::Tensor rmsnorm_fwd_inf(const at::Tensor &input, const at::Tensor &weight, fl
* Cast
**************************************************************************************************/

namespace transformer_engine::pytorch {

py::handle cast(const at::Tensor& tensor,
py::handle quantization_params,
bool rowwise_usage,
bool columnwise_usage,
py::handle proxy);

} // namespace transformer_engine::pytorch

at::Tensor cast_to_fp8(const at::Tensor &input, const at::Tensor &scale, at::Tensor amax,
at::Tensor scale_inv, transformer_engine::DType otype,
const int scale_offset = 0, const int amax_offset = 0,
Expand Down
3 changes: 2 additions & 1 deletion transformer_engine/pytorch/csrc/extensions/activation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

#include "extensions.h"

at::Tensor gelu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv,
at::Tensor gelu(at::Tensor input,
at::Tensor scale, at::Tensor amax, at::Tensor scale_inv,
transformer_engine::DType otype) {
using namespace transformer_engine;

Expand Down
1 change: 1 addition & 0 deletions transformer_engine/pytorch/csrc/extensions/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
* See LICENSE for license information.
************************************************************************/

#include "common/common.h"
#include "extensions.h"

constexpr int block_size = 512;
Expand Down
Loading