Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fp8 #266

Open
wants to merge 33 commits into
base: main
Choose a base branch
from
Open

fp8 #266

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
44e6574
move the basics of fp8 to this branch
xrsrke Oct 9, 2024
6c9a4d0
add fp8 tensor
xrsrke Oct 28, 2024
0f8f672
add fp8 linear
xrsrke Oct 28, 2024
7dfe3ac
add fp8 tensor parallel
xrsrke Oct 29, 2024
dda00a4
add fp8 tp profiler
xrsrke Oct 30, 2024
b4156dc
update profiling script
xrsrke Nov 1, 2024
39a4960
remove uncessary .contiguous() in fp8 backward
xrsrke Nov 1, 2024
1ddc44c
remove unnecessary .transpose in fp8 linear backward
xrsrke Nov 1, 2024
c827594
remove uncessary transpose input in the fwd pass, and contiguous weig…
xrsrke Nov 1, 2024
c937375
add bencmark speed with 5% speed up
xrsrke Nov 1, 2024
f3e3495
add speed benchmark
xrsrke Nov 1, 2024
4b26cf1
65% speed up in fwd+bwd pass with m=n=k=32768
xrsrke Nov 3, 2024
edb1e87
add dumb transpose in fp8_matmul_kernel
xrsrke Nov 3, 2024
e93cf55
remove transpose in kernel
xrsrke Nov 4, 2024
9510f57
backup before doing monkey dispatch fp8 tp
xrsrke Nov 18, 2024
478984a
remove fp8 tp from llama's modeling code, fix no grad in param, remov…
xrsrke Nov 20, 2024
c7d9e8a
refactor NanotronParameter to support fp8
xrsrke Nov 20, 2024
23d66cb
keep FP8 NanotronParameter's dtype in 8 bit, move converting model to…
xrsrke Nov 21, 2024
2864391
add tests for create_param_that_share_metadata, and generating hash
xrsrke Nov 21, 2024
1800efe
add fp8 optim init
xrsrke Nov 22, 2024
c5bcbe7
refactor fp8 linear, tp, parameter tests
xrsrke Nov 22, 2024
a4d6f15
move master weights to gradient accumulator
xrsrke Nov 25, 2024
fbbbf4d
refactor + add test for fp8 initialization
xrsrke Nov 28, 2024
b764b97
new changes
xrsrke Nov 28, 2024
afdfbf1
fix hanging due to NanotronParameter.__repr__ (param.data == Nanotron…
xrsrke Nov 29, 2024
79341ea
by default, do not quantize the first and last layer
xrsrke Nov 30, 2024
b440408
fix nan in fwd pass
xrsrke Dec 18, 2024
4723335
fix grad_clipping for fp8
xrsrke Dec 19, 2024
dd3259b
fix didn't update fp8 parameters in optim.step() due to grad_accum
xrsrke Jan 9, 2025
a3a13ce
remove ablated fp8 config, and uncessary files/code
xrsrke Jan 9, 2025
e8b114b
clean up
xrsrke Jan 10, 2025
ebea115
add
xrsrke Jan 11, 2025
9a99ab6
Merge branch 'main' into xrsrke/fp8_for_nanotron
xrsrke Jan 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,8 @@ cython_debug/

checkpoints/
wandb/

*.csv
*.html
src/nanotron/.test_cache/
log/
256 changes: 256 additions & 0 deletions benchmark/fp8_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
import torch
import transformer_engine.pytorch.cpp_extensions as texcpp

# from transformer_engine.pytorch.module import get_workspace
# import transformer_engine_extensions as tex
import transformer_engine_torch as tex

scale = 1.0

meta = tex.FP8TensorMeta()
meta.scale = torch.ones(1, dtype=torch.float32, device="cuda") * scale
meta.scale_inv = torch.ones(1, dtype=torch.float32, device="cuda") / scale
meta.amax_history = torch.zeros(1, 1, dtype=torch.float32, device="cuda")


def cast_to_fp8(x, qtype):
ret = texcpp.cast_to_fp8(x, meta, tex.FP8FwdTensors.GEMM1_INPUT, qtype)
ret._fp8_qtype = qtype
return ret


def cast_from_fp8(x, qtype):
ret = texcpp.cast_from_fp8(x, meta, tex.FP8FwdTensors.GEMM1_INPUT, x._fp8_qtype, qtype)
ret._fp8_qtype = qtype
return ret


one_scale_inv = torch.ones(1, dtype=torch.float32, device="cuda")
empty_tensor = torch.Tensor()
# workspace = get_workspace()
workspace = torch.empty(33_554_432, dtype=torch.int8, device="cuda")
assert workspace.is_cuda


# PT_DType = dict([(v, k) for k, v in texcpp.TE_DType.items()])
# PT_DType[tex.DType.kFloat8E4M3] = torch.uint8
# PT_DType[tex.DType.kFloat8E5M2] = torch.uint8


def convert_torch_dtype_to_te_dtype(dtype: torch.dtype) -> tex.DType:
# NOTE: transformer engine maintains it own dtype mapping
# so we need to manually map torch dtypes to TE dtypes
TORCH_DTYPE_TE_DTYPE_NAME_MAPPING = {
torch.int32: "kInt32",
torch.float32: "kFloat32",
torch.float16: "kFloat16",
torch.bfloat16: "kBFloat16",
# DTypes.FP8E4M3: "kFloat8E4M3",
# DTypes.FP8E5M2: "kFloat8E5M2",
# DTypes.KFLOAT16: "kFloat16",
}
return getattr(tex.DType, TORCH_DTYPE_TE_DTYPE_NAME_MAPPING[dtype])


def fp8_gemm(fa, fb, trans_a, trans_b, bias=None, qtype=tex.DType.kFloat32):
"""
# te_gemm

input_A: (A_row, A_col)
input_B: (B_row, B_col)

when transa, transb = True, False
m, k, n = A_row, A_col, B_row
lda, ldb, ldd = A_col, A_col, A_row
output_D: (B_row, A_row)

when transa, transb = False, False
m, k, n = A_col, A_row, B_row
lda, ldb, ldd = A_col, A_row, A_col
output_D: (B_row, A_col)

when transa, transb = False, True
m, k, n = A_col, A_row, B_col
lda, ldb, ldd = A_col, B_col, A_col
output_D: (B_col, A_col)
"""
assert fa.is_cuda and fb.is_cuda
assert fa.is_contiguous()
assert fb.is_contiguous()
device = fa.device
fa_qtype, fb_qtype = fa._fp8_qtype, fb._fp8_qtype
A_row, A_col = fa.shape
B_row, B_col = fb.shape
if trans_a and not trans_b:
assert A_col == B_col
C_row, C_col = B_row, A_row
elif not trans_a and not trans_b:
assert A_row == B_col
C_row, C_col = B_row, A_col
elif not trans_a and trans_b:
assert A_row == B_row
C_row, C_col = B_col, A_col
out_shape = (C_row, C_col)

# dtype = PT_DType[qtype]
if qtype == tex.DType.kFloat32:
dtype = torch.float32
elif qtype == tex.DType.kFloat16:
dtype = torch.float16

out = torch.empty(out_shape, dtype=dtype, device=device)
# te_gemm is column-order.

# tex.te_gemm(
# fa, one_scale_inv, fa_qtype, trans_a,
# fb, one_scale_inv, fb_qtype, trans_b,
# out, qtype,
# bias or empty_tensor, empty_tensor, False,
# workspace, workspace.shape[0],
# False, True,
# )

_empty_tensor = torch.Tensor()
SCALE = AMAX = _empty_tensor
TE_CONFIG_TRANSPOSE_BIAS = False

tex.te_gemm(
fa,
one_scale_inv,
fa_qtype,
trans_a,
fb,
one_scale_inv,
fb_qtype,
trans_b,
# out, SCALE, qtype, AMAX,
# bias or empty_tensor, qtype, False,
# workspace, workspace.shape[0],
# False, True,
out,
SCALE,
qtype,
AMAX,
torch.tensor([], dtype=dtype),
qtype,
_empty_tensor,
TE_CONFIG_TRANSPOSE_BIAS,
workspace,
workspace.shape[0],
False,
True,
0,
)

out._fp8_qtype = qtype
return out


def fp8_matmul(fa, fb, bias=None, qtype=tex.DType.kFloat32):
# trans_a = False and trans_b = False is not implemented.
fb_qtype = fb._fp8_qtype
fb = fb.T.contiguous()
fb._fp8_qtype = fb_qtype
return fp8_gemm(fb, fa, trans_a=True, trans_b=False, bias=bias, qtype=qtype)


h100_peak_flops_float32 = 67e12
h100_peak_flops_fp16_tc = 989e12
h100_peak_tops_float8_tc = 1979e12

dtype_to_peak_tops = {
torch.float32: h100_peak_flops_float32,
torch.float16: h100_peak_flops_fp16_tc,
torch.bfloat16: h100_peak_flops_fp16_tc,
torch.float8_e4m3fn: h100_peak_tops_float8_tc,
torch.float8_e5m2: h100_peak_tops_float8_tc,
}

from torch.utils import benchmark


def benchmark_fn_in_sec(f, *args, **kwargs):
# Manual warmup
for _ in range(4):
f(*args, **kwargs)

t0 = benchmark.Timer(stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f})
measurement = t0.blocked_autorange()
return measurement.mean


