Skip to content

Latest commit

 

History

History

hardshrink

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 

HardSharink

0x00 说明

包含以下内容:

  • hardshrink_f32_kernel
  • hardshrink_f32x4_kernel(float4向量化版本)
  • hardshrink_f16_kernel(fp16版本)
  • hardshrink_f16x2_kernel(fp16向量化版本)
  • hardshrink_f16x8_kernel(fp16向量化版本)
  • hardshrink_f16x8_pack_kernel(fp16向量化,pack版本)
  • PyTorch bindings

测试

# 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ...
export TORCH_CUDA_ARCH_LIST=Ada 
python3 hardshrink.py

输出:

-------------------------------------------------------------------------------------
                                        S=1024, K=1024
           out_f32: ['0.0         ', '0.0         '], time:0.00428367ms
         out_f32x4: ['0.0         ', '0.0         '], time:0.00370097ms
        out_f32_th: ['0.0         ', '0.0         '], time:0.00739479ms
-------------------------------------------------------------------------------------
           out_f16: ['0.0         ', '0.0         '], time:0.00417209ms
         out_f16x2: ['0.0         ', '0.0         '], time:0.00280690ms
         out_f16x8: ['0.0         ', '0.0         '], time:0.00266337ms
     out_f16x8pack: ['0.0         ', '0.0         '], time:0.00255847ms
        out_f16_th: ['0.0         ', '0.0         '], time:0.00636697ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
                                        S=1024, K=2048
           out_f32: ['-0.82629085 ', '0.0         '], time:0.00588202ms
         out_f32x4: ['-0.82629085 ', '0.0         '], time:0.00565958ms
        out_f32_th: ['-0.82629085 ', '0.0         '], time:0.01137996ms
-------------------------------------------------------------------------------------
           out_f16: ['-0.82617188 ', '0.0         '], time:0.00589466ms
         out_f16x2: ['-0.82617188 ', '0.0         '], time:0.00484967ms
         out_f16x8: ['-0.82617188 ', '0.0         '], time:0.00404549ms
     out_f16x8pack: ['-0.82617188 ', '0.0         '], time:0.00391698ms
        out_f16_th: ['-0.82617188 ', '0.0         '], time:0.00740528ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
                                        S=1024, K=4096
           out_f32: ['0.0         ', '0.0         '], time:0.01001120ms
         out_f32x4: ['0.0         ', '0.0         '], time:0.00953746ms
        out_f32_th: ['0.0         ', '0.0         '], time:0.01943731ms
-------------------------------------------------------------------------------------
           out_f16: ['0.0         ', '0.0         '], time:0.01005745ms
         out_f16x2: ['0.0         ', '0.0         '], time:0.00977778ms
         out_f16x8: ['0.0         ', '0.0         '], time:0.00651360ms
     out_f16x8pack: ['0.0         ', '0.0         '], time:0.00592923ms
        out_f16_th: ['0.0         ', '0.0         '], time:0.01157641ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
                                        S=2048, K=1024
           out_f32: ['0.90714514  ', '0.0         '], time:0.00693989ms
         out_f32x4: ['0.90714514  ', '0.0         '], time:0.00567746ms
        out_f32_th: ['0.90714514  ', '0.0         '], time:0.01138806ms
-------------------------------------------------------------------------------------
           out_f16: ['0.90722656  ', '0.0         '], time:0.00694394ms
         out_f16x2: ['0.90722656  ', '0.0         '], time:0.00400567ms
         out_f16x8: ['0.90722656  ', '0.0         '], time:0.00397611ms
     out_f16x8pack: ['0.90722656  ', '0.0         '], time:0.00391245ms
        out_f16_th: ['0.90722656  ', '0.0         '], time:0.00740814ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
                                        S=2048, K=2048
           out_f32: ['-0.61146295 ', '-1.60048997 '], time:0.01001549ms
         out_f32x4: ['-0.61146295 ', '-1.60048997 '], time:0.00960755ms
        out_f32_th: ['-0.61146295 ', '-1.60048997 '], time:0.01942730ms
