Skip to content

Commit

Permalink
[Experimental][Kleidi] Add GEMM operator tests (#1638)
Browse files Browse the repository at this point in the history
  • Loading branch information
digantdesai authored Jan 30, 2025
1 parent c1f5872 commit b559c6d
Show file tree
Hide file tree
Showing 9 changed files with 1,623 additions and 47 deletions.
4 changes: 2 additions & 2 deletions torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ if ((CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") OR (CMAKE_SYSTEM_PROCESSOR STREQUA
include(FetchContent)
# KleidiAI is an open-source library that provides optimized
# performance-critical routines, also known as micro-kernels, for artificial
# intelligence (AI) workloads tailored for Arm® CPUs.
# intelligence (AI) workloads tailored for Arm® CPUs.
FetchContent_Declare(kleidiai
GIT_REPOSITORY https://git.gitlab.arm.com/kleidi/kleidiai.git
GIT_TAG 35e156d62d1d7e4d27a39f56ed7770a665628b31) # same as xnnpack for now, TODO - revisit this
GIT_TAG v1.2.0)
FetchContent_MakeAvailable(kleidiai)

# Temporarily exposing this to the parent scope until we wire
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ void kernel(
activation_data,
weight_data,
output,
/*dst_stride_row=*/n * sizeof(float),
/*dst_stride_row=*/output_m_stride * sizeof(float),
/*dst_stride_col=*/sizeof(float),
clamp_min,
clamp_max);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ void kernel(
activation_data,
weight_data,
output,
/*dst_stride_row=*/ n * sizeof(float),
/*dst_stride_row=*/ output_m_stride * sizeof(float),
/*dst_stride_col=*/ sizeof(float),
clamp_min,
clamp_max);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ void kernel(
activation_data,
weight_data,
output,
/*dst_stride_row=*/n * sizeof(float),
/*dst_stride_row=*/output_m_stride * sizeof(float),
/*dst_stride_col=*/sizeof(float),
clamp_min,
clamp_max);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ void kernel(
activation_data,
weight_data,
output,
/*dst_stride_row=*/n * sizeof(float),
/*dst_stride_row=*/output_m_stride * sizeof(float),
/*dst_stride_col=*/sizeof(float),
clamp_min,
clamp_max);
Expand Down
22 changes: 22 additions & 0 deletions torchao/experimental/ops/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,34 @@ if(TORCHAO_BUILD_KLEIDIAI)
add_compile_definitions(TORCHAO_ENABLE_KLEIDI=1)
endif()

if(TORCHAO_BUILD_ARM_I8MM)
add_compile_definitions(TORCHAO_ENABLE_ARM_I8MM)
endif()

if (ANDROID_ABI)
# We are cross compiling, delay test discovery till runtime
set(CMAKE_GTEST_DISCOVER_TESTS_DISCOVERY_MODE PRE_TEST)
endif()

include_directories(${TORCHAO_INCLUDE_DIRS})

set(TORCHAO_PARALLEL_BACKEND "test_dummy")
add_subdirectory(${TORCHAO_ROOT}/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/torchao_kernels_aarch64)

include(${TORCHAO_ROOT}/Utils.cmake)

if (ANDROID_ABI)
# Given where we are today this is sufficent. But needs to be revisited.
# This is also needed for native builds, but keeping it only for cross builds
# for now given the hacky nature.
file(GLOB DOTPROD_SRC_FILES test*.cpp)
message(SRC_FILES: ${DOTPROD_SRC_FILES})
set_property(SOURCE
${DOTPROD_SRC_FILES}
APPEND_STRING PROPERTY
COMPILE_FLAGS " -march=armv8.2-a+dotprod ")
endif()

add_executable(
test_linear_8bit_act_xbit_weight
test_linear_8bit_act_xbit_weight.cpp
Expand Down
41 changes: 39 additions & 2 deletions torchao/experimental/ops/tests/build_and_run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,57 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

target=${1:-"native"}
SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd)
export CMAKE_OUT=/tmp/cmake-out/torch_ao/tests

IS_ARM64=0
BUILD_ARM_I8MM=0
EXTRA_ARGS=""
if [[ "${target}" == "android" ]]; then
if [[ -z ${ANDROID_NDK} ]]; then
echo "Need to set ANDROID_NDK env variable to build for Android";
exit 1;
fi
android_abi=arm64-v8a
android_platform=28 # must be >=28 for aligned_alloc
IS_ARM64=1
BUILD_ARM_I8MM=1 # Hardcoded for now
CMAKE_OUT=${CMAKE_OUT/cmake-out/cmake-out-android}
toolchain_file="${ANDROID_NDK}/build/cmake/android.toolchain.cmake"
if [[ -z ${toolchain_file} ]]; then
echo "Unable to find toolchain file at ANDROID_NDK location, looking for ${toolchain_file}"
exit 1;
fi
EXTRA_ARGS="\
-DCMAKE_TOOLCHAIN_FILE=${toolchain_file} \
-DANDROID_ABI=${android_abi} \
-DANDROID_PLATFORM=${android_platform}
"
echo "Building tests for Android (${android_abi}) @ ${CMAKE_OUT}"
fi

hash arch; retval=$?
if [[ ${retval} -eq 0 && $(arch) == "arm64" ]]; then
IS_ARM64=1
fi

export CMAKE_OUT=/tmp/cmake-out/torchao/tests
cmake \
-DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \
${EXTRA_ARGS} \
-DCMAKE_BUILD_TYPE=Debug \
-DTORCHAO_BUILD_KLEIDIAI=${IS_ARM64} \
-DTORCHAO_BUILD_ARM_I8MM=${BUILD_ARM_I8MM} \
-S . \
-B ${CMAKE_OUT}

cmake --build ${CMAKE_OUT}

echo "Successfully built tests."

if [[ "${target}" != "native" ]]; then
echo "Skip running tests when cross compiling.";
exit 0;
fi

# Run
${CMAKE_OUT}/test_linear_8bit_act_xbit_weight
128 changes: 128 additions & 0 deletions torchao/experimental/ops/tests/generate_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.

# Simple script to generate test cases for the torchao ops
from string import Template


def add_test_string(kernel, m, n, k, g, has_bias, has_clamp):
name = f"m{m}xn{n}xk{k}xg{g}{'_bias' if has_bias else ''}{'_clamp' if has_clamp else ''}"
d = {
"name": name,
"kernel": kernel,
"m": m,
"n": n,
"k": k,
"g": g,
"has_bias": "true" if has_bias else "false",
"has_clamp": "true" if has_clamp else "false",
}

test_template = Template(
"""
TEST(test_linear_8bit_act_xbit_weight, Kleidi_${kernel}_${name}) {
UKernelConfig ukernel_config = get_ukernel_config_kleidi<${kernel}>();
test_linear_8bit_act_xbit_weight<
4 /*weight_nbit*/,
false /*has_weight_zeros*/,
$has_bias /*has_bias*/,
$has_clamp /*has_clamp*/,
true /*has_kleidi*/>(
/*m=*/$m, /*n=*/$n, /*k=*/$k, /*group_size=*/$g, &ukernel_config);
}
"""
)

return [test_template.safe_substitute(d)]


def get_test_block(kernel):
# Assuming given kleidi kernel can run with all these test cases
tests = []
# GEMV, m == 1
## subtile
tests += add_test_string(kernel, 1, 2 * 1, 32, 32, False, False)
tests += add_test_string(kernel, 1, 2 * 2, 32, 32, False, False)
tests += add_test_string(kernel, 1, 2 * 3, 32, 32, True, False)
tests += add_test_string(kernel, 1, 2 * 2, 32, 32, True, True)
tests += add_test_string(kernel, 1, 2 * 3, 32, 32, False, True)
## larger: n - must be multiple of 2
tests += add_test_string(kernel, 1, 2 * 11, 32, 32, False, False)
tests += add_test_string(kernel, 1, 2 * 13, 32, 32, True, False)
tests += add_test_string(kernel, 1, 2 * 51, 32, 32, False, True)
tests += add_test_string(kernel, 1, 2 * 111, 32, 32, False, False)
## larger: k, g - must be multiple of 32
tests += add_test_string(kernel, 1, 2 * 7, 64, 32, False, False)
tests += add_test_string(kernel, 1, 2 * 11, 128, 32, True, False)
tests += add_test_string(kernel, 1, 2 * 13, 64, 64, False, True)
tests += add_test_string(kernel, 1, 2 * 17, 128, 64, False, False)

# GEMM, m > 1
## subtile
tests += add_test_string(kernel, 2, 2 * 1, 32, 32, False, False)
tests += add_test_string(kernel, 2, 2 * 2, 32, 32, False, False)
tests += add_test_string(kernel, 3, 2 * 3, 32, 32, True, False)
tests += add_test_string(kernel, 4, 2 * 4, 32, 32, True, True)
tests += add_test_string(kernel, 3, 2 * 3, 32, 32, False, True)
## larger: m
tests += add_test_string(kernel, 31, 2 * 1, 32, 32, False, False)
tests += add_test_string(kernel, 32, 2 * 2, 32, 32, False, False)
tests += add_test_string(kernel, 33, 2 * 3, 32, 32, True, False)
tests += add_test_string(kernel, 34, 2 * 4, 32, 32, True, True)
tests += add_test_string(kernel, 35, 2 * 3, 32, 32, False, True)
## larger: n - must be multiple of 2
tests += add_test_string(kernel, 7, 2 * 11, 32, 32, False, False)
tests += add_test_string(kernel, 17, 2 * 13, 32, 32, True, False)
tests += add_test_string(kernel, 23, 2 * 51, 32, 32, False, True)
tests += add_test_string(kernel, 41, 2 * 111, 32, 32, False, False)
## larger: k, g - must be multiple of 32
tests += add_test_string(kernel, 19, 2 * 7, 64, 32, False, False)
tests += add_test_string(kernel, 23, 2 * 11, 128, 32, True, False)
tests += add_test_string(kernel, 29, 2 * 13, 64, 64, False, True)
tests += add_test_string(kernel, 101, 2 * 17, 128, 64, False, False)

return "".join(tests)


def main():
kleidi_template = Template(
"""
/*****************/
// ${kernel} tests
/*****************/
${prologue}
${tests}
${epilogue}
"""
)

kleidi_kernels = [
"dotprod_1x4x32",
"dotprod_1x8x32",
"i8mm_4x8x32",
"i8mm_8x4x32",
]

print("/* Generated by generate_tests.py */")
print("/* Do not modify */")
print()
print("#if defined(TORCHAO_ENABLE_KLEIDI)")
for kernel in kleidi_kernels:
prologue, epilogue = "", ""
if "i8mm" in kernel:
prologue = "#if defined(TORCHAO_ENABLE_ARM_I8MM)"
epilogue = "#endif // TORCHAO_ENABLE_ARM_I8MM"
tests = get_test_block(kernel)
d = {
"prologue": prologue,
"kernel": kernel,
"tests": tests,
"epilogue": epilogue,
}

print(kleidi_template.safe_substitute(d))
print("#endif // TORCHAO_ENABLE_KLEIDI")


if __name__ == "__main__":
main()
Loading

0 comments on commit b559c6d

Please sign in to comment.