Skip to content

Commit

Permalink
Add sweep_utils.py script to tune heuristics (pytorch#3656)
Browse files Browse the repository at this point in the history
Summary:

X-link: facebookresearch/FBGEMM#732

As title. heuristics tuning scripts for `bf16_fast_gemv`

currently the script needs a manual hack to update the kernel for passing in block dims to work. see comments in the code.

Reviewed By: ipiszy

Differential Revision: D68786295
  • Loading branch information
YUNQIUGUO authored and facebook-github-bot committed Feb 5, 2025
1 parent 8d6c0bf commit 12f4ac4
Showing 1 changed file with 83 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from torch._inductor.utils import do_bench_using_profiling


class SweepHeuristics:
def __init__(self):
self.m = 1
self.block_dims = [
(32, 1),
(32, 4),
(32, 8),
(32, 16),
(32, 32),
(64, 1),
(64, 2),
(64, 4),
(64, 8),
(64, 16),
(128, 1),
(128, 2),
(128, 4),
(128, 8),
(256, 1),
(256, 2),
(256, 4),
(512, 1),
(512, 2),
(1024, 1),
]
self.nks = [(1280, 8192), (8192, 1024), (7168, 8192), (8192, 3584)]

def sweep_heuristics(self) -> None:
for n, k in self.nks:
x = torch.randn(size=(self.m, k), dtype=torch.half, device="cuda")
w = torch.randn(size=(n, k), dtype=torch.half, device="cuda")
best_elapsed_time, best_block_dim_x, best_block_dim_y = None, None, None

for block_dim_x, block_dim_y in self.block_dims:
if (
(k % block_dim_x != 0)
or (n % block_dim_x != 0)
or ((k / block_dim_x) % 8 != 0)
):
continue
# This requires
# 1. update for testing purpose for `bf16_fast_gemv` pytorch custom op to pass in block_dim_x and block_dim_y
# 2. modify the fp16_fast_gemv.cu kernel signature to reflect the block_dim heuristics
# https://www.internalfb.com/code/fbsource/[bafd6390bc8c842b46d81be1a27dafd384503a53]/fbcode/deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py?lines=365
res = do_bench_using_profiling(
lambda: torch.ops.fbgemm.bf16_fast_gemv(
x.T, w, block_dim_x=block_dim_x, block_dim_y=block_dim_y
)
)

if best_elapsed_time is None or res < best_elapsed_time:
best_elapsed_time, best_block_dim_x, best_block_dim_y = (
res,
block_dim_x,
block_dim_y,
)
if best_elapsed_time is None:
print("Error: No valid elapsed time found. Exiting the function.")
return
bw = (
(self.m * k * 2 + n * k * 2 + self.m * n * 2)
/ (best_elapsed_time / 1000)
/ (1024**3)
)

print(f"best elapsed time: {best_elapsed_time} ms")
print(f"best block_dim_x: {best_block_dim_x}")
print(f"best block_dim_y: {best_block_dim_y}")
print(f"best bw: {bw} GB/s")


sweep_instance = SweepHeuristics()
sweep_instance.sweep_heuristics()

0 comments on commit 12f4ac4

Please sign in to comment.