Skip to content

Commit

Permalink
backup before doing monkey dispatch fp8 tp
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Nov 18, 2024
1 parent e93cf55 commit 9510f57
Show file tree
Hide file tree
Showing 22 changed files with 1,280 additions and 178 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,8 @@ cython_debug/

checkpoints/
wandb/

*.csv
*.html
src/nanotron/.test_cache/
log/
221 changes: 88 additions & 133 deletions benchmark/fp8_tp_speed.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import argparse
import itertools

import pandas as pd
import torch
Expand All @@ -11,16 +10,13 @@

# H100 SXM specs: bottom of https://www.nvidia.com/en-us/data-center/h100/
h100_peak_flops_float32 = 67e12
# NOTE: without sparsity
h100_peak_flops_fp16_tc = 989e12
h100_peak_tops_float8_tc = 1979e12


# def color_scale(val, min_val, max_val):
# """Generate a color scale from white to dark blue based on value."""
# if pd.isna(val):
# return 'background-color: white'
# normalized = (val - min_val) / (max_val - min_val) if max_val != min_val else 0
# return f'background-color: rgba(0, 0, 139, {normalized:.2f}); color: {"white" if normalized > 0.5 else "black"}'
# # NOTE: with sparity
# h100_peak_flops_fp16_tc = 1979e12
# h100_peak_tops_float8_tc = 3958e12


def color_scale(val, min_val, max_val, metric_type="default"):
Expand All @@ -43,55 +39,6 @@ def color_scale(val, min_val, max_val, metric_type="default"):
return f"{color}; color: {text_color}"


# [Previous benchmark_fn_in_sec and run functions remain the same...]

# def create_html_table(df, exp_number, tp_size):
# # Style the dataframe
# styled_df = df.style.format({
# 'FP8_time_ms': '{:.2f}',
# 'BF16_time_ms': '{:.2f}',
# 'FP8_TFLOPS': '{:.2f}',
# 'BF16_TFLOPS': '{:.2f}',
# 'FP8_efficiency_%': '{:.2f}',
# 'BF16_efficiency_%': '{:.2f}',
# 'Speedup': '{:.2f}'
# })

# # Apply color scaling to specific columns
# styled_df = styled_df.apply(lambda x: pd.Series([
# color_scale(v, x.min(), x.max(), 'time') if col.endswith('time_ms')
# else color_scale(v, x.min(), x.max(), 'performance') if col.endswith('TFLOPS')
# else color_scale(v, x.min(), x.max(), 'efficiency') if col.endswith('efficiency_%')
# else color_scale(v, x.min(), x.max()) if col == 'Speedup'
# else '' for v in x
# ]), axis=0)

# # Generate HTML
# html = f'''
# <html>
# <head>
# <style>
# table {{ border-collapse: collapse; width: 100%; }}
# th, td {{ padding: 12px; text-align: left; border: 1px solid #ddd; }}
# th {{ background-color: #4CAF50; color: white; }}
# tr:nth-child(even) {{ background-color: #f9f9f9; }}
# .header {{ text-align: center; padding: 20px; }}
# </style>
# </head>
# <body>
# <div class="header">
# <h2>Benchmark Results (TP_SIZE={tp_size})</h2>
# <p>Experiment: {exp_number}</p>
# </div>
# {styled_df.to_html()}
# </body>
# </html>
# '''

# with open(f'{exp_number}_benchmark_results_tp{tp_size}.html', 'w') as f:
# f.write(html)


def create_html_table(df, exp_number, tp_size):
def style_df(df):
# Create an empty DataFrame with the same shape as the input
Expand Down Expand Up @@ -199,6 +146,23 @@ def run_linear(input, M, N, K, parallel_context, include_backward=False):
sharded_output.sum().backward()


# def parse_args():
# parser = argparse.ArgumentParser(description="Run profiling experiments with configurable dimensions")
# parser.add_argument("--exp_number", type=str, help="Experiment number")
# parser.add_argument("--tp_size", type=int, default=1, help="Tensor Parallel size")
# # parser.add_argument(
# # "--dimensions",
# # type=str,
# # default="4096,16384,32768,28672,49152",
# # help="Comma-separated list of dimensions to test. Used when M/K/N are not explicitly provided",
# # )
# # Add the missing argument definitions
# parser.add_argument("--m_size", type=int, help="Explicitly set M dimension")
# parser.add_argument("--k_size", type=int, help="Explicitly set K dimension")
# parser.add_argument("--n_size", type=int, help="Explicitly set N dimension")
# return parser.parse_args()


