Skip to content

Commit

Permalink
[aarch64] patch mkldnn acl inner product to accelerate torch.compile(…
Browse files Browse the repository at this point in the history
…) for bert
  • Loading branch information
snadampal committed Dec 7, 2023
1 parent b92da8c commit 77d657e
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 0 deletions.
3 changes: 3 additions & 0 deletions aarch64_linux/aarch64_wheel_ci_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ def parse_arguments():
# work around to fix Raspberry pie crash
print("Applying mkl-dnn patch to fix readdir crash")
os.system("cd /pytorch/third_party/ideep/mkl-dnn && patch -p1 < /builder/mkldnn_fix/aarch64-fix-readdir-crash.patch")
# patch acl inner product to accelerate torch.compile() path
print("Applying mkl-dnn patch to acl inner product")
os.system("cd /pytorch/third_party/ideep/mkl-dnn && patch -p1 < /builder/mkldnn_fix/cpu-aarch64-add-sbgemm-fp32-input-and-bf16-weights-ip.patch")
os.system(f"cd /pytorch; {build_vars} python3 setup.py bdist_wheel")
pytorch_wheel_name = complete_wheel("pytorch")
print(f"Build Compelete. Created {pytorch_wheel_name}..")
1 change: 1 addition & 0 deletions aarch64_linux/build_aarch64_wheel.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,7 @@ def start_build(host: RemoteHost, *,
print("build pytorch with mkldnn+acl backend")
build_vars += " USE_MKLDNN=ON USE_MKLDNN_ACL=ON"
host.run_cmd(f"cd $HOME && git clone https://github.com/pytorch/builder.git")
host.run_cmd(f"cd $HOME/pytorch/third_party/ideep/mkl-dnn && patch -p1 < $HOME/builder/mkldnn_fix/cpu-aarch64-add-sbgemm-fp32-input-and-bf16-weights-ip.patch")
host.run_cmd(f"cd $HOME/pytorch && export ACL_ROOT_DIR=$HOME/ComputeLibrary && {build_vars} python3 setup.py bdist_wheel{build_opts}")
print('Repair the wheel')
pytorch_wheel_name = host.list_dir("pytorch/dist")[0]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
cpu: aarch64: add sbgemm (fp32 input and bf16 weights) inner
product op

---
src/cpu/aarch64/acl_inner_product.hpp | 8 ++++++--
src/cpu/cpu_inner_product_list.cpp | 4 ++++
2 files changed, 10 insertions(+), 2 deletions(-)

diff --git a/src/cpu/aarch64/acl_inner_product.hpp b/src/cpu/aarch64/acl_inner_product.hpp
index a2be164f0..eca56b289 100644
--- a/src/cpu/aarch64/acl_inner_product.hpp
+++ b/src/cpu/aarch64/acl_inner_product.hpp
@@ -99,9 +99,13 @@ struct acl_inner_product_fwd_t : public primitive_t {
const bool is_fp32_ok = expect_data_types(f32, f32, f32, f32, undef)
&& attr()->has_default_values(
primitive_attr_t::skip_mask_t::post_ops, f32);
+ const bool is_fp32_bf16_ok
+ = expect_data_types(f32, bf16, f32, f32, undef)
+ && attr()->has_default_values(
+ primitive_attr_t::skip_mask_t::post_ops, f32);
const bool ok = is_fwd() && !has_zero_dim_memory()
- && utils::one_of(true, is_fp16_ok, is_fp32_ok)
- && weights_md_.format_kind == format_kind::any
+ && utils::one_of(
+ true, is_fp16_ok, is_fp32_ok, is_fp32_bf16_ok)
&& set_default_params() == status::success;

if (!ok) return status::unimplemented;
diff --git a/src/cpu/cpu_inner_product_list.cpp b/src/cpu/cpu_inner_product_list.cpp
index fdd7b1776..5a3dc1ea7 100644
--- a/src/cpu/cpu_inner_product_list.cpp
+++ b/src/cpu/cpu_inner_product_list.cpp
@@ -83,6 +83,10 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
CPU_INSTANCE(ref_inner_product_fwd_t)
nullptr,
}},
+ {{forward, f32, bf16, f32}, {
+ CPU_INSTANCE_AARCH64_ACL(acl_inner_product_fwd_t)
+ nullptr,
+ }},
{{backward_data, f32, f32, f32}, REG_BWD_PK({
CPU_INSTANCE_AMX(brgemm_inner_product_bwd_data_t<avx512_core_amx>) // bf32
CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_data_t<avx512_core>)
--
2.34.1

0 comments on commit 77d657e

Please sign in to comment.