Skip to content

Commit

Permalink
build(kernel): 添加 cublasLt
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Jan 30, 2024
1 parent f1faf3e commit 9495516
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/04kernel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ if(USE_CUDA)
# nvrtc for cuda kernel compile
# cublas for matmul
# cudnn for conv and others
target_link_libraries(kernel PUBLIC cuda nvrtc cublas cudnn kernel_cuda)
target_link_libraries(kernel PUBLIC cuda nvrtc cublas cublasLt cudnn kernel_cuda)
target_include_directories(kernel PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
endif()
if(USE_KUNLUN)
Expand Down
33 changes: 33 additions & 0 deletions src/04kernel/src/utilities/cuda/cublaslt_context.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#include "common.h"
#include "cublaslt_context.hh"

namespace refactor::kernel::cublas {

CublasLtContext::CublasLtContext() : runtime::Resource() {
if (cublasLtCreate(&handle) != CUBLAS_STATUS_SUCCESS) {
RUNTIME_ERROR("Failed to create cublasLt handle");
}
}
CublasLtContext::~CublasLtContext() {
if (cublasLtDestroy(handle) != CUBLAS_STATUS_SUCCESS) {
fmt::println("Failed to destroy cublasLt handle");
abort();
}
}

auto CublasLtContext::typeId() noexcept -> size_t {
static uint8_t ID = 1;
return reinterpret_cast<size_t>(&ID);
}
auto CublasLtContext::build() noexcept -> runtime::ResourceBox {
return std::make_unique<CublasLtContext>();
}

auto CublasLtContext::resourceTypeId() const noexcept -> size_t {
return typeId();
}
auto CublasLtContext::description() const noexcept -> std::string_view {
return "CublasLtContext";
}

}// namespace refactor::kernel::cublas
33 changes: 33 additions & 0 deletions src/04kernel/src/utilities/cuda/cublaslt_context.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#ifndef KERNEL_CUBLASLT_CONTEXT_HH
#define KERNEL_CUBLASLT_CONTEXT_HH

#include "runtime/resource.h"
#include <cublasLt.h>

#define CUBLAS_ASSERT(STATUS) \
if (auto status = (STATUS); status != CUBLAS_STATUS_SUCCESS) { \
fmt::println("cublas failed on \"" #STATUS "\" with {}", \
(int) status); \
abort(); \
}

namespace refactor::kernel::cublas {

struct CublasLtContext final : public runtime::Resource {
cublasLtHandle_t handle;

CublasLtContext();
~CublasLtContext();
CublasLtContext(CublasLtContext const &) noexcept = delete;
CublasLtContext(CublasLtContext &&) noexcept = delete;

static size_t typeId() noexcept;
static runtime::ResourceBox build() noexcept;

size_t resourceTypeId() const noexcept final;
std::string_view description() const noexcept final;
};

}// namespace refactor::kernel::cublas

#endif// KERNEL_CUBLASLT_CONTEXT_HH

0 comments on commit 9495516

Please sign in to comment.