diff --git a/src/simd/distances_avx.cc b/src/simd/distances_avx.cc index fe7ee0cdc..f919d7b45 100644 --- a/src/simd/distances_avx.cc +++ b/src/simd/distances_avx.cc @@ -876,6 +876,95 @@ bf16_vec_L2sqr_batch_4_avx(const knowhere::bf16* x, const knowhere::bf16* y0, co dis3 = _mm256_reduce_add_ps(msum_3); } +/////////////////////////////////////////////////////////////////////////////// +// int8 + +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +float +int8_vec_inner_product_avx(const int8_t* x, const int8_t* y, size_t d) { + int32_t res = 0; + FAISS_PRAGMA_IMPRECISE_LOOP + for (size_t i = 0; i < d; i++) { + res += (int32_t)x[i] * (int32_t)y[i]; + } + return (float)res; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +float +int8_vec_L2sqr_avx(const int8_t* x, const int8_t* y, size_t d) { + int32_t res = 0; + FAISS_PRAGMA_IMPRECISE_LOOP + for (size_t i = 0; i < d; i++) { + const int32_t tmp = (int32_t)x[i] - (int32_t)y[i]; + res += tmp * tmp; + } + return (float)res; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +float +int8_vec_norm_L2sqr_avx(const int8_t* x, size_t d) { + int32_t res = 0; + FAISS_PRAGMA_IMPRECISE_LOOP + for (size_t i = 0; i < d; i++) { + res += (int32_t)x[i] * (int32_t)x[i]; + } + return (float)res; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +void +int8_vec_inner_product_batch_4_avx(const int8_t* x, const int8_t* y0, const int8_t* y1, const int8_t* y2, + const int8_t* y3, const size_t d, float& dis0, float& dis1, float& dis2, + float& dis3) { + int32_t d0 = 0, d1 = 0, d2 = 0, d3 = 0; + + FAISS_PRAGMA_IMPRECISE_LOOP + for (size_t i = 0; i < d; ++i) { + auto x_i = (int32_t)x[i]; + d0 += x_i * (int32_t)y0[i]; + d1 += x_i * (int32_t)y1[i]; + d2 += x_i * (int32_t)y2[i]; + d3 += x_i * (int32_t)y3[i]; + } + + dis0 = (float)d0; + dis1 = (float)d1; + dis2 = (float)d2; + dis3 = (float)d3; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +void +int8_vec_L2sqr_batch_4_avx(const int8_t* x, const int8_t* y0, const int8_t* y1, const int8_t* y2, const int8_t* y3, + const size_t d, float& dis0, float& dis1, float& dis2, float& dis3) { + int32_t d0 = 0, d1 = 0, d2 = 0, d3 = 0; + + FAISS_PRAGMA_IMPRECISE_LOOP + for (size_t i = 0; i < d; ++i) { + auto x_i = (int32_t)x[i]; + const int32_t q0 = x_i - (int32_t)y0[i]; + const int32_t q1 = x_i - (int32_t)y1[i]; + const int32_t q2 = x_i - (int32_t)y2[i]; + const int32_t q3 = x_i - (int32_t)y3[i]; + d0 += q0 * q0; + d1 += q1 * q1; + d2 += q2 * q2; + d3 += q3 * q3; + } + + dis0 = (float)d0; + dis1 = (float)d1; + dis2 = (float)d2; + dis3 = (float)d3; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + /////////////////////////////////////////////////////////////////////////////// // for cardinal diff --git a/src/simd/distances_avx.h b/src/simd/distances_avx.h index fb86aa1bb..2c4cbc4bf 100644 --- a/src/simd/distances_avx.h +++ b/src/simd/distances_avx.h @@ -106,6 +106,27 @@ bf16_vec_L2sqr_batch_4_avx(const knowhere::bf16* x, const knowhere::bf16* y0, co const knowhere::bf16* y2, const knowhere::bf16* y3, const size_t d, float& dis0, float& dis1, float& dis2, float& dis3); +/////////////////////////////////////////////////////////////////////////////// +// int8 + +float +int8_vec_inner_product_avx(const int8_t* x, const int8_t* y, size_t d); + +float +int8_vec_L2sqr_avx(const int8_t* x, const int8_t* y, size_t d); + +float +int8_vec_norm_L2sqr_avx(const int8_t* x, size_t d); + +void +int8_vec_inner_product_batch_4_avx(const int8_t* x, const int8_t* y0, const int8_t* y1, const int8_t* y2, + const int8_t* y3, const size_t d, float& dis0, float& dis1, float& dis2, + float& dis3); + +void +int8_vec_L2sqr_batch_4_avx(const int8_t* x, const int8_t* y0, const int8_t* y1, const int8_t* y2, const int8_t* y3, + const size_t d, float& dis0, float& dis1, float& dis2, float& dis3); + /////////////////////////////////////////////////////////////////////////////// // for cardinal diff --git a/src/simd/distances_avx512.cc b/src/simd/distances_avx512.cc index 848e66d1d..e5661e4a3 100644 --- a/src/simd/distances_avx512.cc +++ b/src/simd/distances_avx512.cc @@ -682,6 +682,95 @@ bf16_vec_L2sqr_batch_4_avx512(const knowhere::bf16* x, const knowhere::bf16* y0, dis3 = _mm512_reduce_add_ps(m512_res_3); } +/////////////////////////////////////////////////////////////////////////////// +// int8 + +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +float +int8_vec_inner_product_avx512(const int8_t* x, const int8_t* y, size_t d) { + int32_t res = 0; + FAISS_PRAGMA_IMPRECISE_LOOP + for (size_t i = 0; i < d; i++) { + res += (int32_t)x[i] * (int32_t)y[i]; + } + return (float)res; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +float +int8_vec_L2sqr_avx512(const int8_t* x, const int8_t* y, size_t d) { + int32_t res = 0; + FAISS_PRAGMA_IMPRECISE_LOOP + for (size_t i = 0; i < d; i++) { + const int32_t tmp = (int32_t)x[i] - (int32_t)y[i]; + res += tmp * tmp; + } + return (float)res; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +float +int8_vec_norm_L2sqr_avx512(const int8_t* x, size_t d) { + int32_t res = 0; + FAISS_PRAGMA_IMPRECISE_LOOP + for (size_t i = 0; i < d; i++) { + res += (int32_t)x[i] * (int32_t)x[i]; + } + return (float)res; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +void +int8_vec_inner_product_batch_4_avx512(const int8_t* x, const int8_t* y0, const int8_t* y1, const int8_t* y2, + const int8_t* y3, const size_t d, float& dis0, float& dis1, float& dis2, + float& dis3) { + int32_t d0 = 0, d1 = 0, d2 = 0, d3 = 0; + + FAISS_PRAGMA_IMPRECISE_LOOP + for (size_t i = 0; i < d; ++i) { + auto x_i = (int32_t)x[i]; + d0 += x_i * (int32_t)y0[i]; + d1 += x_i * (int32_t)y1[i]; + d2 += x_i * (int32_t)y2[i]; + d3 += x_i * (int32_t)y3[i]; + } + + dis0 = (float)d0; + dis1 = (float)d1; + dis2 = (float)d2; + dis3 = (float)d3; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +void +int8_vec_L2sqr_batch_4_avx512(const int8_t* x, const int8_t* y0, const int8_t* y1, const int8_t* y2, const int8_t* y3, + const size_t d, float& dis0, float& dis1, float& dis2, float& dis3) { + int32_t d0 = 0, d1 = 0, d2 = 0, d3 = 0; + + FAISS_PRAGMA_IMPRECISE_LOOP + for (size_t i = 0; i < d; ++i) { + auto x_i = (int32_t)x[i]; + const int32_t q0 = x_i - (int32_t)y0[i]; + const int32_t q1 = x_i - (int32_t)y1[i]; + const int32_t q2 = x_i - (int32_t)y2[i]; + const int32_t q3 = x_i - (int32_t)y3[i]; + d0 += q0 * q0; + d1 += q1 * q1; + d2 += q2 * q2; + d3 += q3 * q3; + } + + dis0 = (float)d0; + dis1 = (float)d1; + dis2 = (float)d2; + dis3 = (float)d3; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + /////////////////////////////////////////////////////////////////////////////// // for cardinal diff --git a/src/simd/distances_avx512.h b/src/simd/distances_avx512.h index 371d8cb4e..f1cda20ba 100644 --- a/src/simd/distances_avx512.h +++ b/src/simd/distances_avx512.h @@ -100,6 +100,27 @@ bf16_vec_L2sqr_batch_4_avx512(const knowhere::bf16* x, const knowhere::bf16* y0, const knowhere::bf16* y2, const knowhere::bf16* y3, const size_t d, float& dis0, float& dis1, float& dis2, float& dis3); +/////////////////////////////////////////////////////////////////////////////// +// int8 + +float +int8_vec_inner_product_avx512(const int8_t* x, const int8_t* y, size_t d); + +float +int8_vec_L2sqr_avx512(const int8_t* x, const int8_t* y, size_t d); + +float +int8_vec_norm_L2sqr_avx512(const int8_t* x, size_t d); + +void +int8_vec_inner_product_batch_4_avx512(const int8_t* x, const int8_t* y0, const int8_t* y1, const int8_t* y2, + const int8_t* y3, const size_t d, float& dis0, float& dis1, float& dis2, + float& dis3); + +void +int8_vec_L2sqr_batch_4_avx512(const int8_t* x, const int8_t* y0, const int8_t* y1, const int8_t* y2, const int8_t* y3, + const size_t d, float& dis0, float& dis1, float& dis2, float& dis3); + /////////////////////////////////////////////////////////////////////////////// // for cardinal diff --git a/src/simd/distances_neon.cc b/src/simd/distances_neon.cc index 2db261c03..2e7c96754 100644 --- a/src/simd/distances_neon.cc +++ b/src/simd/distances_neon.cc @@ -2113,6 +2113,85 @@ bf16_vec_L2sqr_batch_4_neon(const knowhere::bf16* x, const knowhere::bf16* y0, c dis3 = vaddvq_f32(res.val[3]); } +/////////////////////////////////////////////////////////////////////////////// +// int8 + +float +int8_vec_inner_product_neon(const int8_t* x, const int8_t* y, size_t d) { + // TODO caiyd: use ref implementation temporarily + int32_t res = 0; + for (size_t i = 0; i < d; i++) { + res += (int32_t)x[i] * (int32_t)y[i]; + } + return (float)res; +} + +float +int8_vec_L2sqr_neon(const int8_t* x, const int8_t* y, size_t d) { + // TODO caiyd: use ref implementation temporarily + int32_t res = 0; + for (size_t i = 0; i < d; i++) { + const int32_t tmp = (int32_t)x[i] - (int32_t)y[i]; + res += tmp * tmp; + } + return (float)res; +} + +float +int8_vec_norm_L2sqr_neon(const int8_t* x, size_t d) { + // TODO caiyd: use ref implementation temporarily + int32_t res = 0; + for (size_t i = 0; i < d; i++) { + res += (int32_t)x[i] * (int32_t)x[i]; + } + return (float)res; +} + +void +int8_vec_inner_product_batch_4_neon(const int8_t* x, const int8_t* y0, const int8_t* y1, const int8_t* y2, + const int8_t* y3, const size_t d, float& dis0, float& dis1, float& dis2, + float& dis3) { + // TODO caiyd: use ref implementation temporarily + int32_t d0 = 0, d1 = 0, d2 = 0, d3 = 0; + + for (size_t i = 0; i < d; ++i) { + auto x_i = (int32_t)x[i]; + d0 += x_i * (int32_t)y0[i]; + d1 += x_i * (int32_t)y1[i]; + d2 += x_i * (int32_t)y2[i]; + d3 += x_i * (int32_t)y3[i]; + } + + dis0 = (float)d0; + dis1 = (float)d1; + dis2 = (float)d2; + dis3 = (float)d3; +} + +void +int8_vec_L2sqr_batch_4_neon(const int8_t* x, const int8_t* y0, const int8_t* y1, const int8_t* y2, const int8_t* y3, + const size_t d, float& dis0, float& dis1, float& dis2, float& dis3) { + // TODO caiyd: use ref implementation temporarily + int32_t d0 = 0, d1 = 0, d2 = 0, d3 = 0; + + for (size_t i = 0; i < d; ++i) { + auto x_i = (int32_t)x[i]; + const int32_t q0 = x_i - (int32_t)y0[i]; + const int32_t q1 = x_i - (int32_t)y1[i]; + const int32_t q2 = x_i - (int32_t)y2[i]; + const int32_t q3 = x_i - (int32_t)y3[i]; + d0 += q0 * q0; + d1 += q1 * q1; + d2 += q2 * q2; + d3 += q3 * q3; + } + + dis0 = (float)d0; + dis1 = (float)d1; + dis2 = (float)d2; + dis3 = (float)d3; +} + /////////////////////////////////////////////////////////////////////////////// // for cardinal diff --git a/src/simd/distances_neon.h b/src/simd/distances_neon.h index 871263ca0..30e7559e1 100644 --- a/src/simd/distances_neon.h +++ b/src/simd/distances_neon.h @@ -117,6 +117,27 @@ bf16_vec_L2sqr_batch_4_neon(const knowhere::bf16* x, const knowhere::bf16* y0, c const knowhere::bf16* y2, const knowhere::bf16* y3, const size_t d, float& dis0, float& dis1, float& dis2, float& dis3); +/////////////////////////////////////////////////////////////////////////////// +// int8 + +float +int8_vec_inner_product_neon(const int8_t* x, const int8_t* y, size_t d); + +float +int8_vec_L2sqr_neon(const int8_t* x, const int8_t* y, size_t d); + +float +int8_vec_norm_L2sqr_neon(const int8_t* x, size_t d); + +void +int8_vec_inner_product_batch_4_neon(const int8_t* x, const int8_t* y0, const int8_t* y1, const int8_t* y2, + const int8_t* y3, const size_t d, float& dis0, float& dis1, float& dis2, + float& dis3); + +void +int8_vec_L2sqr_batch_4_neon(const int8_t* x, const int8_t* y0, const int8_t* y1, const int8_t* y2, const int8_t* y3, + const size_t d, float& dis0, float& dis1, float& dis2, float& dis3); + /////////////////////////////////////////////////////////////////////////////// // for cardinal diff --git a/src/simd/distances_ref.cc b/src/simd/distances_ref.cc index 62711d76a..fac0e4ca3 100644 --- a/src/simd/distances_ref.cc +++ b/src/simd/distances_ref.cc @@ -379,6 +379,80 @@ bf16_vec_L2sqr_batch_4_ref(const knowhere::bf16* x, const knowhere::bf16* y0, co dis3 = d3; } +/////////////////////////////////////////////////////////////////////////////// +// int8 + +float +int8_vec_inner_product_ref(const int8_t* x, const int8_t* y, size_t d) { + int32_t res = 0; + for (size_t i = 0; i < d; i++) { + res += (int32_t)x[i] * (int32_t)y[i]; + } + return (float)res; +} + +float +int8_vec_L2sqr_ref(const int8_t* x, const int8_t* y, size_t d) { + int32_t res = 0; + for (size_t i = 0; i < d; i++) { + const int32_t tmp = (int32_t)x[i] - (int32_t)y[i]; + res += tmp * tmp; + } + return (float)res; +} + +float +int8_vec_norm_L2sqr_ref(const int8_t* x, size_t d) { + int32_t res = 0; + for (size_t i = 0; i < d; i++) { + res += (int32_t)x[i] * (int32_t)x[i]; + } + return (float)res; +} + +void +int8_vec_inner_product_batch_4_ref(const int8_t* x, const int8_t* y0, const int8_t* y1, const int8_t* y2, + const int8_t* y3, const size_t d, float& dis0, float& dis1, float& dis2, + float& dis3) { + int32_t d0 = 0, d1 = 0, d2 = 0, d3 = 0; + + for (size_t i = 0; i < d; ++i) { + auto x_i = (int32_t)x[i]; + d0 += x_i * (int32_t)y0[i]; + d1 += x_i * (int32_t)y1[i]; + d2 += x_i * (int32_t)y2[i]; + d3 += x_i * (int32_t)y3[i]; + } + + dis0 = (float)d0; + dis1 = (float)d1; + dis2 = (float)d2; + dis3 = (float)d3; +} + +void +int8_vec_L2sqr_batch_4_ref(const int8_t* x, const int8_t* y0, const int8_t* y1, const int8_t* y2, const int8_t* y3, + const size_t d, float& dis0, float& dis1, float& dis2, float& dis3) { + int32_t d0 = 0, d1 = 0, d2 = 0, d3 = 0; + + for (size_t i = 0; i < d; ++i) { + auto x_i = (int32_t)x[i]; + const int32_t q0 = x_i - (int32_t)y0[i]; + const int32_t q1 = x_i - (int32_t)y1[i]; + const int32_t q2 = x_i - (int32_t)y2[i]; + const int32_t q3 = x_i - (int32_t)y3[i]; + d0 += q0 * q0; + d1 += q1 * q1; + d2 += q2 * q2; + d3 += q3 * q3; + } + + dis0 = (float)d0; + dis1 = (float)d1; + dis2 = (float)d2; + dis3 = (float)d3; +} + /////////////////////////////////////////////////////////////////////////////// // for cardinal diff --git a/src/simd/distances_ref.h b/src/simd/distances_ref.h index b34f940f1..07d6dc156 100644 --- a/src/simd/distances_ref.h +++ b/src/simd/distances_ref.h @@ -134,6 +134,26 @@ bf16_vec_L2sqr_batch_4_ref(const knowhere::bf16* x, const knowhere::bf16* y0, co const knowhere::bf16* y2, const knowhere::bf16* y3, const size_t d, float& dis0, float& dis1, float& dis2, float& dis3); +/////////////////////////////////////////////////////////////////////////////// +// int8 +float +int8_vec_inner_product_ref(const int8_t* x, const int8_t* y, size_t d); + +float +int8_vec_L2sqr_ref(const int8_t* x, const int8_t* y, size_t d); + +float +int8_vec_norm_L2sqr_ref(const int8_t* x, size_t d); + +void +int8_vec_inner_product_batch_4_ref(const int8_t* x, const int8_t* y0, const int8_t* y1, const int8_t* y2, + const int8_t* y3, const size_t d, float& dis0, float& dis1, float& dis2, + float& dis3); + +void +int8_vec_L2sqr_batch_4_ref(const int8_t* x, const int8_t* y0, const int8_t* y1, const int8_t* y2, const int8_t* y3, + const size_t d, float& dis0, float& dis1, float& dis2, float& dis3); + /////////////////////////////////////////////////////////////////////////////// // for cardinal float diff --git a/src/simd/distances_sse.cc b/src/simd/distances_sse.cc index 83f6bb620..8648e4990 100644 --- a/src/simd/distances_sse.cc +++ b/src/simd/distances_sse.cc @@ -19,6 +19,7 @@ #include #include "distances_ref.h" +#include "faiss/impl/platform_macros.h" namespace faiss { @@ -492,5 +493,45 @@ bf16_vec_norm_L2sqr_sse(const knowhere::bf16* x, size_t d) { return _mm_cvtss_f32(m_res); } +/////////////////////////////////////////////////////////////////////////////// +// int8 + +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +float +int8_vec_inner_product_sse(const int8_t* x, const int8_t* y, size_t d) { + int32_t res = 0; + FAISS_PRAGMA_IMPRECISE_LOOP + for (size_t i = 0; i < d; i++) { + res += (int32_t)x[i] * (int32_t)y[i]; + } + return (float)res; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +float +int8_vec_L2sqr_sse(const int8_t* x, const int8_t* y, size_t d) { + int32_t res = 0; + FAISS_PRAGMA_IMPRECISE_LOOP + for (size_t i = 0; i < d; i++) { + const int32_t tmp = (int32_t)x[i] - (int32_t)y[i]; + res += tmp * tmp; + } + return (float)res; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +float +int8_vec_norm_L2sqr_sse(const int8_t* x, size_t d) { + int32_t res = 0; + FAISS_PRAGMA_IMPRECISE_LOOP + for (size_t i = 0; i < d; i++) { + res += (int32_t)x[i] * (int32_t)x[i]; + } + return (float)res; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + } // namespace faiss #endif diff --git a/src/simd/distances_sse.h b/src/simd/distances_sse.h index 51574de2c..ff58c411b 100644 --- a/src/simd/distances_sse.h +++ b/src/simd/distances_sse.h @@ -70,4 +70,16 @@ bf16_vec_L2sqr_sse(const knowhere::bf16* x, const knowhere::bf16* y, size_t d); float bf16_vec_norm_L2sqr_sse(const knowhere::bf16* x, size_t d); +/////////////////////////////////////////////////////////////////////////////// +// int8 + +float +int8_vec_inner_product_sse(const int8_t* x, const int8_t* y, size_t d); + +float +int8_vec_L2sqr_sse(const int8_t* x, const int8_t* y, size_t d); + +float +int8_vec_norm_L2sqr_sse(const int8_t* x, size_t d); + } // namespace faiss diff --git a/src/simd/hook.cc b/src/simd/hook.cc index 697478280..9f978866e 100644 --- a/src/simd/hook.cc +++ b/src/simd/hook.cc @@ -82,6 +82,14 @@ decltype(bf16_vec_norm_L2sqr) bf16_vec_norm_L2sqr = bf16_vec_norm_L2sqr_ref; decltype(bf16_vec_inner_product_batch_4) bf16_vec_inner_product_batch_4 = bf16_vec_inner_product_batch_4_ref; decltype(bf16_vec_L2sqr_batch_4) bf16_vec_L2sqr_batch_4 = bf16_vec_L2sqr_batch_4_ref; +// int8 +decltype(int8_vec_L2sqr) int8_vec_L2sqr = int8_vec_L2sqr_ref; +decltype(int8_vec_inner_product) int8_vec_inner_product = int8_vec_inner_product_ref; +decltype(int8_vec_norm_L2sqr) int8_vec_norm_L2sqr = int8_vec_norm_L2sqr_ref; + +decltype(int8_vec_inner_product_batch_4) int8_vec_inner_product_batch_4 = int8_vec_inner_product_batch_4_ref; +decltype(int8_vec_L2sqr_batch_4) int8_vec_L2sqr_batch_4 = int8_vec_L2sqr_batch_4_ref; + /////////////////////////////////////////////////////////////////////////////// #if defined(__x86_64__) bool @@ -217,6 +225,14 @@ fvec_hook(std::string& simd_type) { bf16_vec_inner_product_batch_4 = bf16_vec_inner_product_batch_4_avx512; bf16_vec_L2sqr_batch_4 = bf16_vec_L2sqr_batch_4_avx512; + // int8 + int8_vec_inner_product = int8_vec_inner_product_avx512; + int8_vec_L2sqr = int8_vec_L2sqr_avx512; + int8_vec_norm_L2sqr = int8_vec_norm_L2sqr_avx512; + + int8_vec_inner_product_batch_4 = int8_vec_inner_product_batch_4_avx512; + int8_vec_L2sqr_batch_4 = int8_vec_L2sqr_batch_4_avx512; + // simd_type = "AVX512"; support_pq_fast_scan = true; @@ -256,6 +272,14 @@ fvec_hook(std::string& simd_type) { bf16_vec_inner_product_batch_4 = bf16_vec_inner_product_batch_4_avx; bf16_vec_L2sqr_batch_4 = bf16_vec_L2sqr_batch_4_avx; + // int8 + int8_vec_inner_product = int8_vec_inner_product_avx; + int8_vec_L2sqr = int8_vec_L2sqr_avx; + int8_vec_norm_L2sqr = int8_vec_norm_L2sqr_avx; + + int8_vec_inner_product_batch_4 = int8_vec_inner_product_batch_4_avx; + int8_vec_L2sqr_batch_4 = int8_vec_L2sqr_batch_4_avx; + // simd_type = "AVX2"; support_pq_fast_scan = true; @@ -294,6 +318,14 @@ fvec_hook(std::string& simd_type) { bf16_vec_inner_product_batch_4 = bf16_vec_inner_product_batch_4_ref; bf16_vec_L2sqr_batch_4 = bf16_vec_L2sqr_batch_4_ref; + // int8 + int8_vec_inner_product = int8_vec_inner_product_sse; + int8_vec_L2sqr = int8_vec_L2sqr_sse; + int8_vec_norm_L2sqr = int8_vec_norm_L2sqr_sse; + + int8_vec_inner_product_batch_4 = int8_vec_inner_product_batch_4_ref; + int8_vec_L2sqr_batch_4 = int8_vec_L2sqr_batch_4_ref; + // simd_type = "SSE4_2"; support_pq_fast_scan = false; @@ -332,6 +364,14 @@ fvec_hook(std::string& simd_type) { bf16_vec_inner_product_batch_4 = bf16_vec_inner_product_batch_4_ref; bf16_vec_L2sqr_batch_4 = bf16_vec_L2sqr_batch_4_ref; + // int8 + int8_vec_inner_product = int8_vec_inner_product_ref; + int8_vec_L2sqr = int8_vec_L2sqr_ref; + int8_vec_norm_L2sqr = int8_vec_norm_L2sqr_ref; + + int8_vec_inner_product_batch_4 = int8_vec_inner_product_batch_4_ref; + int8_vec_L2sqr_batch_4 = int8_vec_L2sqr_batch_4_ref; + // simd_type = "GENERIC"; support_pq_fast_scan = false; diff --git a/src/simd/hook.h b/src/simd/hook.h index e0f1396d4..ed6f22546 100644 --- a/src/simd/hook.h +++ b/src/simd/hook.h @@ -108,6 +108,15 @@ extern void (*bf16_vec_inner_product_batch_4)(const knowhere::bf16*, const knowh extern void (*bf16_vec_L2sqr_batch_4)(const knowhere::bf16*, const knowhere::bf16*, const knowhere::bf16*, const knowhere::bf16*, const knowhere::bf16*, const size_t, float&, float&, float&, float&); +// int8 +extern float (*int8_vec_inner_product)(const int8_t*, const int8_t*, size_t); +extern float (*int8_vec_L2sqr)(const int8_t*, const int8_t*, size_t); +extern float (*int8_vec_norm_L2sqr)(const int8_t*, size_t); + +extern void (*int8_vec_inner_product_batch_4)(const int8_t*, const int8_t*, const int8_t*, const int8_t*, const int8_t*, + const size_t, float&, float&, float&, float&); +extern void (*int8_vec_L2sqr_batch_4)(const int8_t*, const int8_t*, const int8_t*, const int8_t*, const int8_t*, + const size_t, float&, float&, float&, float&); /////////////////////////////////////////////////////////////////////////////// #if defined(__x86_64__) diff --git a/tests/ut/test_simd.cc b/tests/ut/test_simd.cc index 0f44cbacc..484055e4d 100644 --- a/tests/ut/test_simd.cc +++ b/tests/ut/test_simd.cc @@ -71,7 +71,12 @@ TEST_CASE("Test distance") { const auto x_bf16 = ConvertVector(x.get(), nx, dim); const auto y_bf16 = ConvertVector(y.get(), ny, dim); - // int8 + // int8 should have no precision loss + const float int8_tolerance = 0.000001f; + const auto x_int8 = ConvertVector(x.get(), nx, dim); + const auto y_int8 = ConvertVector(y.get(), ny, dim); + + // int8, for hnsw sq, obsolete const auto xi = ConvertVector(x.get(), nx, dim); const auto yi = ConvertVector(y.get(), ny, dim); @@ -125,6 +130,30 @@ TEST_CASE("Test distance") { } } + SECTION("test single distance calculation for int8") { + // calculate the float result ref + std::vector ref_ip, ref_L2sqr, ref_norm_L2sqr; + for (size_t i = 0; i < ny; i++) { + const knowhere::int8* x_data = x_int8.get(); + const knowhere::int8* y_data = y_int8.get() + dim; + ref_ip.push_back(faiss::int8_vec_inner_product_ref(x_data, y_data, dim)); + ref_L2sqr.push_back(faiss::int8_vec_L2sqr_ref(x_data, y_data, dim)); + ref_norm_L2sqr.push_back(faiss::int8_vec_norm_L2sqr_ref(y_data, dim)); + } + + // int8 + for (size_t i = 0; i < ny; i++) { + const knowhere::int8* x_data = x_int8.get(); + const knowhere::int8* y_data = y_int8.get() + dim; + REQUIRE_THAT(faiss::int8_vec_inner_product(x_data, y_data, dim), + Catch::Matchers::WithinRel(ref_ip[i], int8_tolerance)); + REQUIRE_THAT(faiss::int8_vec_L2sqr(x_data, y_data, dim), + Catch::Matchers::WithinRel(ref_L2sqr[i], int8_tolerance)); + REQUIRE_THAT(faiss::int8_vec_norm_L2sqr(y_data, dim), + Catch::Matchers::WithinRel(ref_norm_L2sqr[i], int8_tolerance)); + } + } + // obsolete SECTION("test single distance calculation for hnsw sq") { // calculate the int32 result ref @@ -240,6 +269,39 @@ TEST_CASE("Test distance") { REQUIRE_THAT(l2_batch_4[2], Catch::Matchers::WithinRel(ref_l2_batch_4[2], tolerance)); REQUIRE_THAT(l2_batch_4[3], Catch::Matchers::WithinRel(ref_l2_batch_4[3], tolerance)); } + + // int8 + { + const knowhere::int8* x_data = x_int8.get(); + std::vector y_data{y_int8.get(), y_int8.get() + dim, y_int8.get() + 2 * dim, + y_int8.get() + 3 * dim}; + + // calculate the int8 result ref + std::vector ref_l2_batch_4(4), ref_ip_batch_4(4); + faiss::int8_vec_inner_product_batch_4_ref(x_data, y_data[0], y_data[1], y_data[2], y_data[3], dim, + ref_ip_batch_4[0], ref_ip_batch_4[1], ref_ip_batch_4[2], + ref_ip_batch_4[3]); + faiss::int8_vec_L2sqr_batch_4_ref(x_data, y_data[0], y_data[1], y_data[2], y_data[3], dim, + ref_l2_batch_4[0], ref_l2_batch_4[1], ref_l2_batch_4[2], + ref_l2_batch_4[3]); + + // int8 + std::vector l2_batch_4(4), ip_batch_4(4); + faiss::int8_vec_inner_product_batch_4(x_data, y_data[0], y_data[1], y_data[2], y_data[3], dim, + ip_batch_4[0], ip_batch_4[1], ip_batch_4[2], ip_batch_4[3]); + faiss::int8_vec_L2sqr_batch_4(x_data, y_data[0], y_data[1], y_data[2], y_data[3], dim, l2_batch_4[0], + l2_batch_4[1], l2_batch_4[2], l2_batch_4[3]); + + REQUIRE_THAT(ip_batch_4[0], Catch::Matchers::WithinRel(ref_ip_batch_4[0], int8_tolerance)); + REQUIRE_THAT(ip_batch_4[1], Catch::Matchers::WithinRel(ref_ip_batch_4[1], int8_tolerance)); + REQUIRE_THAT(ip_batch_4[2], Catch::Matchers::WithinRel(ref_ip_batch_4[2], int8_tolerance)); + REQUIRE_THAT(ip_batch_4[3], Catch::Matchers::WithinRel(ref_ip_batch_4[3], int8_tolerance)); + + REQUIRE_THAT(l2_batch_4[0], Catch::Matchers::WithinRel(ref_l2_batch_4[0], int8_tolerance)); + REQUIRE_THAT(l2_batch_4[1], Catch::Matchers::WithinRel(ref_l2_batch_4[1], int8_tolerance)); + REQUIRE_THAT(l2_batch_4[2], Catch::Matchers::WithinRel(ref_l2_batch_4[2], int8_tolerance)); + REQUIRE_THAT(l2_batch_4[3], Catch::Matchers::WithinRel(ref_l2_batch_4[3], int8_tolerance)); + } } SECTION("test ny distance calculation") {