Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Format simd APIs for better readability #1088

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,071 changes: 540 additions & 531 deletions src/simd/distances_avx.cc

Large diffs are not rendered by default.

105 changes: 58 additions & 47 deletions src/simd/distances_avx.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,28 +23,10 @@ namespace faiss {
float
fvec_L2sqr_avx(const float* x, const float* y, size_t d);

float
fvec_L2sqr_avx_bf16_patch(const float* x, const float* y, size_t d);

float
fp16_vec_L2sqr_avx(const knowhere::fp16* x, const knowhere::fp16* y, size_t d);

float
bf16_vec_L2sqr_avx(const knowhere::bf16* x, const knowhere::bf16* y, size_t d);

/// inner product
float
fvec_inner_product_avx(const float* x, const float* y, size_t d);

float
fvec_inner_product_avx_bf16_patch(const float* x, const float* y, size_t d);

float
fp16_vec_inner_product_avx(const knowhere::fp16* x, const knowhere::fp16* y, size_t d);

float
bf16_vec_inner_product_avx(const knowhere::bf16* x, const knowhere::bf16* y, size_t d);

/// L1 distance
float
fvec_L1_avx(const float* x, const float* y, size_t d);
Expand All @@ -60,59 +42,88 @@ void
fvec_inner_product_batch_4_avx(const float* x, const float* y0, const float* y1, const float* y2, const float* y3,
const size_t d, float& dis0, float& dis1, float& dis2, float& dis3);

void
fvec_inner_product_batch_4_avx_bf16_patch(const float* x, const float* y0, const float* y1, const float* y2,
const float* y3, const size_t d, float& dis0, float& dis1, float& dis2,
float& dis3);

void
fp16_vec_inner_product_batch_4_avx(const knowhere::fp16* x, const knowhere::fp16* y0, const knowhere::fp16* y1,
const knowhere::fp16* y2, const knowhere::fp16* y3, const size_t d, float& dis0,
float& dis1, float& dis2, float& dis3);

void
bf16_vec_inner_product_batch_4_avx(const knowhere::bf16* x, const knowhere::bf16* y0, const knowhere::bf16* y1,
const knowhere::bf16* y2, const knowhere::bf16* y3, const size_t d, float& dis0,
float& dis1, float& dis2, float& dis3);

void
fvec_L2sqr_batch_4_avx(const float* x, const float* y0, const float* y1, const float* y2, const float* y3,
const size_t d, float& dis0, float& dis1, float& dis2, float& dis3);

void
fvec_L2sqr_batch_4_avx_bf16_patch(const float* x, const float* y0, const float* y1, const float* y2, const float* y3,
const size_t d, float& dis0, float& dis1, float& dis2, float& dis3);
float
fvec_norm_L2sqr_avx(const float* x, size_t d);

void
fp16_vec_L2sqr_batch_4_avx(const knowhere::fp16* x, const knowhere::fp16* y0, const knowhere::fp16* y1,
const knowhere::fp16* y2, const knowhere::fp16* y3, const size_t d, float& dis0, float& dis1,
float& dis2, float& dis3);
fvec_L2sqr_ny_avx(float* dis, const float* x, const float* y, size_t d, size_t ny);

void
bf16_vec_L2sqr_batch_4_avx(const knowhere::bf16* x, const knowhere::bf16* y0, const knowhere::bf16* y1,
const knowhere::bf16* y2, const knowhere::bf16* y3, const size_t d, float& dis0, float& dis1,
float& dis2, float& dis3);
size_t
fvec_L2sqr_ny_nearest_avx(float* distances_tmp_buffer, const float* x, const float* y, size_t d, size_t ny);

///////////////////////////////////////////////////////////////////////////////
// for hnsw sq, obsolete

int32_t
ivec_inner_product_avx(const int8_t* x, const int8_t* y, size_t d);

int32_t
ivec_L2sqr_avx(const int8_t* x, const int8_t* y, size_t d);

///////////////////////////////////////////////////////////////////////////////
// fp16

float
fvec_norm_L2sqr_avx(const float* x, size_t d);
fp16_vec_inner_product_avx(const knowhere::fp16* x, const knowhere::fp16* y, size_t d);

float
fp16_vec_L2sqr_avx(const knowhere::fp16* x, const knowhere::fp16* y, size_t d);

float
fp16_vec_norm_L2sqr_avx(const knowhere::fp16* x, size_t d);

void
fp16_vec_inner_product_batch_4_avx(const knowhere::fp16* x, const knowhere::fp16* y0, const knowhere::fp16* y1,
const knowhere::fp16* y2, const knowhere::fp16* y3, const size_t d, float& dis0,
float& dis1, float& dis2, float& dis3);
void
fp16_vec_L2sqr_batch_4_avx(const knowhere::fp16* x, const knowhere::fp16* y0, const knowhere::fp16* y1,
const knowhere::fp16* y2, const knowhere::fp16* y3, const size_t d, float& dis0, float& dis1,
float& dis2, float& dis3);

///////////////////////////////////////////////////////////////////////////////
// bf16

float
bf16_vec_inner_product_avx(const knowhere::bf16* x, const knowhere::bf16* y, size_t d);

float
bf16_vec_L2sqr_avx(const knowhere::bf16* x, const knowhere::bf16* y, size_t d);

float
bf16_vec_norm_L2sqr_avx(const knowhere::bf16* x, size_t d);

void
fvec_L2sqr_ny_avx(float* dis, const float* x, const float* y, size_t d, size_t ny);
bf16_vec_inner_product_batch_4_avx(const knowhere::bf16* x, const knowhere::bf16* y0, const knowhere::bf16* y1,
const knowhere::bf16* y2, const knowhere::bf16* y3, const size_t d, float& dis0,
float& dis1, float& dis2, float& dis3);

size_t
fvec_L2sqr_ny_nearest_avx(float* distances_tmp_buffer, const float* x, const float* y, size_t d, size_t ny);
void
bf16_vec_L2sqr_batch_4_avx(const knowhere::bf16* x, const knowhere::bf16* y0, const knowhere::bf16* y1,
const knowhere::bf16* y2, const knowhere::bf16* y3, const size_t d, float& dis0, float& dis1,
float& dis2, float& dis3);

///////////////////////////////////////////////////////////////////////////////
// for cardinal

float
fvec_inner_product_bf16_patch_avx(const float* x, const float* y, size_t d);

float
fvec_L2sqr_bf16_patch_avx(const float* x, const float* y, size_t d);

void
fvec_inner_product_batch_4_bf16_patch_avx(const float* x, const float* y0, const float* y1, const float* y2,
const float* y3, const size_t d, float& dis0, float& dis1, float& dis2,
float& dis3);

void
fvec_L2sqr_batch_4_bf16_patch_avx(const float* x, const float* y0, const float* y1, const float* y2, const float* y3,
const size_t d, float& dis0, float& dis1, float& dis2, float& dis3);

} // namespace faiss

Expand Down
Loading
Loading