Skip to content

Commit

Permalink
fix(kernel): 为 TransposeInfo 应对更多 coner case 并增加一个单元测试
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Jan 31, 2024
1 parent 7db54c1 commit b3e89a6
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/04kernel/src/attributes/transpose_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ namespace refactor::kernel {
}
perm.resize(rank);
}
if (rank <= 1) {
dims = {{1, 1}};
blockSize *= blockCount;
blockCount = 1;
return;
}
// 合并末尾连续访存
if (perm.back() == rank - 1) {
blockSize *= shape.back();
Expand All @@ -81,7 +87,6 @@ namespace refactor::kernel {
perm.pop_back();
--rank;
}

// 计算 stride
struct StrideI {
dim_t strideI;
Expand Down
39 changes: 39 additions & 0 deletions src/04kernel/test/attributes/test_transpose_info.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#include "kernel/attributes/transpose_info.h"
#include <gtest/gtest.h>

using namespace refactor;
using namespace kernel;

TEST(kernel, TransposeInfo) {
{
TransposeInfo info(
DataType::F32,
{1, 2, 3, 2, 1},
{1, 2, 3, 0, 4});
EXPECT_EQ(info.blockSize, 48);
EXPECT_EQ(info.blockCount, 1);
EXPECT_EQ(info.dims.size(), 1);
}
{
TransposeInfo info(
DataType::F32,
{1, 1, 2, 1, 1},
{1, 2, 3, 0, 4});
EXPECT_EQ(info.blockSize, 8);
EXPECT_EQ(info.blockCount, 1);
EXPECT_EQ(info.dims.size(), 1);
}
{
TransposeInfo info(
DataType::F32,
{1, 2, 3, 4, 5},
{2, 3, 1, 0, 4});
EXPECT_EQ(info.blockSize, 20);
EXPECT_EQ(info.blockCount, 24);
EXPECT_EQ(info.dims.size(), 2);
EXPECT_EQ(info.dims[1].strideI, 12);
EXPECT_EQ(info.dims[1].strideO, 1);
EXPECT_EQ(info.dims[0].strideI, 1);
EXPECT_EQ(info.dims[0].strideO, 2);
}
}

0 comments on commit b3e89a6

Please sign in to comment.