def run_fp8(a, b):
fa = cast_to_fp8(a, tex.DType.kFloat8E4M3)
fb = cast_to_fp8(b, tex.DType.kFloat8E4M3)
fp8_matmul(fa, fb, qtype=tex.DType.kFloat16)


def run_bfloat16(a, b):
a = a.to(torch.bfloat16)
b = b.to(torch.bfloat16)
torch.matmul(a, b)


def benchmark_linear_operations(a, b):
M, K = a.shape
N, _ = b.shape

# Benchmark FP8
fp8_time = benchmark_fn_in_sec(run_fp8, a, b)

# Benchmark BFloat16
bfloat16_time = benchmark_fn_in_sec(run_bfloat16, a, b)

# Calculate FLOPS
# Each linear operation performs 2*M*N*K FLOPs (multiply-add)
total_flops = 2 * M * N * K

fp8_tflops = (total_flops / fp8_time) / 1e12
bfloat16_tflops = (total_flops / bfloat16_time) / 1e12

# Calculate efficiency compared to peak performance
fp8_efficiency = (fp8_tflops / (h100_peak_tops_float8_tc / 1e12)) * 100
bfloat16_efficiency = (bfloat16_tflops / (h100_peak_flops_fp16_tc / 1e12)) * 100

return {
"M": M,
"N": N,
"K": K,
"FP8_time_ms": fp8_time * 1000,
"BF16_time_ms": bfloat16_time * 1000,
"FP8_TFLOPS": fp8_tflops,
"BF16_TFLOPS": bfloat16_tflops,
"FP8_eff%": fp8_efficiency,
"BF16_eff%": bfloat16_efficiency,
"Speedup": bfloat16_time / fp8_time,
}


if __name__ == "__main__":
# a = torch.randn(128, 128).cuda()
# b = torch.randn(128, 128).cuda()
# qa = cast_from_fp8(fa, tex.DType.kFloat32)
# qb = cast_from_fp8(fb, tex.DType.kFloat32)
# qc = torch.matmul(qa, qb)

# E4M3/E5M2 @ E4M3/E5M2 = FP16/FP32
# print(qc, qc2)

import pandas as pd

def create_benchmark_table(sizes):
results = []
for size in sizes:
a = torch.randn(size, size).cuda()
b = torch.randn(size, size).cuda()
result = benchmark_linear_operations(a, b)
results.append(result)

df = pd.DataFrame(results)
df = df.round(2) # Round to 2 decimal places
return df

# Example usage:
sizes = [4096, 16384, 32768, 28672, 49152]
benchmark_table = create_benchmark_table(sizes)
print(benchmark_table)
Loading
Loading