From f96a1790fe87d8845f87c3dc6067cc6643879e9d Mon Sep 17 00:00:00 2001 From: Miles Macklin Date: Mon, 3 Feb 2025 12:49:36 -0800 Subject: [PATCH] Fix for performance regression on wp.tile_matmul() --- warp/examples/benchmarks/benchmark_gemm.py | 43 ++- warp/examples/benchmarks/benchmark_tile.py | 179 ----------- warp/native/builtin.h | 1 - warp/native/tile.h | 148 ++++++++- warp/native/tile_gemm.h | 341 --------------------- 5 files changed, 163 insertions(+), 549 deletions(-) delete mode 100644 warp/examples/benchmarks/benchmark_tile.py delete mode 100644 warp/native/tile_gemm.h diff --git a/warp/examples/benchmarks/benchmark_gemm.py b/warp/examples/benchmarks/benchmark_gemm.py index 588483e08..2f140cf5c 100644 --- a/warp/examples/benchmarks/benchmark_gemm.py +++ b/warp/examples/benchmarks/benchmark_gemm.py @@ -13,23 +13,28 @@ wp.set_module_options({"fast_math": True, "enable_backward": False}) -def create_mlp_kernel(m, n, k): +# returns a kernel to compute a GEMM given m,n,k tile sizes +def create_gemm_kernel(m, n, k): TILE_M = m TILE_N = n TILE_K = k @wp.kernel - def mlp(x: wp.array2d(dtype=float), weights_wp: wp.array2d(dtype=float), n_k: int, output: wp.array2d(dtype=float)): - i_m, i_n = wp.tid() + def gemm(A: wp.array2d(dtype=float), B: wp.array2d(dtype=float), output: wp.array2d(dtype=float)): + i, j = wp.tid() sum = wp.tile_zeros(shape=(TILE_M, TILE_N), dtype=wp.float32) - for count in range(n_k): - feat = wp.tile_load(x, shape=(TILE_M, TILE_K), offset=(i_m * TILE_M, count * TILE_K)) - weight = wp.tile_load(weights_wp, shape=(TILE_K, TILE_N), offset=(count * TILE_K, i_n * TILE_N)) - wp.tile_matmul(feat, weight, sum) - wp.tile_store(output, sum, offset=(i_m * TILE_M, i_n * TILE_N)) + count = A.shape[1] // TILE_K - return mlp + for k in range(count): + a = wp.tile_load(A, shape=(TILE_M, TILE_K), offset=(i * TILE_M, k * TILE_K)) + b = wp.tile_load(B, shape=(TILE_K, TILE_N), offset=(k * TILE_K, j * TILE_N)) + + wp.tile_matmul(a, b, sum) + + wp.tile_store(output, sum, offset=(i * TILE_M, j * TILE_N)) + + return gemm def benchmark_torch(A, B, warm_up, iterations): @@ -55,19 +60,25 @@ def benchmark_warp(A, B, config, warm_up, iterations): TILE_K = config[2] BLOCK_DIM = config[3] - mlp = create_mlp_kernel(TILE_M, TILE_N, TILE_K) + mlp = create_gemm_kernel(TILE_M, TILE_N, TILE_K) M = A.shape[0] N = B.shape[1] - K = A.shape[1] output = wp.zeros((M, N), dtype=float) + # create launch command + cmd = wp.launch_tiled( + kernel=mlp, + dim=[M // TILE_M, N // TILE_N], + inputs=[A, B, output], + block_dim=BLOCK_DIM, + record_cmd=True, + ) + # warm-up for _ in range(warm_up): - wp.launch_tiled( - kernel=mlp, dim=[M // TILE_M, N // TILE_N], inputs=[A, B, K // TILE_K, output], block_dim=BLOCK_DIM - ) + cmd.launch() # check output if warm_up > 0: @@ -77,9 +88,7 @@ def benchmark_warp(A, B, config, warm_up, iterations): timers = {} with wp.ScopedTimer("warp", print=False, dict=timers, synchronize=True): for _ in range(iterations): - wp.launch_tiled( - kernel=mlp, dim=[M // TILE_M, N // TILE_N], inputs=[A, B, K // TILE_K, output], block_dim=BLOCK_DIM - ) + cmd.launch() return timers["warp"][0] diff --git a/warp/examples/benchmarks/benchmark_tile.py b/warp/examples/benchmarks/benchmark_tile.py deleted file mode 100644 index 7258376da..000000000 --- a/warp/examples/benchmarks/benchmark_tile.py +++ /dev/null @@ -1,179 +0,0 @@ -# Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved. -# NVIDIA CORPORATION and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION is strictly prohibited. - -import numpy as np -import torch - -import warp as wp - -wp.init() -wp.set_module_options({"enable_backward": False, "fast_math": True}) -wp.set_device("cuda:0") - -wp.build.clear_kernel_cache() - - -@wp.kernel -def gemm(A: wp.array2d(dtype=float), B: wp.array2d(dtype=float), C: wp.array2d(dtype=float)): - # output index - i, j = wp.tid() - - sum = float(0.0) - - for k in range(0, A.shape[1]): - sum += A[i, k] * B[k, j] - - C[i, j] = sum - - -TILE_M = wp.constant(64) -TILE_N = wp.constant(64) -TILE_K = wp.constant(8) - - -@wp.kernel -def gemm_tiled(A: wp.array2d(dtype=float), B: wp.array2d(dtype=float), C: wp.array2d(dtype=float)): - # output tile index - i, j = wp.tid() - - sum = wp.tile_zeros(m=TILE_M, n=TILE_N, dtype=wp.float32) - - _M = A.shape[0] - _N = B.shape[1] - K = A.shape[1] - - count = int(K / 8) # TODO: code-gen bug if you use a constant before passing it to a kwd arg (in this case TILE_K) - - for k in range(count): - a = wp.tile_load(A, i * TILE_M, k * TILE_K, m=TILE_M, n=TILE_K) - b = wp.tile_load(B, k * TILE_K, j * TILE_N, m=TILE_K, n=TILE_N) - - # sum += a*b - wp.tile_matmul(a, b, sum) - - wp.tile_store(C, i * TILE_M, j * TILE_N, sum) - - -def benchmark_numpy(A, B, C): - timers = {} - iters = 10 - - # warm up - for _i in range(10): - _C = A @ B - - with wp.ScopedTimer("NumPy", dict=timers): - for _i in range(iters): - _C = A @ B - - return min(timers["NumPy"]) - - -def benchmark_warp_simt(A, B, C): - timers = {} - iters = 10 - - A_wp = wp.array(A) - B_wp = wp.array(B) - C_wp = wp.array(C) - - # warm up - for _i in range(10): - wp.launch(gemm, dim=(M, N), inputs=[A_wp, B_wp, C_wp]) - - with wp.ScopedTimer("Warp (SIMT)", dict=timers, print=False, synchronize=True): - for _i in range(iters): - wp.launch(gemm, dim=(M, N), inputs=[A_wp, B_wp, C_wp]) - - return min(timers["Warp (SIMT)"]) - - -def benchmark_warp_tiled(A, B, C): - timers = {} - iters = 10 - - # must match with the tile_matmul() partition size - SUB_TILE_M = 4 - SUB_TILE_N = 4 - - num_threads = int(TILE_M / SUB_TILE_M) * int(TILE_N / SUB_TILE_N) - A_wp = wp.array(A) - B_wp = wp.array(B) - C_wp = wp.array(C) - - # warm up - wp.capture_begin() - - for _i in range(iters): - wp.launch(gemm_tiled, dim=(int(M / TILE_M), int(N / TILE_N)), inputs=[A_wp, B_wp, C_wp], tile_size=num_threads) - - graph = wp.capture_end() - - with wp.ScopedTimer("Warp (Tiled)", dict=timers, print=False, synchronize=True): - # for i in range(iters): - # wp.launch(gemm_tiled, dim=(int(M/TILE_M), int(N/TILE_N)), inputs=[A_wp, B_wp, C_wp], tile_size=num_threads) - wp.capture_launch(graph) - - return min(timers["Warp (Tiled)"]) - - -def benchmark_torch(A, B, C): - A_tc = torch.from_numpy(A).to("cuda:0") - B_tc = torch.from_numpy(B).to("cuda:0") - C_tc = torch.from_numpy(C).to("cuda:0") - - # warm-up - for _i in range(10): - torch.matmul(A_tc, B_tc, out=C_tc) - - timers = {} - iters = 10 - - torch.cuda.synchronize() - - with wp.ScopedTimer("Torch", dict=timers, print=False): - for _i in range(iters): - torch.matmul(A_tc, B_tc) # , out=C_tc) - - torch.cuda.synchronize() - - return min(timers["Torch"]) - - -results_torch = [] -results_warp_simt = [] -results_warp_tiled = [] - -print("{:>8s} {:>8s} {:>8s} {:>8s} {:>8s} {:>8s}".format("M", "N", "K", "Torch", "Warp (SIMT)", "Warp (Tiled)")) -print("--------------------------------------------------------") - -for i in range(2, 33): - # for i in range(8,9): - - M = i * 128 - N = M - K = N - - # M = TILE_M*21 - # K = TILE_K*7 - # N = TILE_M*12 - - rng = np.random.default_rng(42) - - A = rng.random((M, K), dtype=np.float32) - B = rng.random((K, N), dtype=np.float32) - C = np.zeros((M, N), dtype=np.float32) - - results_torch.append(benchmark_torch(A, B, C)) - results_warp_simt.append(0.0) # benchmark_warp_simt(A, B, C)) - results_warp_tiled.append(benchmark_warp_tiled(A, B, C)) - - print( - "{:>8d} {:>8d} {:>8d} {:>8f} {:>8f} {:>8f}".format( - M, N, K, results_torch[-1], results_warp_simt[-1], results_warp_tiled[-1] - ) - ) diff --git a/warp/native/builtin.h b/warp/native/builtin.h index 6c8fb6376..1e60adb61 100644 --- a/warp/native/builtin.h +++ b/warp/native/builtin.h @@ -1761,6 +1761,5 @@ inline CUDA_CALLABLE void adj_expect_near(const vec3& actual, const vec3& expect // only include in kernels for now #if defined(__CUDACC_RTC__) #include "tile.h" -#include "tile_gemm.h" #include "tile_reduce.h" #endif diff --git a/warp/native/tile.h b/warp/native/tile.h index 83c4f6913..cec06d424 100644 --- a/warp/native/tile.h +++ b/warp/native/tile.h @@ -35,10 +35,6 @@ #endif #define WP_USE_ASYNC_PIPELINE 0 -#if WP_USE_ASYNC_PIPELINE -#include "cuda_pipeline_primitives.h" -#endif // WP_USE_ASYNC_PIPELINE - #define WP_USE_REGISTER_GEMM 0 /* Tile Expressions @@ -308,6 +304,22 @@ struct tile_global_t tile_global_t(array_t& a, const Coord& c) : data(a), offset(c) { } + + inline CUDA_CALLABLE int index_from_coord(const Coord& coord) const + { + // element index + int index = 0; + + WP_PRAGMA_UNROLL + for (int i=0; i < Shape::N; ++i) + { + // global = offset + coord + int c = offset[i] + coord[i]; + index += data.strides[i]*c; + } + + return index/sizeof(T); + } inline CUDA_CALLABLE bool index(const Coord& coord, int& out) const { @@ -1059,7 +1071,45 @@ struct tile_shared_t template inline CUDA_CALLABLE void copy_to_global(const Global& dest) - { + { + // vectorized loads for specific input/output shapes + if constexpr (Layout::Shape::N == 2) + { + constexpr int lastdim = Layout::Shape::N-1; + constexpr bool contiguous_src = Layout::Stride::dim(lastdim) == 1; + const bool contiguous_dest = dest.data.strides[lastdim] == sizeof(T); + const int elements = (dest.data.shape[lastdim] - dest.offset[lastdim]); + const bool aligned = (elements*sizeof(T))%sizeof(float4) == 0; + + if (contiguous_dest && contiguous_src && aligned) + { + constexpr int M = Layout::Shape::dim(0); + constexpr int N = (Layout::Shape::dim(1)*sizeof(T))/sizeof(float4); + + // alias of shared tile with 128bit type + using SrcLayout = tile_layout_strided_t>; + tile_shared_t src128((float4*)data.ptr); + float4* dest128 = (float4*)&dest.data.data[dest.index_from_coord(tile_coord(0,0))]; + + assert(((uint64_t)(data.ptr))%sizeof(float4) == 0); + assert(((uint64_t)(ptr))%sizeof(float4) == 0); + + const int stride_i = dest.data.strides[0]/sizeof(float4); + const int stride_j = 1; + + WP_PRAGMA_UNROLL + for (int i=threadIdx.x; i < SrcLayout::Size; i += WP_TILE_BLOCK_DIM) + { + auto c = SrcLayout::coord_from_linear(i); + + dest128[stride_i*c[0] + stride_j*c[1]] = src128.data(i); + } + + return; + } + } + + // scalar bounds checked path WP_PRAGMA_UNROLL for (int i=threadIdx.x; i < Layout::Size; i += WP_TILE_BLOCK_DIM) { @@ -1068,12 +1118,93 @@ struct tile_shared_t } } + __device__ __forceinline__ + void cp_async_global_to_shared_128(float4* shared_dest, const float4* global_src) + { + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + unsigned long long saddr = 0ULL; + unsigned long long gaddr = 0ULL; + + asm volatile("cvta.to.shared.u64 %0, %1;" : "=l"(saddr) : "l"(shared_dest)); + asm volatile("cvta.to.global.u64 %0, %1;" : "=l"(gaddr) : "l"(global_src)); + + // Use cp.async on newer architectures + asm volatile( + "cp.async.ca.shared.global [%0], [%1], 16;\n" + : + : "l"(saddr), "l"(gaddr) + ); + #else + // use regular load/store through register on older arches + *shared_dest = *global_src; + #endif + } + + __device__ __forceinline__ + void cp_async_commit_and_wait_all_128() + { + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + asm volatile( + "cp.async.commit_group;\n" + "cp.async.wait_group 0;\n" ::); + #endif + } + template inline CUDA_CALLABLE void copy_from_global(const Global& src) - { + { if (initialized) WP_TILE_SYNC(); + + // vectorized loads for specific input/output shapes + if constexpr (Layout::Shape::N == 2) + { + constexpr int lastdim = Layout::Shape::N-1; + constexpr bool contiguous_dest = Layout::Stride::dim(lastdim) == 1; + const bool contiguous_src = src.data.strides[lastdim] == sizeof(T); + const int elements = (src.data.shape[lastdim] - src.offset[lastdim]); + const bool aligned = (elements*sizeof(T))%sizeof(float4) == 0; + + if (contiguous_dest && contiguous_src && aligned) + { + constexpr int M = Layout::Shape::dim(0); + constexpr int N = (Layout::Shape::dim(1)*sizeof(T))/sizeof(float4); + + // alias of shared tile with 128bit type + using DestLayout = tile_layout_strided_t>; + tile_shared_t dest128((float4*)data.ptr); + float4* src128 = (float4*)&src.data.data[src.index_from_coord(tile_coord(0,0))]; + + assert(((uint64_t)(dest128.data.ptr))%sizeof(float4) == 0); + assert(((uint64_t)(src128))%sizeof(float4) == 0); + + const int stride_i = src.data.strides[0]/sizeof(float4); + const int stride_j = 1; + + WP_PRAGMA_UNROLL + for (int i=threadIdx.x; i < DestLayout::Size; i += WP_TILE_BLOCK_DIM) + { + auto c = DestLayout::coord_from_linear(i); + +#if WP_USE_ASYNC_PIPELINE + cp_async_global_to_shared_128(&dest128.data(i), &src128[stride_i*c[0] + stride_j*c[1]]); +#else + dest128.data(i) = src128[stride_i*c[0] + stride_j*c[1]]; +#endif // WP_USE_ASYNC_PIPELINE + } + +#if WP_USE_ASYNC_PIPELINE + cp_async_commit_and_wait_all_128(); +#endif // WP_USE_ASYNC_PIPELINE + + initialized = true; + WP_TILE_SYNC(); + return; + } + } + // scalar bounds checked path WP_PRAGMA_UNROLL for (int i=threadIdx.x; i < Layout::Size; i += WP_TILE_BLOCK_DIM) { @@ -1944,11 +2075,6 @@ TileC& tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B, Ti using T = typename TileA::Type; -#if WP_USE_ASYNC_PIPELINE - __pipeline_wait_prior(0); - WP_TILE_SYNC(); -#endif - #if WP_USE_REGISTER_GEMM partitioned_gemm::matmul(A, B, C); #else diff --git a/warp/native/tile_gemm.h b/warp/native/tile_gemm.h deleted file mode 100644 index 2ab0fe40b..000000000 --- a/warp/native/tile_gemm.h +++ /dev/null @@ -1,341 +0,0 @@ -/** Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved. - * NVIDIA CORPORATION and its licensors retain all intellectual property - * and proprietary rights in and to this software, related documentation - * and any modifications thereto. Any use, reproduction, disclosure or - * distribution of this software and related documentation without an express - * license agreement from NVIDIA CORPORATION is strictly prohibited. - */ - -#pragma once - -#include "builtin.h" - -#define USE_CUTE 0 - -#if USE_CUTE -#include "cutlass/include/cute/tensor.hpp" -#include "cutlass/include/cute/algorithm/cooperative_gemm.hpp" -#endif // USE_CUTE - -namespace wp -{ - -/* -// 2D tile zero -template -inline CUDA_CALLABLE array_t tile_zeros() -{ - const int length = M*N; - - WP_TILE_SHARED __align__(16) T data[length]; - - WP_PRAGMA_UNROLL - for (int t=threadIdx.x; t < length; t += blockDim.x) - { - data[t] = T(0.0); - } - - return array_t(data, M, N, nullptr); -} - -// 2D tile load -template -inline CUDA_CALLABLE array_t tile_load(const array_t& src, int i, int j) -{ - const int length = M*N; - - WP_TILE_SHARED __align__(16) T data[length]; - - //--------------- - // naive-synchronous load - // - // WP_PRAGMA_UNROLL - // for (int t=threadIdx.x; t < length; t += blockDim.x) - // { - // data[t] = index(src, i*M + t/N, j*N + t%N); - // } - - //--------------- - // async 128 bit loads (assumes row-major i.e.: stride 1 on y axis and 4-element alignment on dimension) - const int s = 4; - - WP_PRAGMA_UNROLL - for (int t=threadIdx.x*s; t < length; t += blockDim.x*s) - { - __pipeline_memcpy_async(&data[t], - &index(src, i*M + t/N, j*N + t%N), - sizeof(T)*s); - } - - __pipeline_commit(); - - - return array_t(data, M, N, nullptr); -} - -// 2D tile store -template -inline CUDA_CALLABLE void tile_store(array_t& dest, int i, int j, const array_t& src) -{ - const int M = src.shape[0]; - const int N = src.shape[1]; - - const int length = M*N; - - // cooperatively store the tile, using a block-stride iterator - WP_PRAGMA_UNROLL - for (int t=threadIdx.x; t < length; t += blockDim.x) - { - index(dest, i*M + t/N, j*N + t%N) = src.data[t]; - } -} -*/ - -template -inline CUDA_CALLABLE const T& index(const T* __restrict__ p, int i, int j, int stride) -{ - return p[i*stride + j]; -} - -template -inline CUDA_CALLABLE T& index(T* __restrict__ p, int i, int j, int stride) -{ - return p[i*stride + j]; -} - -template -struct partition_t -{ - inline partition_t(array_t A) - { - data = A; - - // todo: do ceil div for non-multiples of M,N - shape[0] = A.shape[0]/M; - shape[1] = A.shape[1]/N; - } - - // underlying data - array_t data; - - // partition dimensions - int shape[2]; -}; - -template -inline int partition_size(const partition_t& tile) -{ - return tile.shape[0]*tile.shape[1]; -} - -// returns the x, y coordinates of a tile given a linear index -template -inline void partition_coord(const partition_t& tile, const int t, int& i, int& j) -{ - i = t/tile.shape[1]; - j = t%tile.shape[1]; -} - -template -inline mat_t partition_load(const partition_t& tile, int i, int j) -{ - mat_t out; - - const int tile_i = i*M; - const int tile_j = j*N; - - WP_PRAGMA_UNROLL - for (int i=0; i < M; ++i) - { - WP_PRAGMA_UNROLL - for (int j=0; j < N; ++j) - { - out.data[i][j] = index(tile.data, tile_i + i, tile_j + j); - } - } - - return out; -} - -template -inline void partition_store(const partition_t& tile, int i, int j, const mat_t& value) -{ - mat_t out; - - const int tile_i = M*i; - const int tile_j = N*j; - - WP_PRAGMA_UNROLL - for (int i=0; i < M; ++i) - { - WP_PRAGMA_UNROLL - for (int j=0; j < N; ++j) - { - index(tile.data, tile_i + i, tile_j + j) = value.data[i][j]; - } - } -} - - -#if !USE_CUTE - -template -inline CUDA_CALLABLE void gemm(const array_t& A, const array_t& B, const array_t& out) -{ - const int TILE_M = 4; - const int TILE_N = 4; - const int TILE_K = 4; - - partition_t A_tile = partition_t(A); - partition_t B_tile = partition_t(B); - partition_t C_tile = partition_t(out); - - const int length = partition_size(C_tile); - - __pipeline_wait_prior(0); - - WP_TILE_SYNC(); - - for (int t=threadIdx.x; t < length; t += blockDim.x) - { - int i, j; - partition_coord(C_tile, t, i, j); - - // accumulator - mat_t sum = partition_load(C_tile, i, j); - - WP_PRAGMA_UNROLL - for (int k=0; k < A_tile.shape[1]; k++) - { - const mat_t a = partition_load(A_tile, i, k); - const mat_t b = partition_load(B_tile, k, j); - - sum += mul(a, b); - } - - partition_store(C_tile, i, j, sum); - } - - WP_TILE_SYNC(); -} - - -// 2D gemm accumulate out += A*B -template -inline CUDA_CALLABLE void tile_matmul_scalar(const TileA& A, - const TileB& B, - TileC& out) -{ - const int length = tile_size(out); - - WP_TILE_SYNC(); - - using T = typename TileA::Type; - - WP_PRAGMA_UNROLL - for (int t=threadIdx.x; t < length; t += WP_TILE_BLOCK_DIM) - { - // compute output index - const int i = t/out.N; - const int j = t%out.N; - - T sum(0.0); - - WP_PRAGMA_UNROLL - for (int k=0; k < A.N; ++k) - { - T a = A(i,k); - T b = B(k,j); - - sum += a*b; // todo: use fmaf() - } - - out(i,j) += sum; - } - - WP_TILE_SYNC(); -} - -#else - - -template -inline CUDA_CALLABLE void tile_matmul(const array_t& A, const array_t& B, const array_t& out) -{ - using namespace cute; - - __pipeline_wait_prior(0); - - // ensure smem tile is ready - WP_TILE_SYNC(); - - // Define CTA matrix size (static) - auto bM = Int<64>{}; - auto bN = Int<64>{}; - auto bK = Int<8>{}; - - // Define the smem layouts (static) - auto sA = make_layout(make_shape(bM, bK), LayoutRight{}); - auto sB = make_layout(make_shape(bN, bK)); - auto sC = make_layout(make_shape(bM, bN), LayoutRight{}); - - Tensor s_a_tensor = make_tensor(make_smem_ptr(A.data), sA); - Tensor s_b_tensor = make_tensor(make_smem_ptr(B.data), sB); - Tensor s_c_tensor = make_tensor(make_smem_ptr(out.data), sC); - - - // TiledMMA tiled_mma = make_tiled_mma(UniversalFMA{}, - // Layout>{}); // 16x8x1 UniversalFMA, assumes blockDim=128 - - - // TiledMMA tiled_mma = make_tiled_mma(UniversalFMA{}, - // Layout,Stride<_16,_1>>{}); // 8x16x1 UniversalFMA, assumes blockDim=128 - - - - TiledMMA tiled_mma = make_tiled_mma(UniversalFMA{}, - Layout,Stride<_64,_1>>{}); // 8x16x1 UniversalFMA, assumes blockDim=128 - - - cooperative_gemm< AutoVectorizingCopyWithAssumedAlignment>, - AutoVectorizingCopyWithAssumedAlignment>, - AutoVectorizingCopyWithAssumedAlignment> - >( - threadIdx.x, tiled_mma, - 1.0f, s_a_tensor, s_b_tensor, 1.0f, s_c_tensor, - cute::identity(), cute::identity(), cute::identity(), cute::identity() - ); - - WP_TILE_SYNC(); - -} - -#endif // USE_CUTE - - -#if 0 - -template -void tile_matmul(TileA& a, TileB& b, TileC& c) -{ - static_assert(wp::is_same::value, "Error, tile datatypes must match"); - static_assert(TileA::N == TileB::M, "Error, inner dimensions must match"); - static_assert(TileC::M == TileA::M, "Error, first output dimension must match"); - static_assert(TileC::N == TileB::N, "Error, second output dimension must match"); - - tile_matmul_scalar(a, b, c); -} - - -template -void adj_tile_matmul(TileA& a, TileB& b, TileC& c, - AdjTileA& adj_a, AdjTileB& adj_b, AdjTileC& adj_c) -{ - tile_matmul_scalar(adj_c, wp::tile_transpose(b), adj_a); - tile_matmul_scalar(wp::tile_transpose(a), adj_c, adj_b); -} - -#endif // 0 - -} // namespace wp