Skip to content

Commit

Permalink
[Misc][Benchmark] optimize benchmarks (#53)
Browse files Browse the repository at this point in the history
  • Loading branch information
DefTruth authored Sep 28, 2024
1 parent 499c39e commit 0c9166d
Show file tree
Hide file tree
Showing 20 changed files with 1,028 additions and 241 deletions.
99 changes: 49 additions & 50 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,64 +42,63 @@
| ✔️ [relu_f16x2](./relu/relu.cu)|f16|/|[link](./relu/)|⭐️|
| ✔️ [relu_f16x8](./relu/relu.cu)|f16|/|[link](./relu/)|⭐️|
| ✔️ [relu_f16x8_pack](./relu/relu.cu)|f16|/|[link](./relu/)|⭐️⭐️|
| ✔️ [warp_reduce_f16/bf16/f32/f8/i8](./reduce/block_all_reduce.cu)|all|all|[link](./reduce/)|⭐️⭐️|
| ✔️ [block_reduce_f32](./reduce/block_all_reduce.cu)|f32|f32|[link](./reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_f32_f32](./reduce/block_all_reduce.cu)|f32|f32|[link](./reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_f32x4_f32](./reduce/block_all_reduce.cu)|f32|f32|[link](./reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_f16_f16](./reduce/block_all_reduce.cu)|f16|f16|[link](./reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_f16_f32](./reduce/block_all_reduce.cu)|f16|f32|[link](./reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_f16x2_f16](./reduce/block_all_reduce.cu)|f16|f16|[link](./reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_f16x2_f32](./reduce/block_all_reduce.cu)|f16|f32|[link](./reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_f16x8_pack_f16](./reduce/block_all_reduce.cu)|f16|f16|[link](./reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_f16x8_pack_f32](./reduce/block_all_reduce.cu)|f16|f32|[link](./reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_bf16_bf16](./reduce/block_all_reduce.cu)|bf16|bf16|[link](./reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_bf16_f32](./reduce/block_all_reduce.cu)|bf16|f32|[link](./reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_bf16x2_bf16](./reduce/block_all_reduce.cu)|bf16|bf16|[link](./reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_bf16x2_f32](./reduce/block_all_reduce.cu)|bf16|f32|[link](./reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_bf16x8_pack_bf16](./reduce/block_all_reduce.cu)|bf16|bf16|[link](./reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_bf16x8_pack_f32](./reduce/block_all_reduce.cu)|bf16|f32|[link](./reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_fp8_e4m3_f16](./reduce/block_all_reduce.cu)|fp8_e4m3|f16|[link](./reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_fp8_e5m2_f16](./reduce/block_all_reduce.cu)|fp8_e5m2|f16|[link](./reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_fp8_e4m3x16_pack_f16](./reduce/block_all_reduce.cu)|fp8_e4m3|f16|[link](./reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_fp8_e5m2x16_pack_f16](./reduce/block_all_reduce.cu)|fp8_e5m2|f16|[link](./reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_i8_i32](./reduce/block_all_reduce.cu)|i8|i32|[link](./reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_i8x16_pack_i32](./reduce/block_all_reduce.cu)|i8|i32|[link](./reduce/)|⭐️⭐️|
| ✔️ [warp_reduce_[all]](./reduce/reduce.cu)|all|all|[link](./reduce/)|⭐️⭐️|
| ✔️ [reduce_f32_f32](./reduce/reduce.cu)|f32|f32|[link](./reduce/)|⭐️⭐️|
| ✔️ [reduce_f32x4_f32](./reduce/reduce.cu)|f32|f32|[link](./reduce/)|⭐️⭐️|
| ✔️ [reduce_f16_f16](./reduce/reduce.cu)|f16|f16|[link](./reduce/)|⭐️⭐️|
| ✔️ [reduce_f16_f32](./reduce/reduce.cu)|f16|f32|[link](./reduce/)|⭐️⭐️|
| ✔️ [reduce_f16x2_f16](./reduce/reduce.cu)|f16|f16|[link](./reduce/)|⭐️⭐️|
| ✔️ [reduce_f16x2_f32](./reduce/reduce.cu)|f16|f32|[link](./reduce/)|⭐️⭐️|
| ✔️ [reduce_f16x8_pack_f16](./reduce/reduce.cu)|f16|f16|[link](./reduce/)|⭐️⭐️|
| ✔️ [reduce_f16x8_pack_f32](./reduce/reduce.cu)|f16|f32|[link](./reduce/)|⭐️⭐️|
| ✔️ [reduce_bf16_bf16](./reduce/reduce.cu)|bf16|bf16|[link](./reduce/)|⭐️⭐️|
| ✔️ [reduce_bf16_f32](./reduce/reduce.cu)|bf16|f32|[link](./reduce/)|⭐️⭐️|
| ✔️ [reduce_bf16x2_bf16](./reduce/reduce.cu)|bf16|bf16|[link](./reduce/)|⭐️⭐️|
| ✔️ [reduce_bf16x2_f32](./reduce/reduce.cu)|bf16|f32|[link](./reduce/)|⭐️⭐️|
| ✔️ [reduce_bf16x8_pack_bf16](./reduce/reduce.cu)|bf16|bf16|[link](./reduce/)|⭐️⭐️|
| ✔️ [reduce_bf16x8_pack_f32](./reduce/reduce.cu)|bf16|f32|[link](./reduce/)|⭐️⭐️|
| ✔️ [reduce_fp8_e4m3_f16](./reduce/reduce.cu)|fp8_e4m3|f16|[link](./reduce/)|⭐️⭐️|
| ✔️ [reduce_fp8_e5m2_f16](./reduce/reduce.cu)|fp8_e5m2|f16|[link](./reduce/)|⭐️⭐️|
| ✔️ [reduce_fp8_e4m3x16_pack_f16](./reduce/reduce.cu)|fp8_e4m3|f16|[link](./reduce/)|⭐️⭐️|
| ✔️ [reduce_fp8_e5m2x16_pack_f16](./reduce/reduce.cu)|fp8_e5m2|f16|[link](./reduce/)|⭐️⭐️|
| ✔️ [reduce_i8_i32](./reduce/reduce.cu)|i8|i32|[link](./reduce/)|⭐️⭐️|
| ✔️ [reduce_i8x16_pack_i32](./reduce/reduce.cu)|i8|i32|[link](./reduce/)|⭐️⭐️|
| ✔️ [dot_product_f32](./dot-product/dot_product.cu)|f32|f32|[link](./dot-product/)|⭐️⭐️|
| ✔️ [dot_product_f32x4](./dot-product/dot_product.cu)|f32|f32|[link](./dot-product/)|⭐️⭐️|
| ✔️ [dot_product_f16_f32](./dot-product/dot_product.cu)|f16|f32|[link](./dot-product/)|⭐️⭐️|
| ✔️ [dot_product_f16x2_f32](./dot-product/dot_product.cu)|f16|f32|[link](./dot-product/)|⭐️⭐️|
| ✔️ [dot_product_f16x8_pack_f32](./dot-product/dot_product.cu)|f16|f32|[link](./dot-product/)|⭐️⭐️|
| ✔️ [softmax_f32(memory fence)](./softmax/softmax.cu)|f32|f32|[link](./softmax/)|⭐️⭐️|
| ✔️ [softmax_f32x4(memory fence)](./softmax/softmax.cu)|f32|f32|[link](./softmax/)|⭐️⭐️|
| ✔️ [softmax_f32(per token)](./softmax/softmax.cu)|f32|f32|[link](./softmax/)|⭐️⭐️|
| ✔️ [softmax_f32x4(per token)](./softmax/softmax.cu)|f32|f32|[link](./softmax/)|⭐️⭐️|
| ✔️ [safe_softmax_f32(per token)](./softmax/softmax.cu)|f32|f32|[link](./softmax/)|⭐️⭐️|
| ✔️ [safe_softmax_f32x4(per token)](./softmax/softmax.cu)|f32|f32|[link](./softmax/)|⭐️⭐️|
| ✔️ [safe_softmax_f16_f32(per token)](./softmax/softmax.cu)|f16|f32|[link](./softmax/)|⭐️⭐️|
| ✔️ [safe_softmax_f16x2_f32(per token)](./softmax/softmax.cu)|f16|f32|[link](./softmax/)|⭐️⭐️|
| ✔️ [safe_softmax_f16x8_pack_f32(per token)](./softmax/softmax.cu)|f16|f32|[link](./softmax/)|⭐️⭐️|
| ✔️ [layer_norm_f32(per token)](./layer-norm/layer_norm.cu)|f32|f32|[link](./layer-norm/)|⭐️⭐️|
| ✔️ [layer_norm_f32x4(per token)](./layer-norm/layer_norm.cu)|f32|f32|[link](./layer-norm/)|⭐️⭐️|
| ✔️ [layer_norm_f16_f16(per token)](./layer-norm/layer_norm.cu)|f16|f16|[link](./layer-norm/)|⭐️⭐️|
| ✔️ [layer_norm_f16x2_f16(per token)](./layer-norm/layer_norm.cu)|f16|f16|[link](./layer-norm/)|⭐️⭐️|
| ✔️ [layer_norm_f16x8_f16(per token)](./layer-norm/layer_norm.cu)|f16|f16|[link](./layer-norm/)|⭐️⭐️|
| ✔️ [layer_norm_f16x8_pack_f16(per token)](./layer-norm/layer_norm.cu)|f16|f16|[link](./layer-norm/)|⭐️⭐️|
| ✔️ [layer_norm_f16x8_pack_f32(per token)](./layer-norm/layer_norm.cu)|f16|f32|[link](./layer-norm/)|⭐️⭐️|
| ✔️ [layer_norm_f16_f32(per token)](./layer-norm/layer_norm.cu)|f16|f32|[link](./layer-norm/)|⭐️⭐️|
| ✔️ [rms_norm_f32(per token)](./rms-norm/rms_norm.cu)|f32|f32|[link](./rms-norm/)|⭐️⭐️|
| ✔️ [rms_norm_f32x4(per token)](./rms-norm/rms_norm.cu)|f32|f32|[link](./rms-norm/)|⭐️⭐️|
| ✔️ [rms_norm_f16_f16(per token)](./rms-norm/rms_norm.cu)|f16|f16|[link](./rms-norm/)|⭐️⭐️|
| ✔️ [rms_norm_f16x2_f16(per token)](./rms-norm/rms_norm.cu)|f16|f16|[link](./rms-norm/)|⭐️⭐️|
| ✔️ [rms_norm_f16x8_f16(per token)](./rms-norm/rms_norm.cu)|f16|f16|[link](./rms-norm/)|⭐️⭐️|
| ✔️ [rms_norm_f16x8_f32(per token)](./rms-norm/rms_norm.cu)|f16|f32|[link](./rms-norm/)|⭐️⭐️|
| ✔️ [rms_norm_f16x8_pack_f16(per token)](./rms-norm/rms_norm.cu)|f16|f16|[link](./rms-norm/)|⭐️⭐️|
| ✔️ [rms_norm_f16x8_pack_f32(per token)](./rms-norm/rms_norm.cu)|f16|f32|[link](./rms-norm/)|⭐️⭐️|
| ✔️ [rms_norm_f16_f32(per token)](./rms-norm/rms_norm.cu)|f16|f32|[link](./rms-norm/)|⭐️⭐️|
| ✔️ [softmax_f32(fence)](./softmax/softmax.cu)|f32|f32|[link](./softmax/)|⭐️⭐️|
| ✔️ [softmax_f32x4(fence)](./softmax/softmax.cu)|f32|f32|[link](./softmax/)|⭐️⭐️|
| ✔️ [softmax_f32](./softmax/softmax.cu)|f32|f32|[link](./softmax/)|⭐️⭐️|
| ✔️ [softmax_f32x4](./softmax/softmax.cu)|f32|f32|[link](./softmax/)|⭐️⭐️|
| ✔️ [safe_softmax_f32](./softmax/softmax.cu)|f32|f32|[link](./softmax/)|⭐️⭐️|
| ✔️ [safe_softmax_f32x4](./softmax/softmax.cu)|f32|f32|[link](./softmax/)|⭐️⭐️|
| ✔️ [safe_softmax_f16_f32](./softmax/softmax.cu)|f16|f32|[link](./softmax/)|⭐️⭐️|
| ✔️ [safe_softmax_f16x2_f32](./softmax/softmax.cu)|f16|f32|[link](./softmax/)|⭐️⭐️|
| ✔️ [safe_softmax_f16x8_pack_f32](./softmax/softmax.cu)|f16|f32|[link](./softmax/)|⭐️⭐️|
| ✔️ [layer_norm_f32](./layer-norm/layer_norm.cu)|f32|f32|[link](./layer-norm/)|⭐️⭐️|
| ✔️ [layer_norm_f32x4](./layer-norm/layer_norm.cu)|f32|f32|[link](./layer-norm/)|⭐️⭐️|
| ✔️ [layer_norm_f16_f16](./layer-norm/layer_norm.cu)|f16|f16|[link](./layer-norm/)|⭐️⭐️|
| ✔️ [layer_norm_f16x2_f16](./layer-norm/layer_norm.cu)|f16|f16|[link](./layer-norm/)|⭐️⭐️|
| ✔️ [layer_norm_f16x8_f16](./layer-norm/layer_norm.cu)|f16|f16|[link](./layer-norm/)|⭐️⭐️|
| ✔️ [layer_norm_f16x8_pack_f16](./layer-norm/layer_norm.cu)|f16|f16|[link](./layer-norm/)|⭐️⭐️|
| ✔️ [layer_norm_f16x8_pack_f32](./layer-norm/layer_norm.cu)|f16|f32|[link](./layer-norm/)|⭐️⭐️|
| ✔️ [layer_norm_f16_f32](./layer-norm/layer_norm.cu)|f16|f32|[link](./layer-norm/)|⭐️⭐️|
| ✔️ [rms_norm_f32](./rms-norm/rms_norm.cu)|f32|f32|[link](./rms-norm/)|⭐️⭐️|
| ✔️ [rms_norm_f32x4](./rms-norm/rms_norm.cu)|f32|f32|[link](./rms-norm/)|⭐️⭐️|
| ✔️ [rms_norm_f16_f16](./rms-norm/rms_norm.cu)|f16|f16|[link](./rms-norm/)|⭐️⭐️|
| ✔️ [rms_norm_f16x2_f16](./rms-norm/rms_norm.cu)|f16|f16|[link](./rms-norm/)|⭐️⭐️|
| ✔️ [rms_norm_f16x8_f16](./rms-norm/rms_norm.cu)|f16|f16|[link](./rms-norm/)|⭐️⭐️|
| ✔️ [rms_norm_f16x8_f32](./rms-norm/rms_norm.cu)|f16|f32|[link](./rms-norm/)|⭐️⭐️|
| ✔️ [rms_norm_f16x8_pack_f16](./rms-norm/rms_norm.cu)|f16|f16|[link](./rms-norm/)|⭐️⭐️|
| ✔️ [rms_norm_f16x8_pack_f32](./rms-norm/rms_norm.cu)|f16|f32|[link](./rms-norm/)|⭐️⭐️|
| ✔️ [rms_norm_f16_f32](./rms-norm/rms_norm.cu)|f16|f32|[link](./rms-norm/)|⭐️⭐️|
| ✔️ [sgemm_naive_f32](./sgemm/sgemm.cu)|f32|f32|[link](./sgemm/)|⭐️⭐️|
| ✔️ [sgemm_sliced_k_f32](./sgemm/sgemm.cu)|f32|f32|[link](./sgemm/)|⭐️⭐️⭐️|
| ✔️ [sgemm_t_8x8_sliced_k_f32x4](./sgemm/sgemm.cu)|f32|f32|[link](./sgemm/)|⭐️⭐️⭐️|
| ✔️ [sgemm_t_8x8_sliced_k_f32x4_bcf](./sgemm/sgemm.cu)|f32|f32|[link](./sgemm/)|⭐️⭐️⭐️|
| ✔️ [sgemm_t_8x8_sliced_k_f32x4_bcf_dbuf](./sgemm/sgemm.cu)|f32|f32|[link](./sgemm/)|⭐️⭐️⭐️|
| ✔️ [sgemm_t_8x8_sliced_k_..._bcf](./sgemm/sgemm.cu)|f32|f32|[link](./sgemm/)|⭐️⭐️⭐️|
| ✔️ [sgemm_t_8x8_sliced_k_..._dbuf](./sgemm/sgemm.cu)|f32|f32|[link](./sgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_sliced_k_f16](./hgemm/hgemm.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_t_8x8_sliced_k_f16x4](./hgemm/hgemm.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️|
| ✔️ [sgemv_k32_f32](./sgemv/sgemv.cu)|f32|f32|[link](./sgemv/)|⭐️⭐️⭐️|
Expand Down
105 changes: 97 additions & 8 deletions dot-product/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
## 测试

```bash
# 只测试Ada架构 不指定默认编译所有架构 耗时较长
# 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ...
export TORCH_CUDA_ARCH_LIST=Ada
python3 dot_product.py
```
Expand All @@ -23,13 +23,102 @@ python3 dot_product.py

```bash
--------------------------------------------------------------------------------
out_f32f32: -1534.59301758 , time:0.17350578ms
out_f32x4f32: -1534.61364746 , time:0.18058038ms
out_f32f32_th: -1534.61157227 , time:0.18307972ms
S=1024, K=1024
out_f32f32: -670.21264648 , time:0.08947158ms
out_f32x4f32: -670.21435547 , time:0.02821302ms
out_f32f32_th: -670.21374512 , time:0.09709382ms
--------------------------------------------------------------------------------
out_f16f32: -1538.26318359 , time:0.10106802ms
out_f16x2f32: -1537.58288574 , time:0.05217433ms
out_f16x8packf32: -1536.44006348 , time:0.02096844ms
out_f16f16_th: -1536.00000000 , time:0.02491832ms
out_f16f32: -670.32208252 , time:0.04000235ms
out_f16x2f32: -670.15814209 , time:0.05491829ms
out_f16x8packf32: -669.90997314 , time:0.01669478ms
out_f16f16_th: -670.50000000 , time:0.02021313ms
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
S=1024, K=2048
out_f32f32: 1040.51086426 , time:0.04557490ms
out_f32x4f32: 1040.50720215 , time:0.06275582ms
out_f32f32_th: 1040.50842285 , time:0.04762864ms
--------------------------------------------------------------------------------
out_f16f32: 1041.44299316 , time:0.03214121ms
out_f16x2f32: 1041.79589844 , time:0.03448486ms
out_f16x8packf32: 1042.22717285 , time:0.02689457ms
out_f16f16_th: 1041.00000000 , time:0.02859521ms
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
S=1024, K=4096
out_f32f32: -1859.81457520 , time:0.08664179ms
out_f32x4f32: -1859.81628418 , time:0.08621526ms
out_f32f32_th: -1859.81933594 , time:0.08647323ms
--------------------------------------------------------------------------------
out_f16f32: -1860.23291016 , time:0.05826116ms
out_f16x2f32: -1860.91186523 , time:0.04677963ms
out_f16x8packf32: -1860.25988770 , time:0.04591107ms
out_f16f16_th: -1861.00000000 , time:0.04904127ms
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
S=2048, K=1024
out_f32f32: 858.98229980 , time:0.04499865ms
out_f32x4f32: 858.98461914 , time:0.04623890ms
out_f32f32_th: 858.98376465 , time:0.06848693ms
--------------------------------------------------------------------------------
out_f16f32: 858.85339355 , time:0.03274632ms
out_f16x2f32: 858.94274902 , time:0.02831578ms
out_f16x8packf32: 859.46844482 , time:0.02884459ms
out_f16f16_th: 859.00000000 , time:0.03692698ms
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
S=2048, K=2048
out_f32f32: -1205.77990723 , time:0.08356524ms
out_f32x4f32: -1205.77624512 , time:0.08583307ms
out_f32f32_th: -1205.77807617 , time:0.08613133ms
--------------------------------------------------------------------------------
out_f16f32: -1205.40588379 , time:0.06001544ms
out_f16x2f32: -1205.29028320 , time:0.04738235ms
out_f16x8packf32: -1205.72924805 , time:0.04624581ms
out_f16f16_th: -1205.00000000 , time:0.04907203ms
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
S=2048, K=4096
out_f32f32: -893.49169922 , time:0.16136765ms
out_f32x4f32: -893.48596191 , time:0.16174912ms
out_f32f32_th: -893.48901367 , time:0.16518927ms
--------------------------------------------------------------------------------
out_f16f32: -894.42169189 , time:0.11468077ms
out_f16x2f32: -894.61779785 , time:0.08950567ms
out_f16x8packf32: -895.26538086 , time:0.08448958ms
out_f16f16_th: -894.00000000 , time:0.09156108ms
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
S=4096, K=1024
out_f32f32: 141.78890991 , time:0.08385873ms
out_f32x4f32: 141.78639221 , time:0.08500123ms
out_f32f32_th: 141.78683472 , time:0.08647728ms
--------------------------------------------------------------------------------
out_f16f32: 141.80113220 , time:0.05876780ms
out_f16x2f32: 141.62113953 , time:0.04708385ms
out_f16x8packf32: 141.15240479 , time:0.04586506ms
out_f16f16_th: 141.50000000 , time:0.04933500ms
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
S=4096, K=2048
out_f32f32: -1238.80456543 , time:0.16236329ms
out_f32x4f32: -1238.80737305 , time:0.16246724ms
out_f32f32_th: -1238.80859375 , time:0.16496468ms
--------------------------------------------------------------------------------
out_f16f32: -1238.78466797 , time:0.11416745ms
out_f16x2f32: -1239.28540039 , time:0.08488607ms
out_f16x8packf32: -1238.85302734 , time:0.08867455ms
out_f16f16_th: -1239.00000000 , time:0.09029007ms
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
S=4096, K=4096
out_f32f32: 556.32690430 , time:0.31692672ms
out_f32x4f32: 556.33087158 , time:0.31752276ms
out_f32f32_th: 556.32879639 , time:0.32040811ms
--------------------------------------------------------------------------------
out_f16f32: 554.45031738 , time:0.23417449ms
out_f16x2f32: 553.61444092 , time:0.16469955ms
out_f16x8packf32: 554.04040527 , time:0.16465998ms
out_f16f16_th: 554.50000000 , time:0.17046404ms
--------------------------------------------------------------------------------
```
37 changes: 21 additions & 16 deletions dot-product/dot_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,24 @@ def run_benchmark(perf_func: callable, a: torch.Tensor, b: torch.Tensor, tag: st
return out, mean_time


print("-" * 80)
S, K = 4096, 4096
a = torch.randn((S*K)).cuda().float()
b = torch.randn((S*K)).cuda().float()
run_benchmark(lib.dot_prod_f32_f32, a, b, "f32f32")
run_benchmark(lib.dot_prod_f32x4_f32, a, b, "f32x4f32")
run_benchmark(torch.dot, a, b, "f32f32_th")

print("-" * 80)
a_f16 = a.half()
b_f16 = b.half()
run_benchmark(lib.dot_prod_f16_f32, a_f16, b_f16, "f16f32")
run_benchmark(lib.dot_prod_f16x2_f32, a_f16, b_f16, "f16x2f32")
run_benchmark(lib.dot_prod_f16x8_pack_f32, a_f16, b_f16, "f16x8packf32")
run_benchmark(torch.dot, a_f16, b_f16, "f16f16_th")
print("-" * 80)
Ss = [1024, 2048, 4096]
Ks = [1024, 2048, 4096]
SKs = [(S, K) for S in Ss for K in Ks]

for (S, K) in SKs:
print("-" * 80)
print(" " * 25 + f"S={S}, K={K}")
a = torch.randn((S*K)).cuda().float()
b = torch.randn((S*K)).cuda().float()
run_benchmark(lib.dot_prod_f32_f32, a, b, "f32f32")
run_benchmark(lib.dot_prod_f32x4_f32, a, b, "f32x4f32")
run_benchmark(torch.dot, a, b, "f32f32_th")

print("-" * 80)
a_f16 = a.half()
b_f16 = b.half()
run_benchmark(lib.dot_prod_f16_f32, a_f16, b_f16, "f16f32")
run_benchmark(lib.dot_prod_f16x2_f32, a_f16, b_f16, "f16x2f32")
run_benchmark(lib.dot_prod_f16x8_pack_f32, a_f16, b_f16, "f16x8packf32")
run_benchmark(torch.dot, a_f16, b_f16, "f16f16_th")
print("-" * 80)
Loading

0 comments on commit 0c9166d

Please sign in to comment.