def parse_args():
parser = argparse.ArgumentParser(description="Run profiling experiments with configurable dimensions")
parser.add_argument("--exp_number", type=str, help="Experiment number")
Expand All @@ -207,14 +171,17 @@ def parse_args():
"--dimensions",
type=str,
default="4096,16384,32768,28672,49152",
help="Comma-separated list of dimensions to test",
help="Comma-separated list of dimensions to test. Used when M/K/N are not explicitly provided",
)
parser.add_argument("--m_size", type=int, help="Explicitly set M dimension")
parser.add_argument("--k_size", type=int, help="Explicitly set K dimension")
parser.add_argument("--n_size", type=int, help="Explicitly set N dimension")
return parser.parse_args()


def benchmark_linear_operations(M, N, K, parallel_context, include_backward):
input = torch.randn(M, K, device="cuda", requires_grad=False)
bfloat16_input = torch.randn(M, K, device="cuda", requires_grad=False, dtype=torch.bfloat16)
def benchmark_linear_operations(M, N, K, parallel_context, include_backward, requires_grad):
input = torch.randn(M, K, device="cuda", requires_grad=requires_grad)
bfloat16_input = torch.randn(M, K, device="cuda", requires_grad=requires_grad, dtype=torch.bfloat16)

# Benchmark FP8
fp8_time = benchmark_fn_in_sec(run_fp8_linear, input, M, N, K, parallel_context, include_backward)
Expand All @@ -223,13 +190,18 @@ def benchmark_linear_operations(M, N, K, parallel_context, include_backward):
bfloat16_time = benchmark_fn_in_sec(run_linear, bfloat16_input, M, N, K, parallel_context, include_backward)

# Calculate FLOPS
# Each linear operation performs 2*M*N*K FLOPs (multiply-add)
total_flops = 2 * M * N * K // parallel_context.tensor_parallel_size
if include_backward:
# Gradient with Respect to Parameters
total_flops += 2 * M * N * K // parallel_context.tensor_parallel_size
if requires_grad:
# Gradient with Respect to Inputs
total_flops += 2 * M * N * K // parallel_context.tensor_parallel_size

fp8_tflops = (total_flops / fp8_time) / 1e12
bfloat16_tflops = (total_flops / bfloat16_time) / 1e12

# Calculate efficiency compared to peak performance
# Calculate efficiency
fp8_efficiency = (fp8_tflops / (h100_peak_tops_float8_tc / 1e12)) * 100
bfloat16_efficiency = (bfloat16_tflops / (h100_peak_flops_fp16_tc / 1e12)) * 100

Expand All @@ -238,12 +210,13 @@ def benchmark_linear_operations(M, N, K, parallel_context, include_backward):
"N": N,
"K": K,
"Include_Backward": include_backward,
"Input_Requires_Grad": requires_grad,
"FP8_time_ms": fp8_time * 1000,
"BF16_time_ms": bfloat16_time * 1000,
"FP8_TFLOPS": fp8_tflops,
"BF16_TFLOPS": bfloat16_tflops,
"FP8_efficiency_%": fp8_efficiency,
"BF16_efficiency_%": bfloat16_efficiency,
"FP8_MFU%": fp8_efficiency,
"BF16_MFU%": bfloat16_efficiency,
"Speedup": bfloat16_time / fp8_time,
}

Expand All @@ -252,79 +225,61 @@ def benchmark_linear_operations(M, N, K, parallel_context, include_backward):
torch.backends.cudnn.benchmark = True

args = parse_args()

dimensions = [int(d.strip()) for d in args.dimensions.split(",")]
TP_SIZE = args.tp_size
EXP_NUMBER = args.exp_number

results = []
total = len(list(itertools.product(dimensions, dimensions, dimensions)))
parallel_context = ParallelContext(data_parallel_size=1, pipeline_parallel_size=1, tensor_parallel_size=TP_SIZE)
parallel_context = ParallelContext(
data_parallel_size=1, pipeline_parallel_size=1, tensor_parallel_size=args.tp_size
)

