diff --git a/src/02hardware/include/hardware/devices/nvidia.h b/src/02hardware/include/hardware/devices/nvidia.h
index d19dd315..18a4269d 100644
--- a/src/02hardware/include/hardware/devices/nvidia.h
+++ b/src/02hardware/include/hardware/devices/nvidia.h
@@ -3,6 +3,12 @@
 
 #include "../device.h"
 
+#define CUDA_ASSERT(STATUS)                                                          \
+    if (auto status = (STATUS); status != cudaSuccess) {                             \
+        RUNTIME_ERROR(fmt::format("cuda failed on \"" #STATUS "\" with \"{}\" ({})", \
+                                  cudaGetErrorString(status), (int) status));        \
+    }
+
 namespace refactor::hardware {
 
     class Nvidia final : public Device {
diff --git a/src/02hardware/src/devices/nvidia/device.cc b/src/02hardware/src/devices/nvidia/device.cc
index fd10cb70..20f63c0f 100644
--- a/src/02hardware/src/devices/nvidia/device.cc
+++ b/src/02hardware/src/devices/nvidia/device.cc
@@ -4,12 +4,6 @@
 #ifdef USE_CUDA
 #include "memory.hh"
 #include <cuda_runtime.h>
-
-#define CUDA_ASSERT(STATUS)                                                          \
-    if (auto status = (STATUS); status != cudaSuccess) {                             \
-        RUNTIME_ERROR(fmt::format("cuda failed on \"" #STATUS "\" with \"{}\" ({})", \
-                                  cudaGetErrorString(status), (int) status));        \
-    }
 #endif
 
 namespace refactor::hardware {
diff --git a/src/02hardware/src/devices/nvidia/memory.cc b/src/02hardware/src/devices/nvidia/memory.cc
index 42310196..1c3be21e 100644
--- a/src/02hardware/src/devices/nvidia/memory.cc
+++ b/src/02hardware/src/devices/nvidia/memory.cc
@@ -1,15 +1,9 @@
 #ifdef USE_CUDA
 
 #include "memory.hh"
-#include "common.h"
+#include "hardware/devices/nvidia.h"
 #include <cuda_runtime.h>
 
-#define CUDA_ASSERT(STATUS)                                                          \
-    if (auto status = (STATUS); status != cudaSuccess) {                             \
-        RUNTIME_ERROR(fmt::format("cuda failed on \"" #STATUS "\" with \"{}\" ({})", \
-                                  cudaGetErrorString(status), (int) status));        \
-    }
-
 namespace refactor::hardware {
     using M = NvidiaMemory;
 
diff --git a/src/04kernel/include/kernel/attributes/attention_info.h b/src/04kernel/include/kernel/attributes/attention_info.h
new file mode 100644
index 00000000..16d5fb0e
--- /dev/null
+++ b/src/04kernel/include/kernel/attributes/attention_info.h
@@ -0,0 +1,16 @@
+#ifndef KERNEL_ATTENTION_INFO_H
+#define KERNEL_ATTENTION_INFO_H
+
+#include "../tensor.h"
+
+namespace refactor::kernel {
+
+    struct AttentionInfo {
+        DataType dataType;
+        dim_t batch, nHead, nKVHead, seqLen, headDim, cacheLen;
+        bool concatCache, resetCache;
+    };
+
+}// namespace refactor::kernel
+
+#endif// KERNEL_ATTENTION_INFO_H
diff --git a/src/04kernel/include/kernel/collectors/attention.h b/src/04kernel/include/kernel/collectors/attention.h
index 527bc63f..abf33957 100644
--- a/src/04kernel/include/kernel/collectors/attention.h
+++ b/src/04kernel/include/kernel/collectors/attention.h
@@ -6,9 +6,8 @@
 namespace refactor::kernel {
 
     struct AttentionCollector final : public InfoCollector {
-        dim_t maxSeqLen;
 
-        AttentionCollector(decltype(_target), decltype(maxSeqLen)) noexcept;
+        AttentionCollector(decltype(_target)) noexcept;
 
         std::vector<KernelBox>
         filter(TensorRefs inputs, TensorRefs outputs) const final;
diff --git a/src/04kernel/src/collectors/attention.cc b/src/04kernel/src/collectors/attention.cc
index 3933097f..a778c128 100644
--- a/src/04kernel/src/collectors/attention.cc
+++ b/src/04kernel/src/collectors/attention.cc
@@ -1,38 +1,57 @@
 #include "kernel/collectors/attention.h"
+#include "kernel/attributes/attention_info.h"
 // #include "../kernels/attention/cpu_kernel.hh"
 #include "../kernels/attention/cuda_kernel.hh"
 
 namespace refactor::kernel {
 
     AttentionCollector::AttentionCollector(
-        decltype(_target) target,
-        decltype(maxSeqLen) maxSeqLen_) noexcept
-        : InfoCollector(target),
-          maxSeqLen(maxSeqLen_) {}
+        decltype(_target) target) noexcept
+        : InfoCollector(target) {}
 
     std::vector<KernelBox>
     AttentionCollector::filter(TensorRefs inputs, TensorRefs outputs) const {
         auto const &query = inputs[0].get();
         auto const &key = inputs[1].get();
-        auto pastSeqLen = inputs.size() == 3 ? 0 : *inputs[2].get().data->get<int64_t>();
-        auto cacheLen = outputs.size() == 1 ? 0 : outputs[1].get().shape[2];
 
-        std::vector<KernelBox> ans;
+        AttentionInfo info{
+            .dataType = query.dataType,
+            .batch = query.shape[0],
+            .nHead = query.shape[1],
+            .nKVHead = key.shape[1],
+            .seqLen = query.shape[2],
+            .headDim = query.shape[3],
+            .cacheLen = 0,
+            .concatCache = false,
+            .resetCache = false,
+        };
+        switch (outputs.size()) {
+            case 1:
+                // no kv cache
+                ASSERT(inputs.size() == 3, "");
+                break;
+            case 3:
+                switch (inputs.size()) {
+                    case 6:
+                        info.resetCache = true;
+                    case 4:
+                        info.concatCache = true;
+                    case 3:
+                        info.cacheLen = outputs[1].get().shape[2];
+                        break;
+                    default:
+                        UNREACHABLE();
+                }
+                break;
+            default:
+                UNREACHABLE();
+        }
+
+        std ::vector<KernelBox> ans;
         switch (_target) {
             case decltype(_target)::Cpu:
                 break;
             case decltype(_target)::Nvidia: {
-                decltype(AttentionCuda::info) info{
-                    .dataType = query.dataType,
-                    .batch = query.shape[0],
-                    .nHead = query.shape[1],
-                    .nKVHead = key.shape[1],
-                    .pastSeqLen = static_cast<dim_t>(pastSeqLen),
-                    .seqLen = query.shape[2],
-                    .cacheLen = cacheLen,
-                    .headDim = query.shape[3],
-                    .resetCache = false,
-                };
                 if (auto ptr = AttentionCuda::build(info); ptr) {
                     ans.emplace_back(std::move(ptr));
                 }
diff --git a/src/04kernel/src/kernels/attention/cuda_kernel.cu b/src/04kernel/src/kernels/attention/cuda_kernel.cu
new file mode 100644
index 00000000..a0f3f56a
--- /dev/null
+++ b/src/04kernel/src/kernels/attention/cuda_kernel.cu
@@ -0,0 +1,127 @@
+#include "../../utilities/cuda/cublaslt_utils.cuh"
+#include "cuda_kernel.hh"
+#include "hardware/functions.h"
+
+namespace refactor::kernel {
+    using K = AttentionCuda;
+    using namespace cublas;
+
+    RoutineWorkspace K::lower(Resources &res) const {
+        auto handle = res.fetchOrStore<CublasLtContext>()->handle;
+
+        constexpr auto ROW_MAJOR = CUBLASLT_ORDER_ROW;
+        constexpr auto COL_MAJOR = CUBLASLT_ORDER_COL;
+
+        if (!info.cacheLen) {
+            if (info.nHead == info.nKVHead) {
+                // RAII for closure
+                struct Descriptors {
+                    MatMulDescriptor mul;
+                    MatrixDescriptor q, k, v, att;
+                    cublasLtMatmulAlgo_t algoQK, algoAV;
+                    size_t attSize, workspaceSizeQK, workspaceSizeAV;
+
+                    Descriptors(CublasLtContext const &context,
+                                cublasComputeType_t compute,
+                                AttentionInfo info)
+                        : mul(compute, CUDA_R_32F),
+                          q(MatrixLayout{
+                              .dataType = dataTypeConvert(info.dataType),
+                              .rows = static_cast<uint64_t>(info.seqLen),
+                              .cols = static_cast<uint64_t>(info.headDim),
+                              .majorStride = static_cast<int64_t>(info.headDim),
+                              .order = ROW_MAJOR,
+                              .batchCount = static_cast<int32_t>(info.batch * info.nHead),
+                              .batchStride = static_cast<int64_t>(info.seqLen * info.headDim),
+                          }),
+                          k(MatrixLayout{
+                              .dataType = dataTypeConvert(info.dataType),
+                              .rows = static_cast<uint64_t>(info.headDim),
+                              .cols = static_cast<uint64_t>(info.seqLen),
+                              .majorStride = static_cast<int64_t>(info.headDim),
+                              .order = COL_MAJOR,
+                              .batchCount = static_cast<int32_t>(info.batch * info.nHead),
+                              .batchStride = static_cast<int64_t>(info.seqLen * info.headDim),
+                          }),
+                          v(MatrixLayout{
+                              .dataType = dataTypeConvert(info.dataType),
+                              .rows = static_cast<uint64_t>(info.seqLen),
+                              .cols = static_cast<uint64_t>(info.headDim),
+                              .majorStride = static_cast<int64_t>(info.headDim),
+                              .order = ROW_MAJOR,
+                              .batchCount = static_cast<int32_t>(info.batch * info.nHead),
+                              .batchStride = static_cast<int64_t>(info.seqLen * info.headDim),
+                          }),
+                          att(MatrixLayout{
+                              .dataType = dataTypeConvert(info.dataType),
+                              .rows = static_cast<uint64_t>(info.seqLen),
+                              .cols = static_cast<uint64_t>(info.seqLen),
+                              .majorStride = static_cast<int64_t>(info.seqLen),
+                              .order = ROW_MAJOR,
+                              .batchCount = static_cast<int32_t>(info.batch * info.nHead),
+                              .batchStride = static_cast<int64_t>(info.seqLen * info.seqLen),
+                          }),
+                          attSize(info.batch * info.nHead * info.seqLen * info.seqLen * info.dataType.size()) {
+                        auto [algoQK_, workspaceSizeQK_] = tune(context.handle, mul, q, k, att);
+                        auto [algoAV_, workspaceSizeAV_] = tune(context.handle, mul, att, v, q);
+                        algoQK = algoQK_;
+                        algoAV = algoAV_;
+                        workspaceSizeQK = workspaceSizeQK_;
+                        workspaceSizeAV = workspaceSizeAV_;
+                    }
+                };
+
+                auto const &context = *res.fetchOrStore<CublasLtContext>();
+                auto d = std::make_shared<Descriptors>(context, CUBLAS_COMPUTE_32F, info);
+                auto workspaceSize = d->attSize;
+                workspaceSize = hardware::alignBytes(workspaceSize, 256);
+                workspaceSize += d->workspaceSizeQK;
+                workspaceSize = hardware::alignBytes(workspaceSize, 256);
+                workspaceSize += d->workspaceSizeAV;
+                workspaceSize = hardware::alignBytes(workspaceSize, 256);
+
+                auto routine = [d = std::move(d), info = this->info]//
+                    (Resources & res, void *workspace, void const *const *inputs, void *const *outputs) {
+                        auto handle = res.fetchOrStore<CublasLtContext>()->handle;
+                        auto q = inputs[0];
+                        auto k = inputs[1];
+                        auto v = inputs[2];
+                        auto o = outputs[0];
+                        auto att = workspace;
+                        auto workspaceQK = reinterpret_cast<uint8_t *>(workspace) + hardware::alignBytes(d->attSize, 256);
+                        auto workspaceAV = workspaceQK + hardware::alignBytes(d->workspaceSizeQK, 256);
+
+                        float alpha = 1, beta = 0;
+                        cublasLtMatmul(
+                            handle, d->mul.get(),
+                            &alpha,
+                            q, d->q.get(),
+                            k, d->k.get(),
+                            &beta,
+                            att, d->att.get(),
+                            att, d->att.get(),
+                            &d->algoQK,
+                            workspaceQK, d->workspaceSizeQK,
+                            cudaStreamLegacy);
+
+                        // TODO inline mask && softmax
+
+                        cublasLtMatmul(
+                            handle, d->mul.get(),
+                            &alpha,
+                            att, d->att.get(),
+                            v, d->v.get(),
+                            &beta,
+                            o, d->q.get(),
+                            o, d->q.get(),
+                            &d->algoAV,
+                            workspaceAV, d->workspaceSizeAV,
+                            cudaStreamLegacy);
+                    };
+                return {std::move(routine), workspaceSize};
+            }
+        }
+        TODO("");
+    }
+
+}// namespace refactor::kernel
diff --git a/src/04kernel/src/kernels/attention/cuda_kernel.hh b/src/04kernel/src/kernels/attention/cuda_kernel.hh
index 5ea19ae8..20cf9712 100644
--- a/src/04kernel/src/kernels/attention/cuda_kernel.hh
+++ b/src/04kernel/src/kernels/attention/cuda_kernel.hh
@@ -1,17 +1,13 @@
 #ifndef KERNEL_ATTENTION_CUDA_KERNEL_HH
 #define KERNEL_ATTENTION_CUDA_KERNEL_HH
 
+#include "kernel/attributes/attention_info.h"
 #include "kernel/kernel.h"
-#include "kernel/tensor.h"
 
 namespace refactor::kernel {
 
     struct AttentionCuda final : public Kernel {
-        struct {
-            DataType dataType;
-            dim_t batch, nHead, nKVHead, pastSeqLen, seqLen, cacheLen, headDim;
-            bool resetCache;
-        } info;
+        AttentionInfo info;
 
         AttentionCuda(decltype(info)) noexcept;
 
diff --git a/src/04kernel/src/utilities/cuda/cublaslt_context.cu b/src/04kernel/src/utilities/cuda/cublaslt_context.cu
deleted file mode 100644
index 2fc8fb18..00000000
--- a/src/04kernel/src/utilities/cuda/cublaslt_context.cu
+++ /dev/null
@@ -1,33 +0,0 @@
-#include "common.h"
-#include "cublaslt_context.hh"
-
-namespace refactor::kernel::cublas {
-
-    CublasLtContext::CublasLtContext() : runtime::Resource() {
-        if (cublasLtCreate(&handle) != CUBLAS_STATUS_SUCCESS) {
-            RUNTIME_ERROR("Failed to create cublasLt handle");
-        }
-    }
-    CublasLtContext::~CublasLtContext() {
-        if (cublasLtDestroy(handle) != CUBLAS_STATUS_SUCCESS) {
-            fmt::println("Failed to destroy cublasLt handle");
-            abort();
-        }
-    }
-
-    auto CublasLtContext::typeId() noexcept -> size_t {
-        static uint8_t ID = 1;
-        return reinterpret_cast<size_t>(&ID);
-    }
-    auto CublasLtContext::build() noexcept -> runtime::ResourceBox {
-        return std::make_unique<CublasLtContext>();
-    }
-
-    auto CublasLtContext::resourceTypeId() const noexcept -> size_t {
-        return typeId();
-    }
-    auto CublasLtContext::description() const noexcept -> std::string_view {
-        return "CublasLtContext";
-    }
-
-}// namespace refactor::kernel::cublas
diff --git a/src/04kernel/src/utilities/cuda/cublaslt_context.hh b/src/04kernel/src/utilities/cuda/cublaslt_context.hh
deleted file mode 100644
index 84e1d2d9..00000000
--- a/src/04kernel/src/utilities/cuda/cublaslt_context.hh
+++ /dev/null
@@ -1,33 +0,0 @@
-#ifndef KERNEL_CUBLASLT_CONTEXT_HH
-#define KERNEL_CUBLASLT_CONTEXT_HH
-
-#include "runtime/resource.h"
-#include <cublasLt.h>
-
-#define CUBLAS_ASSERT(STATUS)                                      \
-    if (auto status = (STATUS); status != CUBLAS_STATUS_SUCCESS) { \
-        fmt::println("cublas failed on \"" #STATUS "\" with {}",   \
-                     (int) status);                                \
-        abort();                                                   \
-    }
-
-namespace refactor::kernel::cublas {
-
-    struct CublasLtContext final : public runtime::Resource {
-        cublasLtHandle_t handle;
-
-        CublasLtContext();
-        ~CublasLtContext();
-        CublasLtContext(CublasLtContext const &) noexcept = delete;
-        CublasLtContext(CublasLtContext &&) noexcept = delete;
-
-        static size_t typeId() noexcept;
-        static runtime::ResourceBox build() noexcept;
-
-        size_t resourceTypeId() const noexcept final;
-        std::string_view description() const noexcept final;
-    };
-
-}// namespace refactor::kernel::cublas
-
-#endif// KERNEL_CUBLASLT_CONTEXT_HH
diff --git a/src/04kernel/src/utilities/cuda/cublaslt_utils.cu b/src/04kernel/src/utilities/cuda/cublaslt_utils.cu
new file mode 100644
index 00000000..d07af6ab
--- /dev/null
+++ b/src/04kernel/src/utilities/cuda/cublaslt_utils.cu
@@ -0,0 +1,145 @@
+#include "cublaslt_utils.cuh"
+#include "hardware/devices/nvidia.h"
+
+namespace refactor::kernel::cublas {
+
+    CublasLtContext::CublasLtContext() : runtime::Resource() {
+        if (cublasLtCreate(&handle) != CUBLAS_STATUS_SUCCESS) {
+            RUNTIME_ERROR("Failed to create cublasLt handle");
+        }
+    }
+    CublasLtContext::~CublasLtContext() {
+        if (cublasLtDestroy(handle) != CUBLAS_STATUS_SUCCESS) {
+            fmt::println("Failed to destroy cublasLt handle");
+            abort();
+        }
+    }
+
+    auto CublasLtContext::typeId() noexcept -> size_t {
+        static uint8_t ID = 1;
+        return reinterpret_cast<size_t>(&ID);
+    }
+    auto CublasLtContext::build() noexcept -> runtime::ResourceBox {
+        return std::make_unique<CublasLtContext>();
+    }
+
+    auto CublasLtContext::resourceTypeId() const noexcept -> size_t {
+        return typeId();
+    }
+    auto CublasLtContext::description() const noexcept -> std::string_view {
+        return "CublasLtContext";
+    }
+
+    cudaDataType dataTypeConvert(DataType dt) {
+        switch (dt) {
+            case DataType::F32:
+                return CUDA_R_32F;
+            default:
+                TODO("");
+        }
+    }
+
+    MatMulDescriptor::MatMulDescriptor(cublasComputeType_t compute, cudaDataType data)
+        : _internal(nullptr) {
+        CUBLASLT_ASSERT(cublasLtMatmulDescCreate(&_internal, compute, data));
+    }
+    MatMulDescriptor::~MatMulDescriptor() {
+        CUBLASLT_ASSERT(cublasLtMatmulDescDestroy(_internal));
+    }
+    cublasLtMatmulDesc_t MatMulDescriptor::get() const noexcept {
+        return _internal;
+    }
+
+    MatrixDescriptor::MatrixDescriptor(MatrixLayout layout)
+        : _internal(nullptr) {
+        CUBLASLT_ASSERT(cublasLtMatrixLayoutCreate(
+            &_internal,
+            layout.dataType,
+            layout.rows,
+            layout.cols,
+            layout.majorStride));
+        CUBLASLT_ASSERT(cublasLtMatrixLayoutSetAttribute(
+            _internal,
+            CUBLASLT_MATRIX_LAYOUT_ORDER,
+            &layout.order,
+            sizeof(layout.order)));
+        CUBLASLT_ASSERT(cublasLtMatrixLayoutSetAttribute(
+            _internal,
+            CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
+            &layout.batchCount,
+            sizeof(layout.batchCount)));
+        CUBLASLT_ASSERT(cublasLtMatrixLayoutSetAttribute(
+            _internal,
+            CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
+            &layout.batchStride,
+            sizeof(layout.batchStride)));
+    }
+    MatrixDescriptor::~MatrixDescriptor() {
+        CUBLASLT_ASSERT(cublasLtMatrixLayoutDestroy(_internal));
+    }
+    cublasLtMatrixLayout_t MatrixDescriptor::get() const noexcept {
+        return _internal;
+    }
+
+    std::pair<cublasLtMatmulAlgo_t, size_t>
+    tune(cublasLtHandle_t handle,
+         MatMulDescriptor const &matmul,
+         MatrixDescriptor const &a,
+         MatrixDescriptor const &b,
+         MatrixDescriptor const &c) {
+
+        int device;
+        CUDA_ASSERT(cudaGetDevice(&device));
+        cudaDeviceProp prop;
+        CUDA_ASSERT(cudaGetDeviceProperties(&prop, device));
+
+        auto workspace = std::numeric_limits<uint64_t>::max();
+        auto alignment = prop.textureAlignment;
+
+        cublasLtMatmulPreference_t preference;
+        CUBLASLT_ASSERT(cublasLtMatmulPreferenceCreate(&preference));
+        CUBLASLT_ASSERT(cublasLtMatmulPreferenceSetAttribute(
+            preference,
+            CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
+            &workspace,
+            sizeof(workspace)));
+        CUBLASLT_ASSERT(cublasLtMatmulPreferenceSetAttribute(
+            preference,
+            CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES,
+            &alignment,
+            sizeof(alignment)));
+        CUBLASLT_ASSERT(cublasLtMatmulPreferenceSetAttribute(
+            preference,
+            CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES,
+            &alignment,
+            sizeof(alignment)));
+        CUBLASLT_ASSERT(cublasLtMatmulPreferenceSetAttribute(
+            preference,
+            CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES,
+            &alignment,
+            sizeof(alignment)));
+        CUBLASLT_ASSERT(cublasLtMatmulPreferenceSetAttribute(
+            preference,
+            CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES,
+            &alignment,
+            sizeof(alignment)));
+
+        cublasLtMatmulHeuristicResult_t result;
+        int ansN;
+        CUBLASLT_ASSERT(cublasLtMatmulAlgoGetHeuristic(
+            handle,
+            matmul.get(),
+            a.get(),
+            b.get(),
+            c.get(),
+            c.get(),
+            preference,
+            1,
+            &result,
+            &ansN));
+        ASSERT(ansN == 1, "");
+
+        return {result.algo, result.workspaceSize};
+    }
+
+}// namespace refactor::kernel::cublas
diff --git a/src/04kernel/src/utilities/cuda/cublaslt_utils.cuh b/src/04kernel/src/utilities/cuda/cublaslt_utils.cuh
new file mode 100644
index 00000000..5dd23607
--- /dev/null
+++ b/src/04kernel/src/utilities/cuda/cublaslt_utils.cuh
@@ -0,0 +1,74 @@
+#ifndef KERNEL_CUBLASLT_UTILS_CUH
+#define KERNEL_CUBLASLT_UTILS_CUH
+
+#include "common.h"
+#include "runtime/resource.h"
+#include <cublasLt.h>
+
+#define CUBLASLT_ASSERT(STATUS)                                    \
+    if (auto status = (STATUS); status != CUBLAS_STATUS_SUCCESS) { \
+        fmt::println("cublasLt failed on \"" #STATUS "\" with {}", \
+                     (int) status);                                \
+        abort();                                                   \
+    }
+
+namespace refactor::kernel::cublas {
+
+    struct CublasLtContext final : public runtime::Resource {
+        cublasLtHandle_t handle;
+
+        CublasLtContext();
+        ~CublasLtContext();
+        CublasLtContext(CublasLtContext const &) noexcept = delete;
+        CublasLtContext(CublasLtContext &&) noexcept = delete;
+
+        static size_t typeId() noexcept;
+        static runtime::ResourceBox build() noexcept;
+
+        size_t resourceTypeId() const noexcept final;
+        std::string_view description() const noexcept final;
+    };
+
+    cudaDataType dataTypeConvert(DataType);
+
+    class MatMulDescriptor {
+        cublasLtMatmulDesc_t _internal;
+
+    public:
+        MatMulDescriptor(cublasComputeType_t, cudaDataType);
+        ~MatMulDescriptor();
+        MatMulDescriptor(MatMulDescriptor const &) noexcept = delete;
+        MatMulDescriptor(MatMulDescriptor &&) noexcept = delete;
+        cublasLtMatmulDesc_t get() const noexcept;
+    };
+
+    struct MatrixLayout {
+        cudaDataType dataType;
+        uint64_t rows, cols;
+        int64_t majorStride;
+        cublasLtOrder_t order;
+        int32_t batchCount;
+        int64_t batchStride;
+    };
+
+    class MatrixDescriptor {
+        cublasLtMatrixLayout_t _internal;
+
+    public:
+        MatrixDescriptor(MatrixLayout layout);
+        ~MatrixDescriptor();
+        MatrixDescriptor(MatrixDescriptor const &) noexcept = delete;
+        MatrixDescriptor(MatrixDescriptor &&) noexcept = delete;
+        cublasLtMatrixLayout_t get() const noexcept;
+    };
+
+    std::pair<cublasLtMatmulAlgo_t, size_t>
+    tune(cublasLtHandle_t,
+         MatMulDescriptor const &,
+         MatrixDescriptor const &,
+         MatrixDescriptor const &,
+         MatrixDescriptor const &);
+
+}// namespace refactor::kernel::cublas
+
+#endif// KERNEL_CUBLASLT_UTILS_CUH