Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make simd a default behavior #34

Merged
merged 2 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ include(cmake/default.cmake)
#gcc 10 g++10
# First compile faiss before anything else
#set(CMAKE_CXX_FLAGS "-fno-openmp")
#test avx2

add_subdirectory(thirdparty/faiss)

# Set specific options for Faiss compilation
Expand All @@ -29,6 +31,33 @@ set(CMAKE_CXX_FLAGS "-std=c++20 -Wall -Werror=return-type")
set(CMAKE_CXX_FLAGS_DEBUG "-g -O0 -DNO_RACE_CHECK -DCANDY_DEBUG_MODE=1")
set(CMAKE_CXX_FLAGS_RELEASE "-Wno-ignored-qualifiers -Wno-sign-compare -O3")
set(PROJECT_BINARY_DIR_RAW ${PROJECT_BINARY_DIR})

# Valid values are "generic", "avx2", "avx512".

detect_avx512_support(AVX512_AVAILABLE)
# Use AVX-512 based on the result
if(AVX512_AVAILABLE)
message(STATUS "AVX-512 support detected.")
set(CANDY_AVX512 1)
set(CANDY_AVX2 1)
else()
message(STATUS "AVX-512 support NOT detected.")
detect_avx2_support(AVX2_AVAILABLE)
if(AVX2_AVAILABLE)
message(STATUS "AVX-2 support detected.")
set(CANDY_AVX2 1)
else ()
message(STATUS "AVX-2 support not detected.")
set(CANDY_AVX2 0)
set(CANDY_AVX512 0)
endif ()
endif()
configure_file(
"${PROJECT_SOURCE_DIR}/include/simd_config.h.in"
"${PROJECT_BINARY_DIR}/include/simd_config.h"
)


#set(CMAKE_CUDA_STANDARD 11)
#set(CMAKE_CUDA_FLAGS "-std=c++11")
option(ENABLE_OPENCL
Expand Down
46 changes: 45 additions & 1 deletion cmake/macros.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,48 @@ endmacro()

macro(get_headers HEADER_FILES)
file(GLOB_RECURSE ${HEADER_FILES} "include/*.h" "include/*.hpp")
endmacro()
endmacro()

# Define the function to detect AVX-512 support
function(detect_avx512_support result_var)
include(CheckCXXSourceCompiles)
set(CMAKE_REQUIRED_FLAGS "-mavx512f")
check_cxx_source_compiles("
#include <immintrin.h>
int main() {
__m512i vec = _mm512_set1_epi32(1); // AVX-512 intrinsic
return 0;
}
" HAVE_AVX512)

if(HAVE_AVX512)
#message(STATUS "AVX-512 support detected.")
set(${result_var} 1 PARENT_SCOPE)
else()
# message(STATUS "AVX-512 support NOT detected.")
set(${result_var} 0 PARENT_SCOPE)
endif()
endfunction()

function(detect_avx2_support result_var)
include(CheckCXXSourceCompiles)
# Save the current compiler flags to restore them later
set(saved_flags "${CMAKE_CXX_FLAGS}")
# Test AVX2 intrinsic support by compiling a minimal test program
check_cxx_source_compiles("
#include <immintrin.h>
int main() {
__m256i vec = _mm256_set1_epi32(1); // AVX2 intrinsic
return 0;
}
" HAVE_AVX2)

# Restore the original compiler flags
set(CMAKE_CXX_FLAGS "${saved_flags}" PARENT_SCOPE)
# Return TRUE or FALSE based on the test result
if(HAVE_AVX2)
set(${result_var} 1 PARENT_SCOPE)
else()
set(${result_var} 0 PARENT_SCOPE)
endif()
endfunction()
33 changes: 5 additions & 28 deletions include/CANDY/LSHAPGIndex/basis.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,9 @@ struct Res//the result of knns

inline float cal_inner_product(float* v1, float* v2, int dim)
{
#if (defined __AVX2__ && defined __USE__AVX2__ZX__)
return faiss::fvec_inner_product_avx512(v1, v2, dim);
#else
return calIp_fast(v1, v2, dim);
#endif

return calIp_fast(v1, v2, dim);

}

inline float cal_lengthSquare(float* v1, int dim)
Expand All @@ -187,34 +185,13 @@ inline float cal_lengthSquare(float* v1, int dim)
extern int _g_dist_mes;
inline float cal_dist(float* v1, float* v2, int dim)
{
if(_g_dist_mes==1) {
return 1.0-cal_inner_product(v1,v2,dim);
}
#ifdef USE_SQRDIST
#if (defined __AVX2__ && defined __USE__AVX2__ZX__)
return faiss::fvec_L2sqr_avx512(v1, v2, dim);
#else
return calL2Sqr_fast(v1, v2, dim);
#endif
#else
#if (defined __AVX2__ && defined __USE__AVX2__ZX__)
return sqrt(faiss::fvec_L2sqr_avx512(v1, v2, dim));
#else
return sqrt(calL2Sqr_fast(v1, v2, dim));
#endif
#endif
return calL2Sqr_fast(v1, v2, dim);

}

inline float cal_distSqrt(float* v1, float* v2, int dim)
{
#if (defined __AVX2__ && defined __USE__AVX2__ZX__)
return sqrt(faiss::fvec_L2sqr_avx512(v1, v2, dim));
#else
return sqrt(calL2Sqr_fast(v1, v2, dim));
#endif
//return sqrt(calL2Sqr_fast(v1, v2, dim));

return calL2Sqr_fast(v1, v2, dim);
}

template <class T>
Expand Down
Loading
Loading