Skip to content

Commit

Permalink
Fix for performance regression on wp.tile_matmul()
Browse files Browse the repository at this point in the history
  • Loading branch information
mmacklin authored and shi-eric committed Feb 3, 2025
1 parent 013af01 commit f96a179
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 549 deletions.
43 changes: 26 additions & 17 deletions warp/examples/benchmarks/benchmark_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,28 @@
wp.set_module_options({"fast_math": True, "enable_backward": False})


def create_mlp_kernel(m, n, k):
# returns a kernel to compute a GEMM given m,n,k tile sizes
def create_gemm_kernel(m, n, k):
TILE_M = m
TILE_N = n
TILE_K = k

@wp.kernel
def mlp(x: wp.array2d(dtype=float), weights_wp: wp.array2d(dtype=float), n_k: int, output: wp.array2d(dtype=float)):
i_m, i_n = wp.tid()
def gemm(A: wp.array2d(dtype=float), B: wp.array2d(dtype=float), output: wp.array2d(dtype=float)):
i, j = wp.tid()
sum = wp.tile_zeros(shape=(TILE_M, TILE_N), dtype=wp.float32)
for count in range(n_k):
feat = wp.tile_load(x, shape=(TILE_M, TILE_K), offset=(i_m * TILE_M, count * TILE_K))
weight = wp.tile_load(weights_wp, shape=(TILE_K, TILE_N), offset=(count * TILE_K, i_n * TILE_N))
wp.tile_matmul(feat, weight, sum)

wp.tile_store(output, sum, offset=(i_m * TILE_M, i_n * TILE_N))
count = A.shape[1] // TILE_K

return mlp
for k in range(count):
a = wp.tile_load(A, shape=(TILE_M, TILE_K), offset=(i * TILE_M, k * TILE_K))
b = wp.tile_load(B, shape=(TILE_K, TILE_N), offset=(k * TILE_K, j * TILE_N))

wp.tile_matmul(a, b, sum)

wp.tile_store(output, sum, offset=(i * TILE_M, j * TILE_N))

return gemm


def benchmark_torch(A, B, warm_up, iterations):
Expand All @@ -55,19 +60,25 @@ def benchmark_warp(A, B, config, warm_up, iterations):
TILE_K = config[2]
BLOCK_DIM = config[3]

mlp = create_mlp_kernel(TILE_M, TILE_N, TILE_K)
mlp = create_gemm_kernel(TILE_M, TILE_N, TILE_K)

M = A.shape[0]
N = B.shape[1]
K = A.shape[1]

output = wp.zeros((M, N), dtype=float)

# create launch command
cmd = wp.launch_tiled(
kernel=mlp,
dim=[M // TILE_M, N // TILE_N],
inputs=[A, B, output],
block_dim=BLOCK_DIM,
record_cmd=True,
)

# warm-up
for _ in range(warm_up):
wp.launch_tiled(
kernel=mlp, dim=[M // TILE_M, N // TILE_N], inputs=[A, B, K // TILE_K, output], block_dim=BLOCK_DIM
)
cmd.launch()

# check output
if warm_up > 0:
Expand All @@ -77,9 +88,7 @@ def benchmark_warp(A, B, config, warm_up, iterations):
timers = {}
with wp.ScopedTimer("warp", print=False, dict=timers, synchronize=True):
for _ in range(iterations):
wp.launch_tiled(
kernel=mlp, dim=[M // TILE_M, N // TILE_N], inputs=[A, B, K // TILE_K, output], block_dim=BLOCK_DIM
)
cmd.launch()

return timers["warp"][0]

Expand Down
179 changes: 0 additions & 179 deletions warp/examples/benchmarks/benchmark_tile.py

This file was deleted.

1 change: 0 additions & 1 deletion warp/native/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -1761,6 +1761,5 @@ inline CUDA_CALLABLE void adj_expect_near(const vec3& actual, const vec3& expect
// only include in kernels for now
#if defined(__CUDACC_RTC__)
#include "tile.h"
#include "tile_gemm.h"
#include "tile_reduce.h"
#endif
Loading

0 comments on commit f96a179

Please sign in to comment.