Skip to content

Commit

Permalink
fix: fix macro to suppress compilation warning (flashinfer-ai#231)
Browse files Browse the repository at this point in the history
There are some mistakes in our macro definitions which results in lots
of warnings and potential bugs.
This PR fixes the issue.
  • Loading branch information
yzh119 authored May 4, 2024
1 parent 11ca502 commit 94bcf6f
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 19 deletions.
8 changes: 4 additions & 4 deletions python/csrc/pytorch_extension_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,10 @@ using namespace flashinfer;
} \
}()

#define _DISPATCH_CASE(case_expr, var, ...) \
case case_expr: { \
constexpr auto var = case_expr; \
return __VA_ARGS__(); \
#define _DISPATCH_CASE(case_expr, case_var, ...) \
case case_expr: { \
constexpr auto case_var = case_expr; \
return __VA_ARGS__(); \
}

#define DISPATCH_group_size(expr, const_expr, ...) \
Expand Down
28 changes: 13 additions & 15 deletions python/generate_dispatch_inc.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,79 +27,77 @@ def get_dispatch_inc_str(args: argparse.Namespace) -> str:
for _ in args.head_dims
]
)
dispatch_head_dims_str = f"""#define _DISPATCH_CASES_head_dim(...) \\
dispatch_head_dims_str = f"""#define _DISPATCH_CASES_head_dim(const_var, ...) \\
{dispatch_head_dims_entries}
// EOL
"""
# group sizes
dispatch_group_sizes_entries = "\n".join(
[
" _DISPATCH_CASE({}, GROUP_SIZE, __VA_ARGS__) \\".format(_)
" _DISPATCH_CASE({}, case_var, __VA_ARGS__) \\".format(_)
for _ in args.group_sizes
]
)
dispatch_group_sizes_str = f"""#define _DISPATCH_CASES_group_size(...) \\
dispatch_group_sizes_str = f"""#define _DISPATCH_CASES_group_size(case_var, ...) \\
{dispatch_group_sizes_entries}
// EOL
"""
# page sizes
dispatch_page_sizes_entries = "\n".join(
[
" _DISPATCH_CASE({}, PAGE_SIZE, __VA_ARGS__) \\".format(_)
" _DISPATCH_CASE({}, case_var, __VA_ARGS__) \\".format(_)
for _ in args.page_sizes
]
)
dispatch_page_sizes_str = f"""#define _DISPATCH_CASES_page_size(...) \\
dispatch_page_sizes_str = f"""#define _DISPATCH_CASES_page_size(case_var, ...) \\
{dispatch_page_sizes_entries}
// EOL
"""
# kv layouts
dispatch_kv_layouts_entries = "\n".join(
[
" _DISPATCH_CASE({}, KV_LAYOUT, __VA_ARGS__) \\".format(
" _DISPATCH_CASE({}, case_var, __VA_ARGS__) \\".format(
kv_layout_literal[_]
)
for _ in args.kv_layouts
]
)
dispatch_kv_layouts_str = f"""#define _DISPATCH_CASES_kv_layout(...) \\
dispatch_kv_layouts_str = f"""#define _DISPATCH_CASES_kv_layout(case_var, ...) \\
{dispatch_kv_layouts_entries}
// EOL
"""
# positional encoding modes
dispatch_pos_encoding_modes_entries = "\n".join(
[
" _DISPATCH_CASE({}, POS_ENCODING_MODE, __VA_ARGS__) \\".format(
" _DISPATCH_CASE({}, case_var, __VA_ARGS__) \\".format(
pos_encoding_mode_literal[_]
)
for _ in args.pos_encoding_modes
]
)
dispatch_pos_encoding_modes_str = f"""#define _DISPATCH_CASES_pos_encoding_mode(...) \\
dispatch_pos_encoding_modes_str = f"""#define _DISPATCH_CASES_pos_encoding_mode(case_var, ...) \\
{dispatch_pos_encoding_modes_entries}
// EOL
"""
# allow fp16 qk reductions
dispatch_allow_fp16_qk_reduction_entries = "\n".join(
[
" _DISPATCH_CASE({}, ALLOW_FP16_QK_REDUCTION, __VA_ARGS__) \\".format(
bool_literal[_]
)
" _DISPATCH_CASE({}, case_var, __VA_ARGS__) \\".format(bool_literal[_])
for _ in args.allow_fp16_qk_reductions
]
)
dispatch_allow_fp16_qk_reductions_str = f"""#define _DISPATCH_CASES_allow_fp16_qk_reduction(...) \\
dispatch_allow_fp16_qk_reductions_str = f"""#define _DISPATCH_CASES_allow_fp16_qk_reduction(case_var, ...) \\
{dispatch_allow_fp16_qk_reduction_entries}
// EOL
"""
# causal
dispatch_causal_entries = "\n".join(
[
" _DISPATCH_CASE({}, CAUSAL, __VA_ARGS__) \\".format(bool_literal[_])
" _DISPATCH_CASE({}, case_var, __VA_ARGS__) \\".format(bool_literal[_])
for _ in args.causals
]
)
dispatch_causal_str = f"""#define _DISPATCH_CASES_causal(...) \\
dispatch_causal_str = f"""#define _DISPATCH_CASES_causal(case_var, ...) \\
{dispatch_causal_entries}
// EOL
"""
Expand Down

0 comments on commit 94bcf6f

Please sign in to comment.