Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: Cai Yudong <[email protected]>
  • Loading branch information
cydrain committed Feb 19, 2025
1 parent 0b18639 commit 9260210
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 109 deletions.
121 changes: 59 additions & 62 deletions src/simd/distances_avx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,11 @@
#if defined(__x86_64__)

#include "distances_avx.h"
#include "faiss/impl/platform_macros.h"

#include <immintrin.h>

#include <cassert>

#include "faiss/impl/platform_macros.h"

namespace faiss {

#define ALIGNED(x) __attribute__((aligned(x)))
Expand Down Expand Up @@ -230,32 +228,6 @@ fvec_inner_product_batch_4_avx(const float* __restrict x, const float* __restric
}
FAISS_PRAGMA_IMPRECISE_FUNCTION_END

// trust the compiler to unroll this properly
FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
void
fvec_inner_product_batch_4_avx_bf16_patch(const float* __restrict x, const float* __restrict y0,
const float* __restrict y1, const float* __restrict y2,
const float* __restrict y3, const size_t d, float& dis0, float& dis1,
float& dis2, float& dis3) {
float d0 = 0;
float d1 = 0;
float d2 = 0;
float d3 = 0;
FAISS_PRAGMA_IMPRECISE_LOOP
for (size_t i = 0; i < d; ++i) {
d0 += x[i] * bf16_float(y0[i]);
d1 += x[i] * bf16_float(y1[i]);
d2 += x[i] * bf16_float(y2[i]);
d3 += x[i] * bf16_float(y3[i]);
}

dis0 = d0;
dis1 = d1;
dis2 = d2;
dis3 = d3;
}
FAISS_PRAGMA_IMPRECISE_FUNCTION_END

// trust the compiler to unroll this properly
FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
void
Expand Down Expand Up @@ -284,34 +256,6 @@ fvec_L2sqr_batch_4_avx(const float* x, const float* y0, const float* y1, const f
}
FAISS_PRAGMA_IMPRECISE_FUNCTION_END

// trust the compiler to unroll this properly
FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
void
fvec_L2sqr_batch_4_avx_bf16_patch(const float* x, const float* y0, const float* y1, const float* y2, const float* y3,
const size_t d, float& dis0, float& dis1, float& dis2, float& dis3) {
float d0 = 0;
float d1 = 0;
float d2 = 0;
float d3 = 0;
FAISS_PRAGMA_IMPRECISE_LOOP
for (size_t i = 0; i < d; ++i) {
const float q0 = x[i] - bf16_float(y0[i]);
const float q1 = x[i] - bf16_float(y1[i]);
const float q2 = x[i] - bf16_float(y2[i]);
const float q3 = x[i] - bf16_float(y3[i]);
d0 += q0 * q0;
d1 += q1 * q1;
d2 += q2 * q2;
d3 += q3 * q3;
}

dis0 = d0;
dis1 = d1;
dis2 = d2;
dis3 = d3;
}
FAISS_PRAGMA_IMPRECISE_FUNCTION_END

float
fvec_norm_L2sqr_avx(const float* x, size_t d) {
__m256 msum_0 = _mm256_setzero_ps();
Expand Down Expand Up @@ -945,32 +889,85 @@ bf16_vec_L2sqr_batch_4_avx(const knowhere::bf16* x, const knowhere::bf16* y0, co
return;
}

//
FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
float
fvec_L2sqr_avx_bf16_patch(const float* x, const float* y, size_t d) {
fvec_inner_product_avx_bf16_patch(const float* x, const float* y, size_t d) {
size_t i;
float res = 0;
FAISS_PRAGMA_IMPRECISE_LOOP
for (i = 0; i < d; i++) {
const float tmp = x[i] - bf16_float(y[i]);
res += tmp * tmp;
res += x[i] * bf16_float(y[i]);
}
return res;
}
FAISS_PRAGMA_IMPRECISE_FUNCTION_END

FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
void
fvec_inner_product_batch_4_avx_bf16_patch(const float* __restrict x, const float* __restrict y0,
const float* __restrict y1, const float* __restrict y2,
const float* __restrict y3, const size_t d, float& dis0, float& dis1,
float& dis2, float& dis3) {
float d0 = 0;
float d1 = 0;
float d2 = 0;
float d3 = 0;
FAISS_PRAGMA_IMPRECISE_LOOP
for (size_t i = 0; i < d; ++i) {
d0 += x[i] * bf16_float(y0[i]);
d1 += x[i] * bf16_float(y1[i]);
d2 += x[i] * bf16_float(y2[i]);
d3 += x[i] * bf16_float(y3[i]);
}

dis0 = d0;
dis1 = d1;
dis2 = d2;
dis3 = d3;
}
FAISS_PRAGMA_IMPRECISE_FUNCTION_END

FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
float
fvec_inner_product_avx_bf16_patch(const float* x, const float* y, size_t d) {
fvec_L2sqr_avx_bf16_patch(const float* x, const float* y, size_t d) {
size_t i;
float res = 0;
FAISS_PRAGMA_IMPRECISE_LOOP
for (i = 0; i < d; i++) {
res += x[i] * bf16_float(y[i]);
const float tmp = x[i] - bf16_float(y[i]);
res += tmp * tmp;
}
return res;
}
FAISS_PRAGMA_IMPRECISE_FUNCTION_END

FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
void
fvec_L2sqr_batch_4_avx_bf16_patch(const float* x, const float* y0, const float* y1, const float* y2, const float* y3,
const size_t d, float& dis0, float& dis1, float& dis2, float& dis3) {
float d0 = 0;
float d1 = 0;
float d2 = 0;
float d3 = 0;
FAISS_PRAGMA_IMPRECISE_LOOP
for (size_t i = 0; i < d; ++i) {
const float q0 = x[i] - bf16_float(y0[i]);
const float q1 = x[i] - bf16_float(y1[i]);
const float q2 = x[i] - bf16_float(y2[i]);
const float q3 = x[i] - bf16_float(y3[i]);
d0 += q0 * q0;
d1 += q1 * q1;
d2 += q2 * q2;
d3 += q3 * q3;
}

dis0 = d0;
dis1 = d1;
dis2 = d2;
dis3 = d3;
}
FAISS_PRAGMA_IMPRECISE_FUNCTION_END

} // namespace faiss
#endif
97 changes: 50 additions & 47 deletions src/simd/distances_avx.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,28 +23,10 @@ namespace faiss {
float
fvec_L2sqr_avx(const float* x, const float* y, size_t d);

float
fvec_L2sqr_avx_bf16_patch(const float* x, const float* y, size_t d);

float
fp16_vec_L2sqr_avx(const knowhere::fp16* x, const knowhere::fp16* y, size_t d);

float
bf16_vec_L2sqr_avx(const knowhere::bf16* x, const knowhere::bf16* y, size_t d);

/// inner product
float
fvec_inner_product_avx(const float* x, const float* y, size_t d);

float
fvec_inner_product_avx_bf16_patch(const float* x, const float* y, size_t d);

float
fp16_vec_inner_product_avx(const knowhere::fp16* x, const knowhere::fp16* y, size_t d);

float
bf16_vec_inner_product_avx(const knowhere::bf16* x, const knowhere::bf16* y, size_t d);

/// L1 distance
float
fvec_L1_avx(const float* x, const float* y, size_t d);
Expand All @@ -60,59 +42,80 @@ void
fvec_inner_product_batch_4_avx(const float* x, const float* y0, const float* y1, const float* y2, const float* y3,
const size_t d, float& dis0, float& dis1, float& dis2, float& dis3);

void
fvec_inner_product_batch_4_avx_bf16_patch(const float* x, const float* y0, const float* y1, const float* y2,
const float* y3, const size_t d, float& dis0, float& dis1, float& dis2,
float& dis3);

void
fp16_vec_inner_product_batch_4_avx(const knowhere::fp16* x, const knowhere::fp16* y0, const knowhere::fp16* y1,
const knowhere::fp16* y2, const knowhere::fp16* y3, const size_t d, float& dis0,
float& dis1, float& dis2, float& dis3);

void
bf16_vec_inner_product_batch_4_avx(const knowhere::bf16* x, const knowhere::bf16* y0, const knowhere::bf16* y1,
const knowhere::bf16* y2, const knowhere::bf16* y3, const size_t d, float& dis0,
float& dis1, float& dis2, float& dis3);

void
fvec_L2sqr_batch_4_avx(const float* x, const float* y0, const float* y1, const float* y2, const float* y3,
const size_t d, float& dis0, float& dis1, float& dis2, float& dis3);

void
fvec_L2sqr_batch_4_avx_bf16_patch(const float* x, const float* y0, const float* y1, const float* y2, const float* y3,
const size_t d, float& dis0, float& dis1, float& dis2, float& dis3);
float
fvec_norm_L2sqr_avx(const float* x, size_t d);

void
fp16_vec_L2sqr_batch_4_avx(const knowhere::fp16* x, const knowhere::fp16* y0, const knowhere::fp16* y1,
const knowhere::fp16* y2, const knowhere::fp16* y3, const size_t d, float& dis0, float& dis1,
float& dis2, float& dis3);
fvec_L2sqr_ny_avx(float* dis, const float* x, const float* y, size_t d, size_t ny);

void
bf16_vec_L2sqr_batch_4_avx(const knowhere::bf16* x, const knowhere::bf16* y0, const knowhere::bf16* y1,
const knowhere::bf16* y2, const knowhere::bf16* y3, const size_t d, float& dis0, float& dis1,
float& dis2, float& dis3);
size_t
fvec_L2sqr_ny_nearest_avx(float* distances_tmp_buffer, const float* x, const float* y, size_t d, size_t ny);

// for hnsw sq, obsolete
int32_t
ivec_inner_product_avx(const int8_t* x, const int8_t* y, size_t d);

int32_t
ivec_L2sqr_avx(const int8_t* x, const int8_t* y, size_t d);

// fp16
float
fvec_norm_L2sqr_avx(const float* x, size_t d);
fp16_vec_L2sqr_avx(const knowhere::fp16* x, const knowhere::fp16* y, size_t d);

float
fp16_vec_inner_product_avx(const knowhere::fp16* x, const knowhere::fp16* y, size_t d);

float
fp16_vec_norm_L2sqr_avx(const knowhere::fp16* x, size_t d);

void
fp16_vec_inner_product_batch_4_avx(const knowhere::fp16* x, const knowhere::fp16* y0, const knowhere::fp16* y1,
const knowhere::fp16* y2, const knowhere::fp16* y3, const size_t d, float& dis0,
float& dis1, float& dis2, float& dis3);
void
fp16_vec_L2sqr_batch_4_avx(const knowhere::fp16* x, const knowhere::fp16* y0, const knowhere::fp16* y1,
const knowhere::fp16* y2, const knowhere::fp16* y3, const size_t d, float& dis0, float& dis1,
float& dis2, float& dis3);

// bf16
float
bf16_vec_L2sqr_avx(const knowhere::bf16* x, const knowhere::bf16* y, size_t d);

float
bf16_vec_inner_product_avx(const knowhere::bf16* x, const knowhere::bf16* y, size_t d);

float
bf16_vec_norm_L2sqr_avx(const knowhere::bf16* x, size_t d);

void
fvec_L2sqr_ny_avx(float* dis, const float* x, const float* y, size_t d, size_t ny);
bf16_vec_inner_product_batch_4_avx(const knowhere::bf16* x, const knowhere::bf16* y0, const knowhere::bf16* y1,
const knowhere::bf16* y2, const knowhere::bf16* y3, const size_t d, float& dis0,
float& dis1, float& dis2, float& dis3);

size_t
fvec_L2sqr_ny_nearest_avx(float* distances_tmp_buffer, const float* x, const float* y, size_t d, size_t ny);
void
bf16_vec_L2sqr_batch_4_avx(const knowhere::bf16* x, const knowhere::bf16* y0, const knowhere::bf16* y1,
const knowhere::bf16* y2, const knowhere::bf16* y3, const size_t d, float& dis0, float& dis1,
float& dis2, float& dis3);

//
float
fvec_inner_product_avx_bf16_patch(const float* x, const float* y, size_t d);

void
fvec_inner_product_batch_4_avx_bf16_patch(const float* x, const float* y0, const float* y1, const float* y2,
const float* y3, const size_t d, float& dis0, float& dis1, float& dis2,
float& dis3);

float
fvec_L2sqr_avx_bf16_patch(const float* x, const float* y, size_t d);

void
fvec_L2sqr_batch_4_avx_bf16_patch(const float* x, const float* y0, const float* y1, const float* y2, const float* y3,
const size_t d, float& dis0, float& dis1, float& dis2, float& dis3);

} // namespace faiss

Expand Down

0 comments on commit 9260210

Please sign in to comment.