From b3e89a68ba80a041c8f6f616eebbf7ceed041e82 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Wed, 31 Jan 2024 17:53:54 +0800 Subject: [PATCH] =?UTF-8?q?fix(kernel):=20=E4=B8=BA=20TransposeInfo=20?= =?UTF-8?q?=E5=BA=94=E5=AF=B9=E6=9B=B4=E5=A4=9A=20coner=20case=20=E5=B9=B6?= =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E4=B8=80=E4=B8=AA=E5=8D=95=E5=85=83=E6=B5=8B?= =?UTF-8?q?=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- src/04kernel/src/attributes/transpose_info.cc | 7 +++- .../test/attributes/test_transpose_info.cpp | 39 +++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) create mode 100644 src/04kernel/test/attributes/test_transpose_info.cpp diff --git a/src/04kernel/src/attributes/transpose_info.cc b/src/04kernel/src/attributes/transpose_info.cc index d3b57315..9ae385a9 100644 --- a/src/04kernel/src/attributes/transpose_info.cc +++ b/src/04kernel/src/attributes/transpose_info.cc @@ -73,6 +73,12 @@ namespace refactor::kernel { } perm.resize(rank); } + if (rank <= 1) { + dims = {{1, 1}}; + blockSize *= blockCount; + blockCount = 1; + return; + } // 合并末尾连续访存 if (perm.back() == rank - 1) { blockSize *= shape.back(); @@ -81,7 +87,6 @@ namespace refactor::kernel { perm.pop_back(); --rank; } - // 计算 stride struct StrideI { dim_t strideI; diff --git a/src/04kernel/test/attributes/test_transpose_info.cpp b/src/04kernel/test/attributes/test_transpose_info.cpp new file mode 100644 index 00000000..fd735801 --- /dev/null +++ b/src/04kernel/test/attributes/test_transpose_info.cpp @@ -0,0 +1,39 @@ +#include "kernel/attributes/transpose_info.h" +#include + +using namespace refactor; +using namespace kernel; + +TEST(kernel, TransposeInfo) { + { + TransposeInfo info( + DataType::F32, + {1, 2, 3, 2, 1}, + {1, 2, 3, 0, 4}); + EXPECT_EQ(info.blockSize, 48); + EXPECT_EQ(info.blockCount, 1); + EXPECT_EQ(info.dims.size(), 1); + } + { + TransposeInfo info( + DataType::F32, + {1, 1, 2, 1, 1}, + {1, 2, 3, 0, 4}); + EXPECT_EQ(info.blockSize, 8); + EXPECT_EQ(info.blockCount, 1); + EXPECT_EQ(info.dims.size(), 1); + } + { + TransposeInfo info( + DataType::F32, + {1, 2, 3, 4, 5}, + {2, 3, 1, 0, 4}); + EXPECT_EQ(info.blockSize, 20); + EXPECT_EQ(info.blockCount, 24); + EXPECT_EQ(info.dims.size(), 2); + EXPECT_EQ(info.dims[1].strideI, 12); + EXPECT_EQ(info.dims[1].strideO, 1); + EXPECT_EQ(info.dims[0].strideI, 1); + EXPECT_EQ(info.dims[0].strideO, 2); + } +}