-------------------------------------------------------------------------------------
           out_f16: ['-0.61132812 ', '-1.60058594 '], time:0.01005864ms
         out_f16x2: ['-0.61132812 ', '-1.60058594 '], time:0.00810075ms
         out_f16x8: ['-0.61132812 ', '-1.60058594 '], time:0.00622988ms
     out_f16x8pack: ['-0.61132812 ', '-1.60058594 '], time:0.00591063ms
        out_f16_th: ['-0.61132812 ', '-1.60058594 '], time:0.01156282ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
                                        S=2048, K=4096
           out_f32: ['0.0         ', '0.55151719  '], time:0.01828051ms
         out_f32x4: ['0.0         ', '0.55151719  '], time:0.01861715ms
        out_f32_th: ['0.0         ', '0.55151719  '], time:0.09329295ms
-------------------------------------------------------------------------------------
           out_f16: ['0.0         ', '0.55175781  '], time:0.01676941ms
         out_f16x2: ['0.0         ', '0.55175781  '], time:0.01659012ms
         out_f16x8: ['0.0         ', '0.55175781  '], time:0.00997281ms
     out_f16x8pack: ['0.0         ', '0.55175781  '], time:0.00901723ms
        out_f16_th: ['0.0         ', '0.55175781  '], time:0.01817441ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
                                        S=4096, K=1024
           out_f32: ['1.19508219  ', '0.75319326  '], time:0.01123738ms
         out_f32x4: ['1.19508219  ', '0.75319326  '], time:0.00885844ms
        out_f32_th: ['1.19508219  ', '0.75319326  '], time:0.01797009ms
-------------------------------------------------------------------------------------
           out_f16: ['1.1953125   ', '0.75341797  '], time:0.01144171ms
         out_f16x2: ['1.1953125   ', '0.75341797  '], time:0.00580645ms
         out_f16x8: ['1.1953125   ', '0.75341797  '], time:0.00570226ms
     out_f16x8pack: ['1.1953125   ', '0.75341797  '], time:0.00586438ms
        out_f16_th: ['1.1953125   ', '0.75341797  '], time:0.01156831ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
                                        S=4096, K=2048
           out_f32: ['0.58269709  ', '-0.82688355 '], time:0.01827741ms
         out_f32x4: ['0.58269709  ', '-0.82688355 '], time:0.01818061ms
        out_f32_th: ['0.58269709  ', '-0.82688355 '], time:0.09307647ms
-------------------------------------------------------------------------------------
           out_f16: ['0.58251953  ', '-0.82666016 '], time:0.01676655ms
         out_f16x2: ['0.58251953  ', '-0.82666016 '], time:0.01336956ms
         out_f16x8: ['0.58251953  ', '-0.82666016 '], time:0.00968409ms
     out_f16x8pack: ['0.58251953  ', '-0.82666016 '], time:0.00905466ms
        out_f16_th: ['0.58251953  ', '-0.82666016 '], time:0.01817107ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
                                        S=4096, K=4096
           out_f32: ['0.0         ', '0.0         '], time:0.14530301ms
         out_f32x4: ['0.0         ', '0.0         '], time:0.14572287ms
        out_f32_th: ['0.0         ', '0.0         '], time:0.29304075ms
-------------------------------------------------------------------------------------
           out_f16: ['0.0         ', '0.0         '], time:0.03192568ms
         out_f16x2: ['0.0         ', '0.0         '], time:0.03173161ms
         out_f16x8: ['0.0         ', '0.0         '], time:0.01897788ms
     out_f16x8pack: ['0.0         ', '0.0         '], time:0.01702213ms
        out_f16_th: ['0.0         ', '0.0         '], time:0.09349132ms
-------------------------------------------------------------------------------------