Skip to content

Commit

Permalink
[Iluvatar] Add vdot_heur_block_size
Browse files Browse the repository at this point in the history
  • Loading branch information
junjian.zhan committed Feb 11, 2025
1 parent 5aa3ab7 commit c988738
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions src/flag_gems/runtime/backend/_iluvatar/heuristics_config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,16 @@ def batch_norm_heur_block_n(args):
return min(BLOCK_N, max(1, 2**14 // BLOCK_M))


def vdot_heur_block_size(args):
n = args["n_elements"]
if n < 1024:
return 32
elif n < 8192:
return 256
else:
return 1024


HEURISTICS_CONFIGS = {
"argmax": {
"BLOCK_M": argmax_heur_block_m,
Expand Down Expand Up @@ -289,4 +299,7 @@ def batch_norm_heur_block_n(args):
"BLOCK_M": batch_norm_heur_block_m,
"BLOCK_N": batch_norm_heur_block_n,
},
"vdot": {
"BLOCK_SIZE": vdot_heur_block_size,
},
}

0 comments on commit c988738

Please sign in to comment.