Skip to content

Commit

Permalink
fix(kernel): 稍微调整 MatMulInteger 逻辑
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Jan 26, 2024
1 parent c462d7d commit 81c2b71
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 8 deletions.
13 changes: 7 additions & 6 deletions src/04kernel/src/attributes/mat_mul_integer_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,17 @@ namespace refactor::kernel {
scalar(true) {
if (inputs.size() > i + 2) {
auto const &t = inputs[i + 2].get();
signed_ = t.dataType == DataType::I8;

auto size = t.elementsSize();
scalar = size == 1;

if (t.data) {
auto data = slice(t.data->get<uint8_t>(), size);
if (std::all_of(data.begin(), data.end(), [](auto x) { return x == 0; })) {
return;
}
withZeroPoint = std::any_of(data.begin(), data.end(), [](auto x) { return x != 0; });
} else {
withZeroPoint = true;
}
withZeroPoint = true;
signed_ = t.dataType == DataType::I8;
scalar = size == 1;
}
}

Expand Down
28 changes: 26 additions & 2 deletions src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,28 @@ namespace refactor::kernel {
});
}

struct MatMulIntegerCastFunctor {
__device__ int8_t operator()(uint8_t x) const noexcept {
return static_cast<int8_t>(CUB_MIN(127, x));
}
};

static void applyCast(
size_t size, int8_t *dst, void const *src_) {

auto src = reinterpret_cast<uint8_t const *>(src_);
thrust::transform(thrust::device,
src, src + size,
dst, MatMulIntegerCastFunctor{});
}

auto MatMulIntegerCublas::lower(Resources &res) const noexcept -> RoutineWorkspace {

size_t workspace = 0;
if (info.a.withZeroPoint) {
if (info.a.withZeroPoint || !info.a.signed_) {
workspace += info.batch() * info.m * info.k;
}
if (info.b.withZeroPoint) {
if (info.b.withZeroPoint || !info.b.signed_) {
workspace += info.batch() * info.k * info.n;
}

Expand Down Expand Up @@ -118,6 +133,11 @@ namespace refactor::kernel {
}
a = workspacePtr;
workspacePtr += size;
} else if (!meta.signed_) {
auto size = info.batch() * info.m * info.k;
applyCast(size, workspacePtr, a);
a = workspacePtr;
workspacePtr += size;
}
if (auto meta = info.b; meta.withZeroPoint) {
auto size = info.batch() * info.k * info.n;
Expand All @@ -136,6 +156,10 @@ namespace refactor::kernel {
}
}
b = workspacePtr;
} else if (!meta.signed_) {
auto size = info.batch() * info.k * info.n;
applyCast(size, workspacePtr, b);
b = workspacePtr;
}

auto handle = res.fetchOrStore<CublasContext>()->handle;
Expand Down

0 comments on commit 81c2b71

Please sign in to comment.