# Run benchmarks and collect results
results = []
i = 0
for M, N, K in itertools.product(dimensions, dimensions, dimensions):
i += 1
# Run forward-only case
result = benchmark_linear_operations(M, N, K, parallel_context, include_backward=False)
# combinations = list(itertools.product(
# dimensions, # M
# dimensions, # N
# dimensions, # K
# [False, True], # include_backward
# [False, True] # requires_grad
# ))
# Pair dimensions index-wise
# combinations = []
# for i in range(len(dimensions)):
# M = dimensions[i]
# N = dimensions[i]
# K = dimensions[i]
# # For each dimension pair, test all combinations of include_backward and requires_grad
# for include_backward in [False, True]:
# for requires_grad in [False, True]:
# combinations.append((M, N, K, include_backward, requires_grad))

combinations = []
# Check if explicit M, K, N dimensions are provided
if all(dim is not None for dim in [args.m_size, args.k_size, args.n_size]):
# Use explicitly provided dimensions
for include_backward in [False, True]:
# for requires_grad in [False, True]:
for requires_grad in [False]:
combinations.append((args.m_size, args.n_size, args.k_size, include_backward, requires_grad))
else:
# Use dimensions from the --dimensions argument
dimensions = [int(d.strip()) for d in args.dimensions.split(",")]
for i in range(len(dimensions)):
M = dimensions[i]
N = dimensions[i]
K = dimensions[i]
for include_backward in [False, True]:
# for requires_grad in [False, True]:
for requires_grad in [False]:
combinations.append((M, N, K, include_backward, requires_grad))

total = len(combinations)

for i, (M, N, K, include_backward, requires_grad) in enumerate(combinations, 1):
result = benchmark_linear_operations(M, N, K, parallel_context, include_backward, requires_grad)
results.append(result)
print(f"Experiment {i}/{total} complete (Forward-only)")

# Run forward+backward case
result = benchmark_linear_operations(M, N, K, parallel_context, include_backward=True)
results.append(result)
print(f"Experiment {i}/{total} complete (Forward+Backward)")
print(f"Experiment {i}/{total} complete")

df = pd.DataFrame(results)
df = df.round(2) # Round to 2 decimal places
df = df.sort_values(by=["M", "N", "K", "Include_Backward"])
df = df.round(2)
df = df.sort_values(by=["M", "N", "K", "Include_Backward", "Input_Requires_Grad"])

print(df)

# # Define columns to color and their respective color scales
# color_columns = {
# 'FP8_time_ms': 'Reds',
# 'BF16_time_ms': 'Blues',
# 'FP8_TFLOPS': 'Greens',
# 'BF16_TFLOPS': 'Purples',
# 'FP8_efficiency_%': 'Oranges',
# 'BF16_efficiency_%': 'Viridis',
# 'Speedup': 'RdYlBu'
# }

# # Create the table
# fig = go.Figure(data=[go.Table(
# header=dict(
# values=list(df.columns),
# fill_color='lightgrey',
# align='left',
# font=dict(size=12, color='black')
# ),
# cells=dict(
# values=[df[col] for col in df.columns],
# align='left',
# font=dict(size=11),
# # Format cells with colors for numeric columns
# fill_color=[
# 'white' if col not in color_columns else
# [f'rgba({int(255*i)}, {int(255*i)}, {int(255*i)}, 0.5)'
# for i in (df[col]-df[col].min())/(df[col].max()-df[col].min())]
# for col in df.columns
# ]
# )
# )])

# # Update layout
# fig.update_layout(
# title=f'Benchmark Results (TP_SIZE={TP_SIZE})',
# width=1200,
# height=800,
# margin=dict(l=20, r=20, t=40, b=20)
# )

# # Save the interactive HTML file
# fig.write_html(f'{EXP_NUMBER}_benchmark_results_tp{TP_SIZE}.html')
create_html_table(df, EXP_NUMBER, TP_SIZE)

print(f"\nResults have been saved to {EXP_NUMBER}_benchmark_results_tp{TP_SIZE}.html")
create_html_table(df, args.exp_number, args.tp_size)
print(f"\nResults have been saved to {args.exp_number}_benchmark_results_tp{args.tp_size}.html")
Loading

0 comments on commit 9510f57

Please sign in to comment.