Skip to content

Commit

Permalink
[Dlight] Enhance vectorization loading weight for gemv (#16878)
Browse files Browse the repository at this point in the history
* [Dlight] Enhance vectorization loading weight for gemv


* Update gemv.py
  • Loading branch information
vinx13 authored Apr 13, 2024
1 parent 0a3fe22 commit 5c80691
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 37 deletions.
18 changes: 9 additions & 9 deletions python/tvm/dlight/gpu/gemv.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
"""A rule for GEMV and DecodeGEMV."""
import re
from functools import reduce
from typing import List, Optional, Union

Expand Down Expand Up @@ -56,10 +55,9 @@ def get_extent(sch: tir.Schedule, loop_rv: tir.schedule.LoopRV):


def get_bytes(dtype: Union[DataType, str]) -> int:
num = re.findall(r"\d+", dtype)
if len(num) != 1:
raise ValueError(f"Cannot get bytes from {dtype}")
return int(num[0]) // 8
if isinstance(dtype, str):
dtype = DataType(dtype)
return dtype.bits * dtype.lanes // 8


def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> Optional[List[tir.Buffer]]:
Expand Down Expand Up @@ -297,10 +295,11 @@ def apply(
Aq_local = sch.cache_read(rf, read_buffer_index=1, storage_scope="local")
sch.compute_at(Aq_local, r, preserve_unit_loops=True)
s_local, r_local = sch.get_loops(block=Aq_local)[-2:]
s_local, vec_load = sch.split(
s_local, factors=[None, VEC_LOAD], preserve_unit_iters=True
fused_load = sch.fuse(s_local, r_local)
aq_vec_len = max(1, VEC_LOAD // get_bytes(sch.get(Aq_local).reads[0].buffer.dtype))
fused_load, vec_load = sch.split(
fused_load, factors=[None, aq_vec_len], preserve_unit_iters=True
)
sch.reorder(s_local, r_local, vec_load) # either s_local or r_local should be 1
sch.vectorize(vec_load)

# load vector into shared memory, shape should be the whole vector
Expand Down Expand Up @@ -442,10 +441,12 @@ def apply(

TAG_S, TAG_R = "threadIdx.y", "threadIdx.x"
SUPPORT_WARP_SHUFFLE = False
VEC_LOAD = 1
if target.kind.name == "cuda":
VEC_C = 4
LOAD_V_SHARED = True
LOAD_V_VEC = 8
VEC_LOAD = 4
UNROLL = 256
SUPPORT_WARP_SHUFFLE = True
if isinstance(len_S, int):
Expand Down Expand Up @@ -522,7 +523,6 @@ def apply(
else max(get_max_factor(len_r, [TR * 1, TR * 2, TR * 4, TR * 8]) // TR, 1),
)
VEC_C = min(get_max_factor(TILE_R, [1, 2, 4, 8]), VEC_C)
VEC_LOAD = 1

return apply(
sch,
Expand Down
57 changes: 29 additions & 28 deletions tests/python/dlight/test_gpu_gemv.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,13 @@ def expected(lv1637: T.Buffer((1, 32, 1, 128), "float16"), p_lv1638: T.handle, p
T.writes(var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, 0, v0, 0, v1])
var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, 0, v0, 0, v1] = T.float16(0)
for ax2_fused_u_fused_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
for ax0, ax1, ax2_0, ax3 in T.grid(1, 1, 1, 2):
for ax2_1 in T.vectorized(1):
for ax0, ax1, ax2_ax3_fused_0 in T.grid(1, 1, 1):
for ax2_ax3_fused_1 in T.vectorized(2):
with T.block("lv1638_local"):
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.spatial(32, ax0_fused_ax1_fused_fused_0 // n + ax1)
v2 = T.axis.spatial(n, ax0_fused_ax1_fused_fused_0 % n + ax2_0 + ax2_1)
v3 = T.axis.spatial(128, ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * 2 + ax3)
v2 = T.axis.spatial(n, ax0_fused_ax1_fused_fused_0 % n)
v3 = T.axis.spatial(128, ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * 2 + ax2_ax3_fused_0 * 2 + ax2_ax3_fused_1)
T.reads(lv1638[v0, v1, v2, v3])
T.writes(lv1638_local[v0, v1, v2, v3])
lv1638_local[v0, v1, v2, v3] = lv1638[v0, v1, v2, v3]
Expand Down Expand Up @@ -224,11 +224,11 @@ def expected(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 12
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0])
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = T.float16(0)
for ax1_0_fused_ax1_1_fused_0 in T.serial(32, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
for ax0_0, ax1 in T.grid(1, 1):
for ax0_ax1_fused in T.serial(1):
for ax0_1 in T.vectorized(1):
with T.block("lv571_local"):
v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1)
v1 = T.axis.spatial(512, ax1_0_fused_ax1_1_fused_0 * 16 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1)
v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1)
v1 = T.axis.spatial(512, ax1_0_fused_ax1_1_fused_0 * 16 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0)
T.reads(lv571[v0, v1])
T.writes(lv571_local[v0, v1])
lv571_local[v0, v1] = lv571[v0, v1]
Expand Down Expand Up @@ -332,11 +332,11 @@ def expected(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 12
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0])
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = T.float16(0)
for ax1_0_fused_ax1_1_fused_0 in T.serial(8, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
for ax0_0, ax1 in T.grid(1, 1):
for ax0_1 in T.vectorized(1):
for ax0_ax1_fused_0 in range(1):
for ax0_ax1_fused_1 in T.vectorized(1):
with T.block("lv571_local"):
v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1)
v1 = T.axis.spatial(512, ax1_0_fused_ax1_1_fused_0 * 64 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1)
v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1)
v1 = T.axis.spatial(512, ax1_0_fused_ax1_1_fused_0 * 64 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0)
T.reads(lv571[v0, v1])
T.writes(lv571_local[v0, v1])
lv571_local[v0, v1] = lv571[v0, v1]
Expand Down Expand Up @@ -448,11 +448,11 @@ def expected(lv771: T.Buffer((32000, 512), "uint32"), lv772: T.Buffer((32000, 12
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0])
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = T.float16(0)
for ax1_0_fused_ax1_1_fused_0 in T.serial(8, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
for ax0_0, ax1 in T.grid(1, 1):
for ax0_1 in T.vectorized(1):
for ax0_ax1_fused_0 in range(1):
for ax0_ax1_fused_1 in T.vectorized(1):
with T.block("lv771_local"):
v0 = T.axis.spatial(32000, u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1)
v1 = T.axis.spatial(512, ax1_0_fused_ax1_1_fused_0 * 64 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1)
v0 = T.axis.spatial(32000, u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1)
v1 = T.axis.spatial(512, ax1_0_fused_ax1_1_fused_0 * 64 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0)
T.reads(lv771[v0, v1])
T.writes(lv771_local[v0, v1])
lv771_local[v0, v1] = lv771[v0, v1]
Expand Down Expand Up @@ -572,11 +572,11 @@ def expected(lv575: T.Buffer((T.int64(4096), T.int64(1376)), "uint32"), lv576: T
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, T.int64(0), T.int64(0), v0])
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, T.int64(0), T.int64(0), v0] = T.float16(0)
for ax1_0_fused_ax1_1_fused_0 in T.serial(T.int64(43), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
for ax0_0, ax1 in T.grid(T.int64(1), T.int64(1)):
for ax0_1 in T.vectorized(T.int64(1)):
for ax0_ax1_fused_0 in range(T.int64(1)):
for ax0_ax1_fused_1 in T.vectorized(T.int64(1)):
with T.block("lv575_local"):
v0 = T.axis.spatial(T.int64(4096), u_fused_ax0_fused_fused_0 * T.int64(16) + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1)
v1 = T.axis.spatial(T.int64(1376), ax1_0_fused_ax1_1_fused_0 * T.int64(32) + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1)
v0 = T.axis.spatial(T.int64(4096), u_fused_ax0_fused_fused_0 * T.int64(16) + u_fused_ax0_fused_fused_1)
v1 = T.axis.spatial(T.int64(1376), ax1_0_fused_ax1_1_fused_0 * T.int64(32) + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0)
T.reads(lv575[v0, v1])
T.writes(lv575_local[v0, v1])
lv575_local[v0, v1] = lv575[v0, v1]
Expand Down Expand Up @@ -942,15 +942,16 @@ def expected(x: T.Buffer((1, 4096), "float16"), w: T.Buffer((8, 16384, 4096), "f
T.writes(o_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, v_expert_id_o, v0])
o_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, v_expert_id_o, v0] = T.float16(0)
for ax1_fused_u_fused_0 in T.serial(32, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
for ax0, ax1_0, ax2 in T.grid(1, 1, 8):
for ax1_1 in T.vectorized(1):
with T.block("w_local"):
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.spatial(16384, u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + ax1_0 + ax1_1)
v2 = T.axis.spatial(4096, ax1_fused_u_fused_0 * 128 + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * 8 + ax2)
T.reads(w[indptr[v_expert_id_o] + v0, v1, v2])
T.writes(w_local[v0, v1, v2])
w_local[v0, v1, v2] = w[indptr[v_expert_id_o] + v0, v1, v2]
for ax0 in range(1):
for ax1_ax2_fused_0 in range(8):
for ax1_ax2_fused_1 in T.vectorized(1):
with T.block("w_local"):
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.spatial(16384, u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1)
v2 = T.axis.spatial(4096, ax1_fused_u_fused_0 * 128 + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * 8 + ax1_ax2_fused_0 + ax1_ax2_fused_1)
T.reads(w[indptr[v_expert_id_o] + v0, v1, v2])
T.writes(w_local[v0, v1, v2])
w_local[v0, v1, v2] = w[indptr[v_expert_id_o] + v0, v1, v2]
for u_fused_ax0_fused_fused_2, ax1_fused_u_fused_2 in T.grid(1, 8):
for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1 in T.vectorized(1):
with T.block("gemv_rf_update"):
Expand Down

0 comments on commit 5c80691

Please sign in to comment.