diff --git a/src/04kernel/src/attributes/transpose_info.cc b/src/04kernel/src/attributes/transpose_info.cc index 82eb0c9f..2b563d03 100644 --- a/src/04kernel/src/attributes/transpose_info.cc +++ b/src/04kernel/src/attributes/transpose_info.cc @@ -35,7 +35,12 @@ namespace refactor::kernel { } } } - if (rank == 0) { return; } + if (rank == 0) { + dims = {{1, 1}}; + blockSize *= blockCount; + blockCount = 1; + return; + } // 合并连续的维度 { std::vector mapDim(rank, 0); @@ -68,6 +73,14 @@ namespace refactor::kernel { } perm.resize(rank); } + // 合并末尾连续访存 + if (perm.back() == rank - 1) { + blockSize *= shape.back(); + blockCount /= shape.back(); + shape.pop_back(); + perm.pop_back(); + --rank; + } // 计算 stride struct StrideI { dim_t strideI;