diff --git a/accera/acc-opt/test/thrifty_caching.mlir b/accera/acc-opt/test/thrifty_caching.mlir index 5e4e2bee..b0d4eb5c 100644 --- a/accera/acc-opt/test/thrifty_caching.mlir +++ b/accera/acc-opt/test/thrifty_caching.mlir @@ -69,8 +69,8 @@ module @test_thrifty_caching_simple_input_cache attributes {llvm.data_layout = " // CHECK: affine.for %arg6 = 0 to 16 { // CHECK: %1 = affine.load %arg1[%arg5, %arg4 + %arg6] : memref<32x32xf32, #map0> // CHECK: affine.store %1, %0[%arg5, %arg6] : memref<32x16xf32, 3> -// CHECK: } {accxp.access_bounds_check, beginMap = #map1, endMap = #map2, index = #accln<"index{j,7}">, kernels = ["cache_internal_loopnest_kernel_active_block_copy"], operand_segment_sizes = dense<[0, 0, 1]> : vector<3xi32>, scheduledIndex = #accln<"index{j,7}">, subdomainIndexOrder = [#accln<"index{i,6}">, #accln<"index{j,7}">], subdomainSize = [32, 16]} -// CHECK: } {accxp.access_bounds_check, beginMap = #map1, endMap = #map3, index = #accln<"index{i,6}">, operand_segment_sizes = dense<[0, 0, 1]> : vector<3xi32>, scheduledIndex = #accln<"index{i,6}">, subdomainIndexOrder = [#accln<"index{i,6}">, #accln<"index{j,7}">], subdomainSize = [32, 16]} +// CHECK: } {accaffine.access_bounds_check, beginMap = #map1, endMap = #map2, index = #accln<"index{j,7}">, kernels = ["cache_internal_loopnest_kernel_active_block_copy"], operand_segment_sizes = dense<[0, 0, 1]> : vector<3xi32>, scheduledIndex = #accln<"index{j,7}">, subdomainIndexOrder = [#accln<"index{i,6}">, #accln<"index{j,7}">], subdomainSize = [32, 16]} +// CHECK: } {accaffine.access_bounds_check, beginMap = #map1, endMap = #map3, index = #accln<"index{i,6}">, operand_segment_sizes = dense<[0, 0, 1]> : vector<3xi32>, scheduledIndex = #accln<"index{i,6}">, subdomainIndexOrder = [#accln<"index{i,6}">, #accln<"index{j,7}">], subdomainSize = [32, 16]} // CHECK: affine.for %arg5 = 0 to 4 { // CHECK: affine.for %arg6 = 0 to 16 { // CHECK: affine.for %arg7 = 0 to 32 { diff --git a/accera/acc-opt/test/vectorization.mlir b/accera/acc-opt/test/vectorization.mlir index 27718b86..e72a907e 100644 --- a/accera/acc-opt/test/vectorization.mlir +++ b/accera/acc-opt/test/vectorization.mlir @@ -1,4 +1,4 @@ -// RUN: acc-opt --verify-each=false --acc-vectorize %s | FileCheck %s +// RUN: acc-opt --verify-each=false --acc-vectorize -split-input-file %s | FileCheck %s module @test_accera_vectorization attributes {accv.target_device_features = "-avx512pf,-tsxldtrk,+cx16,+sahf,-tbm,-avx512ifma,-sha,+crc32,-fma4,-vpclmulqdq,-prfchw,+bmi2,-cldemote,+fsgsbase,-ptwrite,-amx-tile,-uintr,-gfni,+popcnt,-widekl,+aes,-avx512bitalg,-movdiri,-xsaves,-avx512er,-avxvnni,-avx512fp16,-avx512vnni,-amx-bf16,-avx512vpopcntdq,-pconfig,-clwb,-avx512f,-xsavec,-clzero,-pku,+mmx,-lwp,-rdpid,-xop,-rdseed,-waitpkg,-kl,-movdir64b,-sse4a,-avx512bw,-clflushopt,+xsave,-avx512vbmi2,+64bit,-avx512vl,-serialize,-hreset,+invpcid,-avx512cd,+avx,-vaes,-avx512bf16,+cx8,+fma,-rtm,+bmi,-enqcmd,+rdrnd,-mwaitx,+sse4.1,+sse4.2,+avx2,+fxsr,-wbnoinvd,+sse,+lzcnt,+pclmul,-prefetchwt1,+f16c,+ssse3,-sgx,-shstk,+cmov,-avx512vbmi,-amx-int8,+movbe,-avx512vp2intersect,+xsaveopt,-avx512dq,+sse2,-adx,+sse"} { accv.module "test_accera_vectorization" { @@ -80,5 +80,236 @@ module @test_accera_vectorization attributes {accv.target_device_features = "-av } {beginMap = affine_map<() -> (0)>, endMap = affine_map<() -> (1536)>, index = #accln<"index{i_o,268}">, operand_segment_sizes = dense<[0, 0, 1]> : vector<3xi32>, subdomainIndexOrder = [#accln<"index{i,266}">, #accln<"index{j,267}">], subdomainSize = [1885, 256]} return } + + // CHECK-LABEL builtin.func nested @test_view_split_dim_interleaved_pack + builtin.func nested @test_int16_to_int32_horizontal_vector_add(%arg0: memref<256x16xi16> loc(unknown), %arg1: memref<256xi32> loc(unknown)) attributes {accv.dyn_arg_size_refs = [[-1, -1], [-1]], accv.usages = [1 : i8, 1 : i8], args_name = ["", ""], args_size = ["256*16", "256"], args_symbol = ["args_symbol_name_0", "args_symbol_name_1"], exec_target = 0 : i64} { + // CHECK: affine.for %arg2 = 0 to 256 step 4 { + affine.for %arg2 = 0 to 256 step 4 { + // %0 = memref.reinterpret_cast %arg0 to offset: [0], sizes: [4096], strides: [1] : memref<256x16xi16> to memref<4096xi16> loc(unknown) + // %1 = affine.apply affine_map<(d0, d1, d2) -> ((d1 + d2) * 16 + d0)>(%c0, %arg2, %c0) + // %2 = vector.load %0[%1] : memref<4096xi16>, vector<16xi16> + // %3 = memref.reinterpret_cast %arg0 to offset: [0], sizes: [4096], strides: [1] : memref<256x16xi16> to memref<4096xi16> loc(unknown) + // %4 = affine.apply affine_map<(d0, d1, d2) -> ((d1 + d2) * 16 + d0)>(%c0, %arg2, %c1) + // %5 = vector.load %3[%4] : memref<4096xi16>, vector<16xi16> + // %6 = memref.reinterpret_cast %arg0 to offset: [0], sizes: [4096], strides: [1] : memref<256x16xi16> to memref<4096xi16> loc(unknown) + // %7 = affine.apply affine_map<(d0, d1, d2) -> ((d1 + d2) * 16 + d0)>(%c0, %arg2, %c2) + // %8 = vector.load %6[%7] : memref<4096xi16>, vector<16xi16> + // %9 = memref.reinterpret_cast %arg0 to offset: [0], sizes: [4096], strides: [1] : memref<256x16xi16> to memref<4096xi16> loc(unknown) + // %10 = affine.apply affine_map<(d0, d1, d2) -> ((d1 + d2) * 16 + d0)>(%c0, %arg2, %c3) + // %11 = vector.load %9[%10] : memref<4096xi16>, vector<16xi16> + // %12 = "accv.vpmaddwd"(%2, %cst) : (vector<16xi16>, vector<16xi16>) -> vector<8xi32> + // %13 = "accv.vpmaddwd"(%5, %cst) : (vector<16xi16>, vector<16xi16>) -> vector<8xi32> + // %14 = "accv.vpmaddwd"(%8, %cst) : (vector<16xi16>, vector<16xi16>) -> vector<8xi32> + // %15 = "accv.vpmaddwd"(%11, %cst) : (vector<16xi16>, vector<16xi16>) -> vector<8xi32> + // %16 = "accv.vhadd"(%12, %13) : (vector<8xi32>, vector<8xi32>) -> vector<8xi32> + // %17 = "accv.vhadd"(%14, %15) : (vector<8xi32>, vector<8xi32>) -> vector<8xi32> + // %18 = "accv.vhadd"(%16, %17) : (vector<8xi32>, vector<8xi32>) -> vector<8xi32> + // %19 = vector.shuffle %18, %18 [0, 1, 2, 3] : vector<8xi32>, vector<8xi32> + // %20 = vector.shuffle %18, %18 [4, 5, 6, 7] : vector<8xi32>, vector<8xi32> + // %21 = "accv.bin_op"(%19, %20) {predicate = 0 : i64} : (vector<4xi32>, vector<4xi32>) -> vector<4xi32> + // %22 = memref.reinterpret_cast %arg1 to offset: [0], sizes: [256], strides: [1] : memref<256xi32> to memref<256xi32> loc(unknown) + // %23 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%arg2, %c0) + // %24 = vector.load %22[%23] : memref<256xi32>, vector<4xi32> + // %25 = "accv.bin_op"(%24, %21) {predicate = 0 : i64} : (vector<4xi32>, vector<4xi32>) -> vector<4xi32> + // %26 = memref.reinterpret_cast %arg1 to offset: [0], sizes: [256], strides: [1] : memref<256xi32> to memref<256xi32> loc(unknown) + // %27 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%arg2, %c0) + // vector.store %25, %26[%27] : memref<256xi32>, vector<4xi32> + affine.for %arg3 = 0 to 4 { + affine.for %arg4 = 0 to 16 { + %0 = affine.load %arg0[%arg2 + %arg3, %arg4] : memref<256x16xi16> + %1 = "accv.cast"(%0) : (i16) -> i32 + %2 = affine.load %arg1[%arg2 + %arg3] : memref<256xi32> + %3 = "accv.bin_op"(%2, %1) {predicate = 0 : i64} : (i32, i32) -> i32 + affine.store %3, %arg1[%arg2 + %arg3] : memref<256xi32> + } {beginMap = affine_map<() -> (0)>, endMap = affine_map<() -> (16)>, index = #accln<"index{j,1}">, kernels = ["_"], operand_segment_sizes = dense<[0, 0, 1]> : vector<3xi32>, subdomainIndexOrder = [#accln<"index{i,0}">, #accln<"index{j,1}">], subdomainSize = [256, 16]} + } {accxp_vectorizationInfo = #accxp<"vectorizationinfo{32,16,0}">, beginMap = affine_map<() -> (0)>, endMap = affine_map<() -> (4)>, index = #accln<"index{i_i,3}">, operand_segment_sizes = dense<[0, 0, 1]> : vector<3xi32>, scheduledIndex = #accln<"index{i_i,3}">, subdomainIndexOrder = [#accln<"index{i,0}">, #accln<"index{j,1}">], subdomainSize = [256, 16]} + } {beginMap = affine_map<() -> (0)>, endMap = affine_map<() -> (256)>, index = #accln<"index{i_o,2}">, operand_segment_sizes = dense<[0, 0, 1]> : vector<3xi32>, subdomainIndexOrder = [#accln<"index{i,0}">, #accln<"index{j,1}">], subdomainSize = [256, 16]} + return + } + + // CHECK-LABEL builtin.func nested @test_int32_horizontal_vector_add_simple + builtin.func nested @test_int32_horizontal_vector_add_simple(%arg0: memref<256x8xi32> loc(unknown), %arg1: memref<256xi32> loc(unknown)) attributes {accv.dyn_arg_size_refs = [[-1, -1], [-1]], accv.usages = [1 : i8, 1 : i8], args_name = ["", ""], args_size = ["256*8", "256"], args_symbol = ["args_symbol_name_0", "args_symbol_name_1"], exec_target = 0 : i64} { + // CHECK: affine.for %arg2 = 0 to 256 step 4 { + affine.for %arg2 = 0 to 256 step 4 { + // %0 = memref.reinterpret_cast %arg0 to offset: [0], sizes: [2048], strides: [1] : memref<256x8xi32> to memref<2048xi32> loc(unknown) + // %1 = affine.apply affine_map<(d0, d1, d2) -> ((d1 + d2) * 8 + d0)>(%c0, %arg2, %c0) + // %2 = vector.load %0[%1] : memref<2048xi32>, vector<8xi32> + // %3 = memref.reinterpret_cast %arg0 to offset: [0], sizes: [2048], strides: [1] : memref<256x8xi32> to memref<2048xi32> loc(unknown) + // %4 = affine.apply affine_map<(d0, d1, d2) -> ((d1 + d2) * 8 + d0)>(%c0, %arg2, %c1) + // %5 = vector.load %3[%4] : memref<2048xi32>, vector<8xi32> + // %6 = memref.reinterpret_cast %arg0 to offset: [0], sizes: [2048], strides: [1] : memref<256x8xi32> to memref<2048xi32> loc(unknown) + // %7 = affine.apply affine_map<(d0, d1, d2) -> ((d1 + d2) * 8 + d0)>(%c0, %arg2, %c2) + // %8 = vector.load %6[%7] : memref<2048xi32>, vector<8xi32> + // %9 = memref.reinterpret_cast %arg0 to offset: [0], sizes: [2048], strides: [1] : memref<256x8xi32> to memref<2048xi32> loc(unknown) + // %10 = affine.apply affine_map<(d0, d1, d2) -> ((d1 + d2) * 8 + d0)>(%c0, %arg2, %c3) + // %11 = vector.load %9[%10] : memref<2048xi32>, vector<8xi32> + // %12 = "accv.vhadd"(%2, %5) : (vector<8xi32>, vector<8xi32>) -> vector<8xi32> + // %13 = "accv.vhadd"(%8, %11) : (vector<8xi32>, vector<8xi32>) -> vector<8xi32> + // %14 = "accv.vhadd"(%12, %13) : (vector<8xi32>, vector<8xi32>) -> vector<8xi32> + // %15 = vector.shuffle %14, %14 [0, 1, 2, 3] : vector<8xi32>, vector<8xi32> + // %16 = vector.shuffle %14, %14 [4, 5, 6, 7] : vector<8xi32>, vector<8xi32> + // %17 = "accv.bin_op"(%15, %16) {predicate = 0 : i64} : (vector<4xi32>, vector<4xi32>) -> vector<4xi32> + // %18 = memref.reinterpret_cast %arg1 to offset: [0], sizes: [256], strides: [1] : memref<256xi32> to memref<256xi32> loc(unknown) + // %19 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%arg2, %c0) + // %20 = vector.load %18[%19] : memref<256xi32>, vector<4xi32> + // %21 = "accv.bin_op"(%20, %17) {predicate = 0 : i64} : (vector<4xi32>, vector<4xi32>) -> vector<4xi32> + // %22 = memref.reinterpret_cast %arg1 to offset: [0], sizes: [256], strides: [1] : memref<256xi32> to memref<256xi32> loc(unknown) + // %23 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%arg2, %c0) + // vector.store %21, %22[%23] : memref<256xi32>, vector<4xi32> + affine.for %arg3 = 0 to 4 { + affine.for %arg4 = 0 to 8 { + %0 = affine.load %arg1[%arg2 + %arg3] : memref<256xi32> + %1 = affine.load %arg0[%arg2 + %arg3, %arg4] : memref<256x8xi32> + %2 = "accv.bin_op"(%0, %1) {predicate = 0 : i64} : (i32, i32) -> i32 + affine.store %2, %arg1[%arg2 + %arg3] : memref<256xi32> + } {beginMap = affine_map<() -> (0)>, endMap = affine_map<() -> (8)>, index = #accln<"index{j,1}">, kernels = ["_"], operand_segment_sizes = dense<[0, 0, 1]> : vector<3xi32>, subdomainIndexOrder = [#accln<"index{i,0}">, #accln<"index{j,1}">], subdomainSize = [256, 8]} + } {accxp_vectorizationInfo = #accxp<"vectorizationinfo{32,16,0}">, beginMap = affine_map<() -> (0)>, endMap = affine_map<() -> (4)>, index = #accln<"index{i_i,3}">, operand_segment_sizes = dense<[0, 0, 1]> : vector<3xi32>, scheduledIndex = #accln<"index{i_i,3}">, subdomainIndexOrder = [#accln<"index{i,0}">, #accln<"index{j,1}">], subdomainSize = [256, 8]} + } {beginMap = affine_map<() -> (0)>, endMap = affine_map<() -> (256)>, index = #accln<"index{i_o,2}">, operand_segment_sizes = dense<[0, 0, 1]> : vector<3xi32>, subdomainIndexOrder = [#accln<"index{i,0}">, #accln<"index{j,1}">], subdomainSize = [256, 8]} + return + } + } +} + +// ----- + +module @test_transpose_8x4 attributes {accv.target_device_features = "+avx2", llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"} { + accv.module "test_transpose_8x4" { + + // CHECK-LABEL builtin.func nested @test_transpose_8x4_423849238228f332_impl_17576985214141312005(%arg0: memref<8x4xf32>, %arg1: memref<4x8xf32>) { + // CHECK-NEXT %c0 = arith.constant 0 : index + // CHECK-NEXT %c4 = arith.constant 4 : index + // CHECK-NEXT %c8 = arith.constant 8 : index + // CHECK-NEXT %c12 = arith.constant 12 : index + // CHECK-NEXT %c16 = arith.constant 16 : index + // CHECK-NEXT %c20 = arith.constant 20 : index + // CHECK-NEXT %c24 = arith.constant 24 : index + // CHECK-NEXT %c28 = arith.constant 28 : index + // CHECK-NEXT %0 = memref.reinterpret_cast %arg0 to offset: [0], sizes: [32], strides: [1] : memref<8x4xf32> to memref<32xf32> + // CHECK-NEXT %1 = vector.load %0[%c0] : memref<32xf32>, vector<4xf32> + // CHECK-NEXT %2 = memref.reinterpret_cast %arg0 to offset: [0], sizes: [32], strides: [1] : memref<8x4xf32> to memref<32xf32> + // CHECK-NEXT %3 = vector.load %2[%c4] : memref<32xf32>, vector<4xf32> + // CHECK-NEXT %4 = memref.reinterpret_cast %arg0 to offset: [0], sizes: [32], strides: [1] : memref<8x4xf32> to memref<32xf32> + // CHECK-NEXT %5 = vector.load %4[%c8] : memref<32xf32>, vector<4xf32> + // CHECK-NEXT %6 = memref.reinterpret_cast %arg0 to offset: [0], sizes: [32], strides: [1] : memref<8x4xf32> to memref<32xf32> + // CHECK-NEXT %7 = vector.load %6[%c12] : memref<32xf32>, vector<4xf32> + // CHECK-NEXT %8 = memref.reinterpret_cast %arg0 to offset: [0], sizes: [32], strides: [1] : memref<8x4xf32> to memref<32xf32> + // CHECK-NEXT %9 = vector.load %8[%c16] : memref<32xf32>, vector<4xf32> + // CHECK-NEXT %10 = memref.reinterpret_cast %arg0 to offset: [0], sizes: [32], strides: [1] : memref<8x4xf32> to memref<32xf32> + // CHECK-NEXT %11 = vector.load %10[%c20] : memref<32xf32>, vector<4xf32> + // CHECK-NEXT %12 = memref.reinterpret_cast %arg0 to offset: [0], sizes: [32], strides: [1] : memref<8x4xf32> to memref<32xf32> + // CHECK-NEXT %13 = vector.load %12[%c24] : memref<32xf32>, vector<4xf32> + // CHECK-NEXT %14 = memref.reinterpret_cast %arg0 to offset: [0], sizes: [32], strides: [1] : memref<8x4xf32> to memref<32xf32> + // CHECK-NEXT %15 = vector.load %14[%c28] : memref<32xf32>, vector<4xf32> + // CHECK-NEXT %16 = vector.shuffle %1, %9 [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> + // CHECK-NEXT %17 = vector.shuffle %3, %11 [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> + // CHECK-NEXT %18 = vector.shuffle %5, %13 [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> + // CHECK-NEXT %19 = vector.shuffle %7, %15 [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> + // CHECK-NEXT %20 = vector.shuffle %16, %17 [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT %21 = vector.shuffle %16, %17 [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT %22 = vector.shuffle %18, %19 [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT %23 = vector.shuffle %18, %19 [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT %24 = vector.shuffle %20, %22 [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT %25 = vector.shuffle %20, %22 [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT %26 = vector.shuffle %21, %23 [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT %27 = vector.shuffle %21, %23 [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT %28 = memref.reinterpret_cast %arg1 to offset: [0], sizes: [32], strides: [1] : memref<4x8xf32> to memref<32xf32> + // CHECK-NEXT vector.store %24, %28[%c0] : memref<32xf32>, vector<8xf32> + // CHECK-NEXT %29 = memref.reinterpret_cast %arg1 to offset: [0], sizes: [32], strides: [1] : memref<4x8xf32> to memref<32xf32> + // CHECK-NEXT vector.store %25, %29[%c8] : memref<32xf32>, vector<8xf32> + // CHECK-NEXT %30 = memref.reinterpret_cast %arg1 to offset: [0], sizes: [32], strides: [1] : memref<4x8xf32> to memref<32xf32> + // CHECK-NEXT vector.store %26, %30[%c16] : memref<32xf32>, vector<8xf32> + // CHECK-NEXT %31 = memref.reinterpret_cast %arg1 to offset: [0], sizes: [32], strides: [1] : memref<4x8xf32> to memref<32xf32> + // CHECK-NEXT vector.store %27, %31[%c24] : memref<32xf32>, vector<8xf32> + // CHECK-NEXT return + + builtin.func nested @test_transpose_8x4_423849238228f332_impl_17576985214141312005(%arg0: memref<8x4xf32>, %arg1: memref<4x8xf32>) { + affine.for %arg2 = 0 to 8 { + affine.for %arg3 = 0 to 4 { + %0 = affine.load %arg0[%arg2, %arg3] : memref<8x4xf32> + affine.store %0, %arg1[%arg3, %arg2] : memref<4x8xf32> + } {accxp_vectorizationInfo = #accxp<"vectorizationinfo{32,16,0}">} + } {accxp_vectorizationInfo = #accxp<"vectorizationinfo{32,16,0}">} + return + } + } +} + +// ----- + +module @test_transpose_16x4 attributes {accv.target_device_features = "+avx2", llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"} { + accv.module "test_transpose_16x4" { + + // CHECK-LABEL builtin.func nested @test_transpose_16x4_d3ea7863380b434f_impl_10155054494031908713(%arg0: memref<16x4xf32>, %arg1: memref<4x16xf32>) { + // CHECK-NEXT %c0 = arith.constant 0 : index + // CHECK-NEXT %c1 = arith.constant 1 : index + // CHECK-NEXT %c2 = arith.constant 2 : index + // CHECK-NEXT %c3 = arith.constant 3 : index + // CHECK-NEXT %c4 = arith.constant 4 : index + // CHECK-NEXT %c5 = arith.constant 5 : index + // CHECK-NEXT %c6 = arith.constant 6 : index + // CHECK-NEXT %c7 = arith.constant 7 : index + // CHECK-NEXT affine.for %arg2 = 0 to 16 step 8 { + // CHECK-NEXT %0 = memref.reinterpret_cast %arg0 to offset: [0], sizes: [64], strides: [1] : memref<16x4xf32> to memref<64xf32> + // CHECK-NEXT %1 = affine.apply #map0(%arg2, %c0, %c0) + // CHECK-NEXT %2 = vector.load %0[%1] : memref<64xf32>, vector<4xf32> + // CHECK-NEXT %3 = memref.reinterpret_cast %arg0 to offset: [0], sizes: [64], strides: [1] : memref<16x4xf32> to memref<64xf32> + // CHECK-NEXT %4 = affine.apply #map0(%arg2, %c1, %c0) + // CHECK-NEXT %5 = vector.load %3[%4] : memref<64xf32>, vector<4xf32> + // CHECK-NEXT %6 = memref.reinterpret_cast %arg0 to offset: [0], sizes: [64], strides: [1] : memref<16x4xf32> to memref<64xf32> + // CHECK-NEXT %7 = affine.apply #map0(%arg2, %c2, %c0) + // CHECK-NEXT %8 = vector.load %6[%7] : memref<64xf32>, vector<4xf32> + // CHECK-NEXT %9 = memref.reinterpret_cast %arg0 to offset: [0], sizes: [64], strides: [1] : memref<16x4xf32> to memref<64xf32> + // CHECK-NEXT %10 = affine.apply #map0(%arg2, %c3, %c0) + // CHECK-NEXT %11 = vector.load %9[%10] : memref<64xf32>, vector<4xf32> + // CHECK-NEXT %12 = memref.reinterpret_cast %arg0 to offset: [0], sizes: [64], strides: [1] : memref<16x4xf32> to memref<64xf32> + // CHECK-NEXT %13 = affine.apply #map0(%arg2, %c4, %c0) + // CHECK-NEXT %14 = vector.load %12[%13] : memref<64xf32>, vector<4xf32> + // CHECK-NEXT %15 = memref.reinterpret_cast %arg0 to offset: [0], sizes: [64], strides: [1] : memref<16x4xf32> to memref<64xf32> + // CHECK-NEXT %16 = affine.apply #map0(%arg2, %c5, %c0) + // CHECK-NEXT %17 = vector.load %15[%16] : memref<64xf32>, vector<4xf32> + // CHECK-NEXT %18 = memref.reinterpret_cast %arg0 to offset: [0], sizes: [64], strides: [1] : memref<16x4xf32> to memref<64xf32> + // CHECK-NEXT %19 = affine.apply #map0(%arg2, %c6, %c0) + // CHECK-NEXT %20 = vector.load %18[%19] : memref<64xf32>, vector<4xf32> + // CHECK-NEXT %21 = memref.reinterpret_cast %arg0 to offset: [0], sizes: [64], strides: [1] : memref<16x4xf32> to memref<64xf32> + // CHECK-NEXT %22 = affine.apply #map0(%arg2, %c7, %c0) + // CHECK-NEXT %23 = vector.load %21[%22] : memref<64xf32>, vector<4xf32> + // CHECK-NEXT %24 = vector.shuffle %2, %14 [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> + // CHECK-NEXT %25 = vector.shuffle %5, %17 [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> + // CHECK-NEXT %26 = vector.shuffle %8, %20 [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> + // CHECK-NEXT %27 = vector.shuffle %11, %23 [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> + // CHECK-NEXT %28 = vector.shuffle %24, %25 [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT %29 = vector.shuffle %24, %25 [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT %30 = vector.shuffle %26, %27 [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT %31 = vector.shuffle %26, %27 [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT %32 = vector.shuffle %28, %30 [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT %33 = vector.shuffle %28, %30 [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT %34 = vector.shuffle %29, %31 [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT %35 = vector.shuffle %29, %31 [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT %36 = memref.reinterpret_cast %arg1 to offset: [0], sizes: [64], strides: [1] : memref<4x16xf32> to memref<64xf32> + // CHECK-NEXT %37 = affine.apply #map1(%c0, %arg2, %c0) + // CHECK-NEXT vector.store %32, %36[%37] : memref<64xf32>, vector<8xf32> + // CHECK-NEXT %38 = memref.reinterpret_cast %arg1 to offset: [0], sizes: [64], strides: [1] : memref<4x16xf32> to memref<64xf32> + // CHECK-NEXT %39 = affine.apply #map1(%c1, %arg2, %c0) + // CHECK-NEXT vector.store %33, %38[%39] : memref<64xf32>, vector<8xf32> + // CHECK-NEXT %40 = memref.reinterpret_cast %arg1 to offset: [0], sizes: [64], strides: [1] : memref<4x16xf32> to memref<64xf32> + // CHECK-NEXT %41 = affine.apply #map1(%c2, %arg2, %c0) + // CHECK-NEXT vector.store %34, %40[%41] : memref<64xf32>, vector<8xf32> + // CHECK-NEXT %42 = memref.reinterpret_cast %arg1 to offset: [0], sizes: [64], strides: [1] : memref<4x16xf32> to memref<64xf32> + // CHECK-NEXT %43 = affine.apply #map1(%c3, %arg2, %c0) + // CHECK-NEXT vector.store %35, %42[%43] : memref<64xf32>, vector<8xf32> + // CHECK-NEXT } + // CHECK-NEXT return + + builtin.func nested @test_transpose_16x4_d3ea7863380b434f_impl_10155054494031908713(%arg0: memref<16x4xf32>, %arg1: memref<4x16xf32>) { + affine.for %arg2 = 0 to 16 step 8 { + affine.for %arg3 = 0 to 8 { + affine.for %arg4 = 0 to 4 { + %0 = affine.load %arg0[%arg2 + %arg3, %arg4] : memref<16x4xf32> + affine.store %0, %arg1[%arg4, %arg2 + %arg3] : memref<4x16xf32> + } {accxp_vectorizationInfo = #accxp<"vectorizationinfo{32,16,0}">} + } {accxp_vectorizationInfo = #accxp<"vectorizationinfo{32,16,0}">} + } + return } + } } diff --git a/accera/ir/include/IRUtil.h b/accera/ir/include/IRUtil.h index c4e37f7e..88937dd0 100644 --- a/accera/ir/include/IRUtil.h +++ b/accera/ir/include/IRUtil.h @@ -463,5 +463,7 @@ namespace util std::vector GetDynamicOffsetSymbols(mlir::Value val); + bool AncestorOpContainsAttrOfName(mlir::Operation* op, const mlir::StringRef& name); + } // namespace util } // namespace accera::ir diff --git a/accera/ir/include/exec/ExecutionPlanOps.h b/accera/ir/include/exec/ExecutionPlanOps.h index 98af1caa..18bc1702 100644 --- a/accera/ir/include/exec/ExecutionPlanOps.h +++ b/accera/ir/include/exec/ExecutionPlanOps.h @@ -100,9 +100,6 @@ namespace executionPlan namespace accera::ir::executionPlan { -// Unit attr name for controlling whether bounds checking is done for ops within a marked op -const mlir::StringRef AccessBoundsCheckAttrName = "accxp.access_bounds_check"; - // // Utility functions and EDSC-type intrinsics // diff --git a/accera/ir/include/value/ValueOps.td b/accera/ir/include/value/ValueOps.td index cf0d061e..b57a37c2 100644 --- a/accera/ir/include/value/ValueOps.td +++ b/accera/ir/include/value/ValueOps.td @@ -1577,5 +1577,42 @@ def accv_vminps : accv_Op<"vminps", [NoSideEffect]>{ let results = (outs AnyVector:$result); } +def accv_vhadd : accv_Op<"vhadd", [NoSideEffect, SameOperandsAndResultShape, SameOperandsAndResultType]>{ + let summary = "Vector horizontal interleaved add operation"; + + let description = [{ + The `accv.vhadd` operation lowers to differently sized vector instructions depending on the element type in the vector operands. + + For 32-bit operands the interleaving and adding follows the pattern: + vhadd( A[0...7], B[0...7] ) -> + [ A[0]+A[1], + A[2]+A[3], + B[0]+B[1], + B[2]+B[3], + A[4]+A[5], + A[6]+A[7], + B[4]+B[5], + B[6]+B[7] ] + + For different bit-width operands, the corresponding byte positions are kept consistent as with the 32-bit operands. + i.e. for i16, the first 4 elements are the pairwise sums of the first 8 elements of A, as opposed to 2 and 4 with i32's + for f64, the first 1 element is the sum of the first 2 elements of A, as opposed to 2 and 4 with f32's + + Supported operand / result types and their corresponding AVX instructions: + Operand / Result Type | Instruction + ----------------------|------------- + vector<8xi32> | vphaddd + vector<16xi16> | vphaddw + vector<8xf32> | vhaddps + vector<4xf64> | vhaddpd + + Note: this lowers to MLIR vector dialect ops, so a particular target architecture is not required, + and instructions other than those listed above are possible on other architectures + }]; + + let arguments = (ins AnyVector:$lhs, AnyVector:$rhs); + let results = (outs AnyVector:$result); +} + #endif // ACCERA_accv_OPS diff --git a/accera/ir/src/IRUtil.cpp b/accera/ir/src/IRUtil.cpp index 31568dfa..1c983048 100644 --- a/accera/ir/src/IRUtil.cpp +++ b/accera/ir/src/IRUtil.cpp @@ -1540,5 +1540,18 @@ namespace util return offsetSymbols; } + bool AncestorOpContainsAttrOfName(mlir::Operation* op, const mlir::StringRef& name) + { + while (op != nullptr) + { + if (op->getAttr(name) != nullptr) + { + return true; + } + op = op->getParentOp(); + } + return false; + } + } // namespace util } // namespace accera::ir diff --git a/accera/python/accera/test/dsl_tests.py b/accera/python/accera/test/dsl_tests.py index b6019958..752ab723 100644 --- a/accera/python/accera/test/dsl_tests.py +++ b/accera/python/accera/test/dsl_tests.py @@ -65,7 +65,7 @@ def _get_test_mode(correctness_check: bool = False): class DSLTest_01Arrays(unittest.TestCase): - def _verify_nest(self, nest, args: Tuple[Array], package_name, correctness_check_values=None) -> None: + def _verify_nest(self, nest, args: Tuple[Array], package_name, correctness_check_values=None, quiet=True) -> None: # create a HAT package and add the function to it package = Package() @@ -74,7 +74,7 @@ def _verify_nest(self, nest, args: Tuple[Array], package_name, correctness_check # build the HAT package with verifiers.VerifyPackage(self, package_name, output_dir) as v: - package.build(package_name, format=TEST_FORMAT, mode=_get_test_mode(correctness_check_values), output_dir=output_dir) + package.build(package_name, format=TEST_FORMAT, mode=_get_test_mode(correctness_check_values), output_dir=output_dir, _quiet=quiet) if correctness_check_values: v.check_correctness( function.name, @@ -667,13 +667,13 @@ def main(array): package.add(main, args=(arr, )) package_name = "test_reinterpret_cast" - - with verifiers.VerifyPackage(self, package_name, TEST_PACKAGE_DIR): + output_dir = pathlib.Path(TEST_PACKAGE_DIR) / package_name + with verifiers.VerifyPackage(self, package_name, output_dir): package.build( package_name, format=TEST_FORMAT, mode=Package.Mode.RELEASE, - output_dir=TEST_PACKAGE_DIR, + output_dir=output_dir, _quiet=False ) @@ -6131,6 +6131,7 @@ def _(): function.name, before=[A_test, B_test, C_test], after=[A_test, B_test, (C_test + A_test) * B_test - 1.0], + tolerance=1e-4 ) def test_debug_mode_fusion_cascading_2(self) -> None: diff --git a/accera/python/accera/test/smoke_tests.py b/accera/python/accera/test/smoke_tests.py index 0f638ca7..113e2272 100644 --- a/accera/python/accera/test/smoke_tests.py +++ b/accera/python/accera/test/smoke_tests.py @@ -5,6 +5,7 @@ #################################################################################################### import inspect +from itertools import product import os import sys import unittest @@ -42,7 +43,10 @@ DEV_MODE = True sys.path.insert(1, os.getcwd()) -INTERNAL_FUNCTION_OPTS = { "no_inline_into": True, "public": False } +INTERNAL_FUNCTION_OPTS = { + "no_inline_into": True, + "public": False +} from accera import Package, ScalarType, Nest, Array, Constants, Scalar, fuse, create_parameters, cast, Target, Role from accera._lang_python._lang import _MemorySpace, _MMAShape, Dimension @@ -210,7 +214,6 @@ def _(): with verifiers.VerifyPackage(self, package_name, TEST_PACKAGE_DIR): package.build(package_name, format=self.PACKAGE_FORMAT, mode=self.PACKAGE_MODE, output_dir=TEST_PACKAGE_DIR) - def test_differently_split_fused_schedules(self) -> None: # Split a dimension twice in one schedule and once in another schedule, then fuse the outermost split indices @@ -246,7 +249,6 @@ def _(): ii1 = schedule1.split(i1, 16) schedule1.reorder(i1, j1, ii1) - schedule = fuse((schedule0, schedule1), partial=2) plan = schedule.create_plan() @@ -259,7 +261,6 @@ def _(): with verifiers.VerifyPackage(self, package_name, TEST_PACKAGE_DIR): package.build(package_name, format=self.PACKAGE_FORMAT, mode=self.PACKAGE_MODE, output_dir=TEST_PACKAGE_DIR) - def test_partial_fusion_matmul3_naive(self) -> None: A = Array(role=Role.INPUT, element_type=ScalarType.float32, shape=(16, 11)) B = Array(role=Role.INPUT, element_type=ScalarType.float32, shape=(11, 10)) @@ -518,6 +519,102 @@ def test_mlas_matmul(self) -> None: with verifiers.VerifyPackage(self, package_name, TEST_PACKAGE_DIR): package.build(package_name, format=self.PACKAGE_FORMAT, mode=self.PACKAGE_MODE, output_dir=TEST_PACKAGE_DIR) + def _test_transpose_MxN(self, M: int, N: int): + In = Array(role=Role.INPUT, element_type=ScalarType.float32, shape=(M, N)) + Out = Array(role=Role.INPUT_OUTPUT, element_type=ScalarType.float32, shape=(N, M)) + + nest = Nest(shape=(M, N)) + i, j = nest.get_indices() + + @nest.iteration_logic + def _(): + Out[j, i] = In[i, j] + + sched = nest.create_schedule() + ii = sched.split(i, 8) + jj = sched.split(j, 4) + sched.reorder(i, j, ii, jj) + + plan = sched.create_plan() + plan.vectorize(ii) + plan.vectorize(jj) + + In_test = np.arange(stop=M * N, dtype=np.float32).reshape(M, N) + Out_test = np.random.rand(N, M).astype(np.float32) + Out_ref = In_test.T + + # Create a package and add our function definition to it + package_name = f"test_transpose_{M}x{N}" + package = Package() + function = package.add(plan, args=(In, Out), base_name=f"test_transpose_{M}x{N}") + + # Build the HAT package + with verifiers.VerifyPackage(self, package_name, TEST_PACKAGE_DIR) as v: + package.build( + package_name, + format=self.PACKAGE_FORMAT, + mode=self.PACKAGE_MODE, + output_dir=TEST_PACKAGE_DIR, + _quiet=False + ) + + v.check_correctness(function.name, before=(In_test, Out_test), after=(In_test, Out_ref)) + + def test_transpose_8x4(self): + self._test_transpose_MxN(8, 4) + + def test_transpose_16x4(self): + self._test_transpose_MxN(16, 4) + + def test_transpose_32x32(self): + self._test_transpose_MxN(32, 32) + + def test_transpose_31x31(self): + self._test_transpose_MxN(31, 31) + + def test_transpose_16x4_packed(self): + In = Array(role=Role.INPUT, element_type=ScalarType.float32, shape=(16, 4)) + Out = Array(role=Role.INPUT_OUTPUT, element_type=ScalarType.float32, shape=(4, 2, 8)) + + nest = Nest(shape=(16, 4)) + i, j = nest.get_indices() + + @nest.iteration_logic + def _(): + Out[j, (i / 8) % 2, i % 8] = In[i, j] + + sched = nest.create_schedule() + ii = sched.split(i, 8) + + plan = sched.create_plan() + plan.unroll(i) + plan.vectorize(ii) + plan.vectorize(j) + + In_test = np.arange(stop=16 * 4, dtype=np.float32).reshape(16, 4) + Out_test = np.random.rand(4, 2, 8).astype(np.float32) + Out_ref = np.ndarray(shape=Out_test.shape, dtype=Out_test.dtype) + for i_ in range(16): + for j_ in range(4): + Out_ref[j_, (i_ // 8) % 2, i_ % 8] = In_test[i_, j_] + + # Create a package and add our function definition to it + package_name = "test_transpose_16x4_packed" + package = Package() + function = package.add(plan, args=(In, Out), base_name="test_transpose_16x4_packed") + + # Build the HAT package + with verifiers.VerifyPackage(self, package_name, TEST_PACKAGE_DIR) as v: + package.build( + package_name, + format=self.PACKAGE_FORMAT | Package.Format.MLIR_VERBOSE, + mode=self.PACKAGE_MODE, + output_dir=TEST_PACKAGE_DIR, + _quiet=False + ) + + v.check_correctness(function.name, before=(In_test, Out_test), after=(In_test, Out_ref)) + def test_emittime_cache_mlas_matmul(self) -> None: from accera.samples.OfflineCacheMatrixMultiplication import EmitTimeCacheMLAS @@ -653,6 +750,7 @@ def round_up(number, multiple): @nest.iteration_logic def _(): + def if_block(): C[i, j] += A[i, k] * B[k, j] @@ -1196,32 +1294,35 @@ def test_offset_sub_array_packing_flat(self) -> None: test_name = "test_offset_sub_array_packing_flat" N = 8 - output_size = N*N + N + output_size = N * N + N Input = Array(role=Role.INPUT, element_type=ScalarType.float32, shape=(N, N)) - Output = Array(role=Role.INPUT_OUTPUT, element_type=ScalarType.float32, shape=(output_size,)) + Output = Array(role=Role.INPUT_OUTPUT, element_type=ScalarType.float32, shape=(output_size, )) package = Package() - diagonal_fetch_nest = Nest(shape=(N,)) + diagonal_fetch_nest = Nest(shape=(N, )) diagonal_idx, = diagonal_fetch_nest.get_indices() + @diagonal_fetch_nest.iteration_logic def _diag_fetch(): - diag_vec = Output.sub_array(offsets=(0,), shape=(N,)) + diag_vec = Output.sub_array(offsets=(0, ), shape=(N, )) diag_vec[diagonal_idx] = Input[diagonal_idx, diagonal_idx] diag_fn = package.add(diagonal_fetch_nest, args=(Input, Output), base_name="diagonal_fetch_fn") transpose_nest = Nest(shape=(N, N)) transpose_i, transpose_j = transpose_nest.get_indices() + @transpose_nest.iteration_logic def _transpose(): - packed_output = Output.sub_array(offsets=(N,), shape=(N*N,)) - packed_output[transpose_j*N + transpose_i] = Input[transpose_i, transpose_j] + packed_output = Output.sub_array(offsets=(N, ), shape=(N * N, )) + packed_output[transpose_j * N + transpose_i] = Input[transpose_i, transpose_j] transpose_fn = package.add(transpose_nest, args=(Input, Output), base_name="transpose_fn") - outer_nest = Nest(shape=(1,)) + outer_nest = Nest(shape=(1, )) + @outer_nest.iteration_logic def _(): diag_fn(Input, Output) @@ -1237,12 +1338,12 @@ def _(): # correctness check test_input = np.random.random([N, N]).astype(np.float32) - test_output = np.random.random([N*N + N]).astype(np.float32) - test_output_ref = np.random.random([N*N + N]).astype(np.float32) + test_output = np.random.random([N * N + N]).astype(np.float32) + test_output_ref = np.random.random([N * N + N]).astype(np.float32) for i in range(N): test_output_ref[i] = test_input[i, i] for j in range(N): - test_output_ref[N + (i*N + j)] = test_input[j, i] + test_output_ref[N + (i * N + j)] = test_input[j, i] v.check_correctness(function.name, before=(test_input, test_output), after=(test_input, test_output_ref)) def test_offset_sub_array_packing_split_dim(self) -> None: @@ -1254,33 +1355,36 @@ def test_offset_sub_array_packing_split_dim(self) -> None: test_name = "test_offset_sub_array_packing_split_dim" N = 4 - output_size = N*N + N + output_size = N * N + N Input = Array(role=Role.INPUT, element_type=ScalarType.float32, shape=(N, N)) - Output = Array(role=Role.INPUT_OUTPUT, element_type=ScalarType.float32, shape=(output_size,)) + Output = Array(role=Role.INPUT_OUTPUT, element_type=ScalarType.float32, shape=(output_size, )) package = Package() - diagonal_fetch_nest = Nest(shape=(N,)) + diagonal_fetch_nest = Nest(shape=(N, )) diagonal_idx, = diagonal_fetch_nest.get_indices() + @diagonal_fetch_nest.iteration_logic def _diag_fetch(): - diag_vec = Output.sub_array(offsets=(0,), shape=(N,)) + diag_vec = Output.sub_array(offsets=(0, ), shape=(N, )) diag_vec[diagonal_idx] = Input[diagonal_idx, diagonal_idx] diag_fn = package.add(diagonal_fetch_nest, args=(Input, Output), base_name="diagonal_fetch_fn") transpose_nest = Nest(shape=(N, N)) transpose_i, transpose_j = transpose_nest.get_indices() + @transpose_nest.iteration_logic def _transpose(): - packed_output = Output.sub_array(offsets=(N,), shape=(N*N,)) + packed_output = Output.sub_array(offsets=(N, ), shape=(N * N, )) packed_output_split = packed_output._split_dimension(0, cast(N, ScalarType.index)) packed_output_split[transpose_j, transpose_i] = Input[transpose_i, transpose_j] transpose_fn = package.add(transpose_nest, args=(Input, Output), base_name="transpose_fn") - outer_nest = Nest(shape=(1,)) + outer_nest = Nest(shape=(1, )) + @outer_nest.iteration_logic def _(): diag_fn(Input, Output) @@ -1296,12 +1400,12 @@ def _(): # correctness check test_input = np.random.random([N, N]).astype(np.float32) - test_output = np.random.random([N*N + N]).astype(np.float32) - test_output_ref = np.random.random([N*N + N]).astype(np.float32) + test_output = np.random.random([N * N + N]).astype(np.float32) + test_output_ref = np.random.random([N * N + N]).astype(np.float32) for i in range(N): test_output_ref[i] = test_input[i, i] for j in range(N): - test_output_ref[N + (i*N + j)] = test_input[j, i] + test_output_ref[N + (i * N + j)] = test_input[j, i] v.check_correctness(function.name, before=(test_input, test_output), after=(test_input, test_output_ref)) def test_offset_sub_array_packing_multiple_split_dim(self) -> None: @@ -1313,28 +1417,30 @@ def test_offset_sub_array_packing_multiple_split_dim(self) -> None: test_name = "test_offset_sub_array_packing_multiple_split_dim" N = 4 N_inner = 2 - output_size = N*N + N + output_size = N * N + N Input = Array(role=Role.INPUT, element_type=ScalarType.float32, shape=(N, N)) - Output = Array(role=Role.INPUT_OUTPUT, element_type=ScalarType.float32, shape=(output_size,)) + Output = Array(role=Role.INPUT_OUTPUT, element_type=ScalarType.float32, shape=(output_size, )) package = Package() - diagonal_fetch_nest = Nest(shape=(N,)) + diagonal_fetch_nest = Nest(shape=(N, )) diagonal_idx, = diagonal_fetch_nest.get_indices() + @diagonal_fetch_nest.iteration_logic def _diag_fetch(): - diag_vec = Output.sub_array(offsets=(0,), shape=(N,)) + diag_vec = Output.sub_array(offsets=(0, ), shape=(N, )) diag_vec[diagonal_idx] = Input[diagonal_idx, diagonal_idx] diag_fn = package.add(diagonal_fetch_nest, args=(Input, Output), base_name="diagonal_fetch_fn") transpose_nest = Nest(shape=(N // N_inner, N // N_inner, N_inner, N_inner)) transpose_i, transpose_j, transpose_ii, transpose_jj = transpose_nest.get_indices() + @transpose_nest.iteration_logic def _transpose(): # packed_output is an offset vector with shape [ 16 ] - packed_output = Output.sub_array(offsets=(N,), shape=(N*N,)) + packed_output = Output.sub_array(offsets=(N, ), shape=(N * N, )) # packed_output_split_0 is an offset array with shape [ 4, 4 ] packed_output_split_0 = packed_output._split_dimension(0, cast(N, ScalarType.index)) @@ -1351,7 +1457,8 @@ def _transpose(): transpose_fn = package.add(transpose_nest, args=(Input, Output), base_name="transpose_fn") - outer_nest = Nest(shape=(1,)) + outer_nest = Nest(shape=(1, )) + @outer_nest.iteration_logic def _(): diag_fn(Input, Output) @@ -1367,35 +1474,38 @@ def _(): # correctness check test_input = np.random.random([N, N]).astype(np.float32) - test_output = np.random.random([N*N + N]).astype(np.float32) - test_output_ref = np.random.random([N*N + N]).astype(np.float32) + test_output = np.random.random([N * N + N]).astype(np.float32) + test_output_ref = np.random.random([N * N + N]).astype(np.float32) for i in range(0, N, N_inner): for j in range(0, N, N_inner): for ii in range(N_inner): - test_output_ref[i + ii] = test_input[i+ii, i+ii] # fill the beginning with the diagonal elements + test_output_ref[i + ii] = test_input[i + ii, + i + ii] # fill the beginning with the diagonal elements for jj in range(N_inner): # output[i, j, jj, ii] = input[i+ii, j+jj] # output[i*((N//N_inner) * N_inner * N_inner) + j*(N_inner * N_inner) + jj*(N_inner) + ii] = input[i+ii, j+jj] # Then offset output by N to account for the beginning diagonal elements # Note that since i and j each step by N_inner, there's already one multiplication by N_inner accounted for in their values - test_output_ref[N + (i*((N//N_inner)*N_inner) + j*(N_inner) + jj*(N_inner) + ii)] = test_input[i + ii, j + jj] + test_output_ref[N + (i * ((N // N_inner) * N_inner) + j * (N_inner) + jj * + (N_inner) + ii)] = test_input[i + ii, j + jj] v.check_correctness(function.name, before=(test_input, test_output), after=(test_input, test_output_ref)) def test_shifting_shrinking_sub_array(self) -> None: N = 64 test_name = "test_shifting_shrinking_sub_array" - Input = Array(role=Role.INPUT, element_type=ScalarType.float32, shape=(N,)) - Output = Array(role=Role.INPUT_OUTPUT, element_type=ScalarType.float32, shape=(N,)) + Input = Array(role=Role.INPUT, element_type=ScalarType.float32, shape=(N, )) + Output = Array(role=Role.INPUT_OUTPUT, element_type=ScalarType.float32, shape=(N, )) package = Package() - nest = Nest(shape=(N,)) + nest = Nest(shape=(N, )) idx, = nest.get_indices() + @nest.iteration_logic def _fn(): size = N - idx - sub_arr = Input.sub_array(offsets=(idx,), shape=(size,)) + sub_arr = Input.sub_array(offsets=(idx, ), shape=(size, )) Output[0] = sub_arr[0] function = package.add(nest, args=(Input, Output), base_name=test_name) @@ -1419,23 +1529,24 @@ def test_dynamic_sub_array_split_dim_subfunction(self) -> None: test_name = "test_dynamic_sub_array_split_dim_subfunction" N = 64 - tile_size = 20 # Intentionally does not divide N + tile_size = 20 # Intentionally does not divide N inner_split_size = 2 package = Package() - Input = Array(role=Role.INPUT, element_type=ScalarType.float32, shape=(N,)) - Output = Array(role=Role.INPUT_OUTPUT, element_type=ScalarType.float32, shape=(N,)) + Input = Array(role=Role.INPUT, element_type=ScalarType.float32, shape=(N, )) + Output = Array(role=Role.INPUT_OUTPUT, element_type=ScalarType.float32, shape=(N, )) current_outer_idx, extent = create_dimensions() - inner_nest = Nest(shape=(extent,)) + inner_nest = Nest(shape=(extent, )) inner_idx, = inner_nest.get_indices() + @inner_nest.iteration_logic def _inner_fn(): full_idx = current_outer_idx + inner_idx tile_remaining_elts = accmin(N - current_outer_idx, cast(tile_size, ScalarType.index)) - sub_arr = Output.sub_array(offsets=(current_outer_idx,), shape=(tile_remaining_elts,)) + sub_arr = Output.sub_array(offsets=(current_outer_idx, ), shape=(tile_remaining_elts, )) split_arr = sub_arr._split_dimension(0, cast(inner_split_size, ScalarType.index)) split_arr[inner_idx / inner_split_size, inner_idx % inner_split_size] = Input[full_idx] @@ -1443,10 +1554,12 @@ def _inner_fn(): inner_nest, args=(extent, Input, Output, current_outer_idx), base_name=f"{test_name}_inner_fn", - function_opts=INTERNAL_FUNCTION_OPTS) + function_opts=INTERNAL_FUNCTION_OPTS + ) - outer_nest = Nest(shape=(N,)) + outer_nest = Nest(shape=(N, )) outer_idx, = outer_nest.get_indices() + @outer_nest.iteration_logic def _outer_fn(): extent_val = accmin(N - outer_idx, cast(tile_size, ScalarType.index)) @@ -1457,14 +1570,16 @@ def _outer_fn(): outer_sched.reorder(outer_idx, outer_idx_inner) outer_plan = outer_sched.create_plan() outer_plan._erase_loops([outer_idx_inner]) - + function = package.add(outer_plan, args=(Input, Output), base_name=test_name) output_dir = pathlib.Path(TEST_PACKAGE_DIR) / test_name shutil.rmtree(output_dir, ignore_errors=True) with verifiers.VerifyPackage(self, test_name, output_dir) as v: - package.build(name=test_name, format=self.PACKAGE_FORMAT, mode=self.PACKAGE_MODE, output_dir=output_dir, _quiet=False) + package.build( + name=test_name, format=self.PACKAGE_FORMAT, mode=self.PACKAGE_MODE, output_dir=output_dir, _quiet=False + ) # correctness check test_input = np.random.random([N]).astype(np.float32) @@ -1481,8 +1596,8 @@ def test_dynamic_sub_array_multi_split_dim_subfunction(self) -> None: # Take in the data as a packed flat buffer, interpret it as a matrix and copy it in tiles # Copy a matrix by tiles that does not evenly divide the matrix shape - M = 76 # Multiple of m_tile_inner, but not m_tile_outer - N = 92 # Multiple of n_tile_inner, but not n_tile_outer + M = 76 # Multiple of m_tile_inner, but not m_tile_outer + N = 92 # Multiple of n_tile_inner, but not n_tile_outer m_tile_outer = 16 m_tile_inner = 2 @@ -1498,13 +1613,14 @@ def test_dynamic_sub_array_multi_split_dim_subfunction(self) -> None: package = Package() Input = Array(role=Role.INPUT, element_type=ScalarType.float32, shape=(M, N)) - Output = Array(role=Role.INPUT_OUTPUT, element_type=ScalarType.float32, shape=(M*N,)) + Output = Array(role=Role.INPUT_OUTPUT, element_type=ScalarType.float32, shape=(M * N, )) current_outer_i, current_outer_j = create_dimensions() extent_i, extent_j = create_dimensions() inner_nest = Nest(shape=(extent_i, extent_j)) inner_i, inner_j = inner_nest.get_indices() + @inner_nest.iteration_logic def _inner_fn(): full_i = current_outer_i + inner_i @@ -1512,9 +1628,11 @@ def _inner_fn(): remaining_m_tile = accmin(M - current_outer_i, cast(m_tile_outer, ScalarType.index)) remaining_n_tile = accmin(N - current_outer_j, cast(n_tile_outer, ScalarType.index)) output_offset_pos = current_outer_i * N + current_outer_j * remaining_m_tile - remaining_mn_tile = accmin(remaining_m_tile * remaining_n_tile, cast(m_tile_outer * n_tile_outer, ScalarType.index)) + remaining_mn_tile = accmin( + remaining_m_tile * remaining_n_tile, cast(m_tile_outer * n_tile_outer, ScalarType.index) + ) - sub_arr = Output.sub_array(offsets=(output_offset_pos,), shape=(remaining_mn_tile,)) + sub_arr = Output.sub_array(offsets=(output_offset_pos, ), shape=(remaining_mn_tile, )) # Split [512] -> [128, 4] (main) # Split [448] -> [112, 4] (cleanup 1) @@ -1535,20 +1653,19 @@ def _inner_fn(): dynamic_j_split = remaining_n_tile / n_tile_inner split_arr_3 = split_arr_2._split_dimension(0, cast(dynamic_j_split, ScalarType.index)) - split_arr_3[inner_i / m_tile_inner, - inner_j / n_tile_inner, - inner_i % m_tile_inner, + split_arr_3[inner_i / m_tile_inner, inner_j / n_tile_inner, inner_i % m_tile_inner, inner_j % n_tile_inner] = Input[full_i, full_j] inner_fn = package.add( inner_nest, args=(extent_i, extent_j, Input, Output, current_outer_i, current_outer_j), base_name=f"{test_name}_inner_fn", - function_opts=INTERNAL_FUNCTION_OPTS) - + function_opts=INTERNAL_FUNCTION_OPTS + ) outer_nest = Nest(shape=(M, N)) outer_i, outer_j = outer_nest.get_indices() + @outer_nest.iteration_logic def _outer_fn(): extent_i_val = accmin(M - outer_i, cast(m_tile_outer, ScalarType.index)) @@ -1567,22 +1684,27 @@ def _outer_fn(): shutil.rmtree(output_dir, ignore_errors=True) with verifiers.VerifyPackage(self, test_name, output_dir) as v: - package.build(name=test_name, format=self.PACKAGE_FORMAT, mode=self.PACKAGE_MODE, output_dir=output_dir, _quiet=False) + package.build( + name=test_name, format=self.PACKAGE_FORMAT, mode=self.PACKAGE_MODE, output_dir=output_dir, _quiet=False + ) # correctness check test_input = np.random.random([M, N]).astype(np.float32) - test_output = np.random.random([M*N]).astype(np.float32) - test_output_ref = np.random.random([M*N]).astype(np.float32) + test_output = np.random.random([M * N]).astype(np.float32) + test_output_ref = np.random.random([M * N]).astype(np.float32) + # m_tile_outer does not divide M and n_tile_outer does not divide N, so we have cleanup cases there # However, m_tile_inner and n_tile_inner do divide M and N respectively, so we won't need cleanup cases there def partition_value(range_val, split): return (range_val // split) * split + M_pv = partition_value(M, m_tile_outer) N_pv = partition_value(N, n_tile_outer) m_tile_outer_cleanup = M - M_pv n_tile_outer_cleanup = N - N_pv + def packed_index(i_outer, i_middle, i_inner, j_outer, j_middle, j_inner, tile_offset, n_tile_outer_clamped): - return tile_offset + i_middle*n_tile_outer_clamped + j_middle*m_tile_inner + i_inner*n_tile_inner + j_inner + return tile_offset + i_middle * n_tile_outer_clamped + j_middle * m_tile_inner + i_inner * n_tile_inner + j_inner for i_outer in range(0, M_pv, m_tile_outer): for j_outer in range(0, N_pv, n_tile_outer): @@ -1591,7 +1713,10 @@ def packed_index(i_outer, i_middle, i_inner, j_outer, j_middle, j_inner, tile_of for j_middle in range(0, n_tile_outer, n_tile_inner): for i_inner in range(m_tile_inner): for j_inner in range(n_tile_inner): - test_output_ref[packed_index(i_outer, i_middle, i_inner, j_outer, j_middle, j_inner, tile_offset, n_tile_outer)] = test_input[i_outer + i_middle + i_inner, j_outer + j_middle + j_inner] + test_output_ref[packed_index( + i_outer, i_middle, i_inner, j_outer, j_middle, j_inner, tile_offset, + n_tile_outer + )] = test_input[i_outer + i_middle + i_inner, j_outer + j_middle + j_inner] # j_outer cleanup for j_outer in range(N_pv, N, n_tile_outer): tile_offset = i_outer * N + j_outer * m_tile_outer @@ -1599,7 +1724,10 @@ def packed_index(i_outer, i_middle, i_inner, j_outer, j_middle, j_inner, tile_of for j_middle in range(0, n_tile_outer_cleanup, n_tile_inner): for i_inner in range(m_tile_inner): for j_inner in range(n_tile_inner): - test_output_ref[packed_index(i_outer, i_middle, i_inner, j_outer, j_middle, j_inner, tile_offset, n_tile_outer_cleanup)] = test_input[i_outer + i_middle + i_inner, j_outer + j_middle + j_inner] + test_output_ref[packed_index( + i_outer, i_middle, i_inner, j_outer, j_middle, j_inner, tile_offset, + n_tile_outer_cleanup + )] = test_input[i_outer + i_middle + i_inner, j_outer + j_middle + j_inner] # i_outer cleanup for i_outer in range(M_pv, M, m_tile_outer): for j_outer in range(0, N_pv, n_tile_outer): @@ -1608,7 +1736,10 @@ def packed_index(i_outer, i_middle, i_inner, j_outer, j_middle, j_inner, tile_of for j_middle in range(0, n_tile_outer, n_tile_inner): for i_inner in range(m_tile_inner): for j_inner in range(n_tile_inner): - test_output_ref[packed_index(i_outer, i_middle, i_inner, j_outer, j_middle, j_inner, tile_offset, n_tile_outer)] = test_input[i_outer + i_middle + i_inner, j_outer + j_middle + j_inner] + test_output_ref[packed_index( + i_outer, i_middle, i_inner, j_outer, j_middle, j_inner, tile_offset, + n_tile_outer + )] = test_input[i_outer + i_middle + i_inner, j_outer + j_middle + j_inner] # j_outer cleanup for j_outer in range(N_pv, N, n_tile_outer): tile_offset = i_outer * N + j_outer * m_tile_outer_cleanup @@ -1616,7 +1747,10 @@ def packed_index(i_outer, i_middle, i_inner, j_outer, j_middle, j_inner, tile_of for j_middle in range(0, n_tile_outer_cleanup, n_tile_inner): for i_inner in range(m_tile_inner): for j_inner in range(n_tile_inner): - test_output_ref[packed_index(i_outer, i_middle, i_inner, j_outer, j_middle, j_inner, tile_offset, n_tile_outer_cleanup)] = test_input[i_outer + i_middle + i_inner, j_outer + j_middle + j_inner] + test_output_ref[packed_index( + i_outer, i_middle, i_inner, j_outer, j_middle, j_inner, tile_offset, + n_tile_outer_cleanup + )] = test_input[i_outer + i_middle + i_inner, j_outer + j_middle + j_inner] v.check_correctness(function.name, before=(test_input, test_output), after=(test_input, test_output_ref)) def test_padded_nchwc_conv2d_manual_cache(self) -> None: @@ -1968,9 +2102,7 @@ def test_matmul_last_major_vectorized_cache(self) -> None: A = Array(role=Role.INPUT, element_type=ScalarType.float32, shape=(M, K), layout=Array.Layout.LAST_MAJOR) B = Array(role=Role.INPUT, element_type=ScalarType.float32, shape=(K, N), layout=Array.Layout.LAST_MAJOR) - C = Array( - role=Role.INPUT_OUTPUT, element_type=ScalarType.float32, shape=(M, N), layout=Array.Layout.LAST_MAJOR - ) + C = Array(role=Role.INPUT_OUTPUT, element_type=ScalarType.float32, shape=(M, N), layout=Array.Layout.LAST_MAJOR) nest = Nest(shape=(M, N, K)) i, j, k = nest.get_indices() @@ -2995,10 +3127,7 @@ def _gpu_cache( A = Array(role=Role.INPUT, element_type=ScalarType.float32, shape=(M, K), layout=Array.Layout.FIRST_MAJOR) B = Array(role=Role.INPUT, element_type=ScalarType.float32, shape=(K, N), layout=Array.Layout.FIRST_MAJOR) C = Array( - role=Role.INPUT_OUTPUT, - element_type=ScalarType.float32, - shape=(M, N), - layout=Array.Layout.FIRST_MAJOR + role=Role.INPUT_OUTPUT, element_type=ScalarType.float32, shape=(M, N), layout=Array.Layout.FIRST_MAJOR ) nest = Nest(shape=(M, N, K)) @@ -3081,10 +3210,7 @@ def test_gpu_cache_double_buffering_trigger_index(self) -> None: A = Array(role=Role.INPUT, element_type=ScalarType.float32, shape=(M, K), layout=Array.Layout.FIRST_MAJOR) B = Array(role=Role.INPUT, element_type=ScalarType.float32, shape=(K, N), layout=Array.Layout.FIRST_MAJOR) C = Array( - role=Role.INPUT_OUTPUT, - element_type=ScalarType.float32, - shape=(M, N), - layout=Array.Layout.FIRST_MAJOR + role=Role.INPUT_OUTPUT, element_type=ScalarType.float32, shape=(M, N), layout=Array.Layout.FIRST_MAJOR ) nest = Nest(shape=(M, N, K)) @@ -3163,10 +3289,7 @@ def test_cpu_cache_double_buffering_trigger_index(self) -> None: A = Array(role=Role.INPUT, element_type=ScalarType.float32, shape=(M, K), layout=Array.Layout.FIRST_MAJOR) B = Array(role=Role.INPUT, element_type=ScalarType.float32, shape=(K, N), layout=Array.Layout.FIRST_MAJOR) C = Array( - role=Role.INPUT_OUTPUT, - element_type=ScalarType.float32, - shape=(M, N), - layout=Array.Layout.FIRST_MAJOR + role=Role.INPUT_OUTPUT, element_type=ScalarType.float32, shape=(M, N), layout=Array.Layout.FIRST_MAJOR ) nest = Nest(shape=(M, N, K)) @@ -3196,14 +3319,18 @@ def _(): self._verify_matrix_multiplication_function(function, package, test_name) - def _matmul_cache_element_type_common(self, test_name, array_element_types, cache_element_types, check_correctness=True) -> None: + def _matmul_cache_element_type_common( + self, test_name, array_element_types, cache_element_types, check_correctness=True + ) -> None: M = 256 N = 256 K = 256 A = Array(role=Role.INPUT, element_type=array_element_types[0], shape=(M, K), layout=Array.Layout.FIRST_MAJOR) B = Array(role=Role.INPUT, element_type=array_element_types[1], shape=(K, N), layout=Array.Layout.FIRST_MAJOR) - C = Array(role=Role.INPUT_OUTPUT, element_type=array_element_types[2], shape=(M, N), layout=Array.Layout.FIRST_MAJOR) + C = Array( + role=Role.INPUT_OUTPUT, element_type=array_element_types[2], shape=(M, N), layout=Array.Layout.FIRST_MAJOR + ) nest = Nest(shape=(M, N, K)) i, j, k = nest.get_indices() @@ -3232,7 +3359,6 @@ def _(): self._verify_matrix_multiplication_function(function, package, test_name, check_correctness=check_correctness) - # TODO : move vpmaddwd tests to a different test file def test_signextend_int16_matmul_vpmaddwd(self): from accera import AllocateFlags, create_dimensions @@ -3252,7 +3378,7 @@ def inout_array(arr: Array): M_kernel_tile = 6 N_kernel_tile = 16 - + N_vector_tile = 8 K_vector_tile = 2 @@ -3260,32 +3386,40 @@ def inout_array(arr: Array): B = Array(role=Role.INPUT, element_type=ScalarType.uint8, shape=(K, N), layout=Array.Layout.FIRST_MAJOR) C = Array(role=Role.INPUT_OUTPUT, element_type=ScalarType.int32, shape=(M, N), layout=Array.Layout.FIRST_MAJOR) - A_cache = Array(role=Role.TEMP, - element_type=ScalarType.int16, - shape=(M_tile, K_tile), - layout=Array.Layout.FIRST_MAJOR, - flags=AllocateFlags.HEAP) - B_cache = Array(role=Role.TEMP, - element_type=ScalarType.uint8, - shape=(N_tile // N_kernel_tile, K_tile // K_vector_tile, N_kernel_tile, K_vector_tile), - layout=Array.Layout.FIRST_MAJOR, - flags=AllocateFlags.HEAP) - - C_cache = Array(role=Role.TEMP, - element_type=ScalarType.int32, - shape=(M_kernel_tile, N_kernel_tile), - layout=Array.Layout.FIRST_MAJOR, - flags=AllocateFlags.STACK) # Stack allocate the small accumulation cache + A_cache = Array( + role=Role.TEMP, + element_type=ScalarType.int16, + shape=(M_tile, K_tile), + layout=Array.Layout.FIRST_MAJOR, + flags=AllocateFlags.HEAP + ) + B_cache = Array( + role=Role.TEMP, + element_type=ScalarType.uint8, + shape=(N_tile // N_kernel_tile, K_tile // K_vector_tile, N_kernel_tile, K_vector_tile), + layout=Array.Layout.FIRST_MAJOR, + flags=AllocateFlags.HEAP + ) + + C_cache = Array( + role=Role.TEMP, + element_type=ScalarType.int32, + shape=(M_kernel_tile, N_kernel_tile), + layout=Array.Layout.FIRST_MAJOR, + flags=AllocateFlags.STACK + ) # Stack allocate the small accumulation cache io_A_cache = inout_array(A_cache) io_B_cache = inout_array(B_cache) io_C_cache = inout_array(C_cache) - B_ext = Array(role=Role.TEMP, - element_type=ScalarType.int16, - shape=(N_kernel_tile, K_vector_tile), - layout=Array.Layout.FIRST_MAJOR, - flags=AllocateFlags.STACK) + B_ext = Array( + role=Role.TEMP, + element_type=ScalarType.int16, + shape=(N_kernel_tile, K_vector_tile), + layout=Array.Layout.FIRST_MAJOR, + flags=AllocateFlags.STACK + ) io_B_ext = inout_array(B_ext) @@ -3299,6 +3433,7 @@ def inout_array(arr: Array): ### Matmul inner kernel tile mmi_nest = Nest(shape=(n_kernel_dim, k_kernel_dim)) mmi_j, mmi_k = mmi_nest.get_indices() + @mmi_nest.iteration_logic def _matmul_inner(): mmi_i = i_vector_idx @@ -3313,16 +3448,20 @@ def _matmul_inner(): mmi_sched.reorder(mmi_j, mmi_k, mmi_jj, mmi_jjj, mmi_kk) mmi_plan = mmi_sched.create_plan() mmi_plan.vectorize(mmi_jjj) - mmi_fn = package.add(mmi_plan, - args=(n_kernel_dim, k_kernel_dim, - io_A_cache, io_B_ext, io_C_cache, - i_kernel_idx, j_kernel_idx, k_kernel_idx, i_vector_idx), + mmi_fn = package.add( + mmi_plan, + args=( + n_kernel_dim, k_kernel_dim, io_A_cache, io_B_ext, io_C_cache, i_kernel_idx, j_kernel_idx, k_kernel_idx, + i_vector_idx + ), base_name="matmul_kernel", - function_opts=INTERNAL_FUNCTION_OPTS) + function_opts=INTERNAL_FUNCTION_OPTS + ) ### B element zero extend bext_nest = Nest((n_kernel_dim, k_kernel_dim)) bext_j, bext_k = bext_nest.get_indices() + @bext_nest.iteration_logic def _bext(): tile_j = j_kernel_idx @@ -3335,17 +3474,17 @@ def _bext(): bext_sched.reorder(bext_j, bext_k, bext_jj, bext_jjj, bext_kk) bext_plan = bext_sched.create_plan() bext_plan.vectorize(bext_jjj) - bext_fn = package.add(bext_plan, - args=(n_kernel_dim, k_kernel_dim, - io_B_cache, io_B_ext, - j_kernel_idx, k_kernel_idx), + bext_fn = package.add( + bext_plan, + args=(n_kernel_dim, k_kernel_dim, io_B_cache, io_B_ext, j_kernel_idx, k_kernel_idx), base_name="b_ext_kernel", - function_opts=INTERNAL_FUNCTION_OPTS) - + function_opts=INTERNAL_FUNCTION_OPTS + ) ### Matmul outer kernel tile mmo_nest = Nest(shape=(m_kernel_dim, k_tile_dim)) mmo_i, mmo_k = mmo_nest.get_indices() + @mmo_nest.iteration_logic def _matmul(): @@ -3358,7 +3497,9 @@ def _matmul(): k_kernel_extent = accmin(k_tile_dim - mmo_k, cast(K_vector_tile, ScalarType.index)) bext_fn(n_kernel_dim, k_kernel_extent, io_B_cache, B_ext, j_kernel_idx, mmo_k) - mmi_fn(n_kernel_dim, k_kernel_extent, io_A_cache, B_ext, io_C_cache, i_kernel_idx, j_kernel_idx, mmo_k, mmo_i) + mmi_fn( + n_kernel_dim, k_kernel_extent, io_A_cache, B_ext, io_C_cache, i_kernel_idx, j_kernel_idx, mmo_k, mmo_i + ) mmo_sched = mmo_nest.create_schedule() mmo_ii, mmo_kk = mmo_sched.tile(dict(zip([mmo_i, mmo_k], [M_kernel_tile, K_tile]))) @@ -3366,28 +3507,33 @@ def _matmul(): mmo_sched.reorder(mmo_k, mmo_i, mmo_kk, mmo_ii, mmo_kkk) mmo_plan = mmo_sched.create_plan() mmo_plan._erase_loops([mmo_kkk]) - mmo_fn = package.add(mmo_plan, - args=(m_kernel_dim, n_kernel_dim, k_tile_dim, - io_A_cache, io_B_cache, io_C_cache, - i_kernel_idx, j_kernel_idx), + mmo_fn = package.add( + mmo_plan, + args=( + m_kernel_dim, n_kernel_dim, k_tile_dim, io_A_cache, io_B_cache, io_C_cache, i_kernel_idx, j_kernel_idx + ), base_name="matmul_kernel", - function_opts=INTERNAL_FUNCTION_OPTS) - + function_opts=INTERNAL_FUNCTION_OPTS + ) ### C cache init cci_nest = Nest(shape=(M_kernel_tile, N_kernel_tile)) cci_i, cci_j = cci_nest.get_indices() + @cci_nest.iteration_logic def _cci(): io_C_cache[cci_i, cci_j] = 0 cci_sched = cci_nest.create_schedule() cci_plan = cci_sched.create_plan() - cci_fn = package.add(cci_plan, args=(io_C_cache,), base_name="c_cache_init_kernel", function_opts=INTERNAL_FUNCTION_OPTS) + cci_fn = package.add( + cci_plan, args=(io_C_cache, ), base_name="c_cache_init_kernel", function_opts=INTERNAL_FUNCTION_OPTS + ) ### C cache reduce ccr_nest = Nest(shape=(m_kernel_dim, n_kernel_dim)) ccr_i, ccr_j = ccr_nest.get_indices() + @ccr_nest.iteration_logic def _ccr(): global_i = i_tile_idx + i_kernel_idx + ccr_i @@ -3399,17 +3545,17 @@ def _ccr(): ccr_sched.reorder(ccr_i, ccr_j, ccr_ii, ccr_jj) ccr_plan = ccr_sched.create_plan() ccr_plan.vectorize(ccr_ii) - ccr_fn = package.add(ccr_plan, - args=(m_kernel_dim, n_kernel_dim, - C, io_C_cache, - i_tile_idx, j_tile_idx, - i_kernel_idx, j_kernel_idx), + ccr_fn = package.add( + ccr_plan, + args=(m_kernel_dim, n_kernel_dim, C, io_C_cache, i_tile_idx, j_tile_idx, i_kernel_idx, j_kernel_idx), base_name="c_cache_reduce_kernel", - function_opts=INTERNAL_FUNCTION_OPTS) + function_opts=INTERNAL_FUNCTION_OPTS + ) ### A cache pack pa_nest = Nest(shape=(m_tile_dim, k_tile_dim)) pa_i, pa_k = pa_nest.get_indices() + @pa_nest.iteration_logic def _pack_a(): global_i = i_tile_idx + pa_i @@ -3420,22 +3566,23 @@ def _pack_a(): pa_ii, pa_kk = pa_sched.tile(dict(zip([pa_i, pa_k], [M_tile, K_tile]))) pa_sched.reorder(pa_i, pa_k, pa_ii, pa_kk) pa_plan = pa_sched.create_plan() - pa_fn = package.add(pa_plan, - args=(m_tile_dim, k_tile_dim, - A, io_A_cache, - i_tile_idx, k_tile_idx), + pa_fn = package.add( + pa_plan, + args=(m_tile_dim, k_tile_dim, A, io_A_cache, i_tile_idx, k_tile_idx), base_name="pack_a", - function_opts=INTERNAL_FUNCTION_OPTS) - + function_opts=INTERNAL_FUNCTION_OPTS + ) ### B cache pack pb_nest = Nest(shape=(n_tile_dim, k_tile_dim)) pb_j, pb_k = pb_nest.get_indices() + @pb_nest.iteration_logic def _pack_b(): global_j = j_tile_idx + pb_j global_k = k_tile_idx + pb_k - io_B_cache[pb_j / N_kernel_tile, pb_k / K_vector_tile, pb_j % N_kernel_tile, pb_k % K_vector_tile] = B[global_k, global_j] + io_B_cache[pb_j / N_kernel_tile, pb_k / K_vector_tile, pb_j % N_kernel_tile, + pb_k % K_vector_tile] = B[global_k, global_j] pb_sched = pb_nest.create_schedule() pb_jj, pb_kk = pb_sched.tile(dict(zip([pb_j, pb_k], [N_tile, K_tile]))) @@ -3443,31 +3590,32 @@ def _pack_b(): pb_sched.reorder(pb_j, pb_k, pb_jj, pb_kk, pb_jjj, pb_kkk) pb_plan = pb_sched.create_plan() pb_plan.vectorize(pb_jjj) - pb_fn = package.add(pb_plan, - args=(n_tile_dim, k_tile_dim, - B, io_B_cache, - j_tile_idx, k_tile_idx), + pb_fn = package.add( + pb_plan, + args=(n_tile_dim, k_tile_dim, B, io_B_cache, j_tile_idx, k_tile_idx), base_name="pack_b", - function_opts=INTERNAL_FUNCTION_OPTS) + function_opts=INTERNAL_FUNCTION_OPTS + ) + compute_kernel_nest = Nest(shape=(1, )) - compute_kernel_nest = Nest(shape=(1,)) @compute_kernel_nest.iteration_logic def _hack(): - cci_fn(C_cache) # Don't need to range-clamp this, we can just zero out the full buffer every time + cci_fn(C_cache) # Don't need to range-clamp this, we can just zero out the full buffer every time mmo_fn(m_kernel_dim, n_kernel_dim, k_tile_dim, io_A_cache, io_B_cache, C_cache, i_kernel_idx, j_kernel_idx) ccr_fn(m_kernel_dim, n_kernel_dim, C, C_cache, i_tile_idx, j_tile_idx, i_kernel_idx, j_kernel_idx) compute_kernel_sched = compute_kernel_nest.create_schedule() compute_kernel_plan = compute_kernel_sched.create_plan() - compute_kernel_fn = package.add(compute_kernel_plan, + compute_kernel_fn = package.add( + compute_kernel_plan, args=( - m_kernel_dim, n_kernel_dim, k_tile_dim, - io_A_cache, io_B_cache, C, - i_tile_idx, j_tile_idx, k_tile_idx, - i_kernel_idx, j_kernel_idx), + m_kernel_dim, n_kernel_dim, k_tile_dim, io_A_cache, io_B_cache, C, i_tile_idx, j_tile_idx, k_tile_idx, + i_kernel_idx, j_kernel_idx + ), base_name="compute_kernel_fn", - function_opts=INTERNAL_FUNCTION_OPTS) + function_opts=INTERNAL_FUNCTION_OPTS + ) tile_nest = Nest(shape=(m_tile_dim, n_tile_dim)) tile_i, tile_j = tile_nest.get_indices() @@ -3476,24 +3624,29 @@ def _hack(): def _tile(): m_kernel_extent = accmin(m_tile_dim - tile_i, cast(M_kernel_tile, ScalarType.index)) n_kernel_extent = accmin(n_tile_dim - tile_j, cast(N_kernel_tile, ScalarType.index)) - compute_kernel_fn(m_kernel_extent, n_kernel_extent, k_tile_dim, - io_A_cache, io_B_cache, C, - i_tile_idx, j_tile_idx, k_tile_idx, - tile_i, tile_j) + compute_kernel_fn( + m_kernel_extent, n_kernel_extent, k_tile_dim, io_A_cache, io_B_cache, C, i_tile_idx, j_tile_idx, + k_tile_idx, tile_i, tile_j + ) tile_sched = tile_nest.create_schedule() - tile_ii, tile_jj = tile_sched.tile({ tile_i: M_tile, tile_j: N_tile }) - tile_iii, tile_jjj = tile_sched.tile({ tile_ii: M_kernel_tile, tile_jj: N_kernel_tile }) + tile_ii, tile_jj = tile_sched.tile({ + tile_i: M_tile, + tile_j: N_tile + }) + tile_iii, tile_jjj = tile_sched.tile({ + tile_ii: M_kernel_tile, + tile_jj: N_kernel_tile + }) tile_sched.reorder(tile_i, tile_j, tile_ii, tile_jj, tile_iii, tile_jjj) tile_plan = tile_sched.create_plan() tile_plan._erase_loops([tile_iii, tile_jjj]) - tile_fn = package.add(tile_plan, - args=(m_tile_dim, n_tile_dim, k_tile_dim, - io_A_cache, io_B_cache, C, - i_tile_idx, j_tile_idx, k_tile_idx), + tile_fn = package.add( + tile_plan, + args=(m_tile_dim, n_tile_dim, k_tile_dim, io_A_cache, io_B_cache, C, i_tile_idx, j_tile_idx, k_tile_idx), base_name="tile_fn", - function_opts=INTERNAL_FUNCTION_OPTS) - + function_opts=INTERNAL_FUNCTION_OPTS + ) global_nest = Nest(shape=(M, N, K)) global_i, global_j, global_k = global_nest.get_indices() @@ -3509,13 +3662,17 @@ def _tile(): tile_fn(m_tile_extent, n_tile_extent, k_tile_extent, A_cache, B_cache, C, global_i, global_j, global_k) global_sched = global_nest.create_schedule() - global_ii, global_jj, global_kk = global_sched.tile({ global_i: M_tile, global_j: N_tile, global_k: K_tile }) + global_ii, global_jj, global_kk = global_sched.tile({ + global_i: M_tile, + global_j: N_tile, + global_k: K_tile + }) global_sched.reorder(global_i, global_j, global_k, global_ii, global_jj, global_kk) global_plan = global_sched.create_plan() global_plan._erase_loops([global_ii, global_jj, global_kk]) function = package.add(global_plan, args=(A, B, C), base_name=test_name) - + A_test = np.random.random((M, K)).astype(np.int16) B_test = np.random.random((K, N)).astype(np.uint8) C_test = np.random.random((M, N)).astype(np.int32) @@ -3529,14 +3686,18 @@ def _tile(): # build the HAT package with verifiers.VerifyPackage(self, test_name, output_dir) as v: - package.build(test_name, format=Package.Format.DEFAULT | Package.Format.MLIR, mode=Package.Mode.RELEASE, output_dir=output_dir) + package.build( + test_name, + format=Package.Format.DEFAULT | Package.Format.MLIR, + mode=Package.Mode.RELEASE, + output_dir=output_dir + ) v.check_correctness( function.name, before=correctness_check_values["pre"], after=correctness_check_values["post"], ) - def test_int16_matmul_vpmaddwd(self): test_name = "test_int16_matmul_vpmaddwd" M = 240 @@ -3555,24 +3716,32 @@ def _(): C[i, j] += A[i, k] * B[k, j] schedule = nest.create_schedule() - ii, jj, kk = schedule.tile({ i: 24, j: 128, k: 128 }) - iii, jjj, kkk = schedule.tile({ ii: 6, jj: 16, kk: 4 }) - jjjj, kkkk = schedule.tile({ jjj: 8, kkk: 2 }) + ii, jj, kk = schedule.tile({ + i: 24, + j: 128, + k: 128 + }) + iii, jjj, kkk = schedule.tile({ + ii: 6, + jj: 16, + kk: 4 + }) + jjjj, kkkk = schedule.tile({ + jjj: 8, + kkk: 2 + }) - schedule.reorder(i, j, k, - ii, jj, kk, - kkk, iii, jjj, - jjjj, kkkk) + schedule.reorder(i, j, k, ii, jj, kk, kkk, iii, jjj, jjjj, kkkk) plan = schedule.create_plan() - plan.cache(A, index = ii, element_type = ScalarType.int16, vectorize=False) - plan.cache(B, index = jjjj, trigger_index = jj, layout = Array.Layout.LAST_MAJOR, vectorize=False) + plan.cache(A, index=ii, element_type=ScalarType.int16, vectorize=False) + plan.cache(B, index=jjjj, trigger_index=jj, layout=Array.Layout.LAST_MAJOR, vectorize=False) plan.cache(C, iii) plan.vectorize(jjjj) package = Package() function = package.add(plan, args=(A, B, C), base_name=test_name) - + A_test = np.random.random((M, K)).astype(np.int16) B_test = np.random.random((K, N)).astype(np.int16) C_test = np.random.random((M, N)).astype(np.int32) @@ -3593,8 +3762,10 @@ def _(): after=correctness_check_values["post"], ) - - @expectedFailure(FailedReason.INVALID, "generated x86_64 lib not readable by MacOS arm64 build tools", sys.platform == "darwin" and platform.machine() == "arm64") + @expectedFailure( + FailedReason.INVALID, "generated x86_64 lib not readable by MacOS arm64 build tools", sys.platform == "darwin" + and platform.machine() == "arm64" + ) def test_int16_matmul_vpmaddwd_16_element_avx512(self): test_name = "test_int16_matmul_vpmaddwd_16_element_avx512" M = 240 @@ -3613,20 +3784,28 @@ def _(): C[i, j] += A[i, k] * B[k, j] schedule = nest.create_schedule() - ii, jj, kk = schedule.tile({ i: 24, j: 128, k: 128 }) - iii, jjj, kkk = schedule.tile({ ii: 6, jj: 32, kk: 4 }) - jjjj, kkkk = schedule.tile({ jjj: 16, kkk: 2 }) + ii, jj, kk = schedule.tile({ + i: 24, + j: 128, + k: 128 + }) + iii, jjj, kkk = schedule.tile({ + ii: 6, + jj: 32, + kk: 4 + }) + jjjj, kkkk = schedule.tile({ + jjj: 16, + kkk: 2 + }) - schedule.reorder(i, j, k, - ii, jj, kk, - kkk, iii, jjj, - jjjj, kkkk) + schedule.reorder(i, j, k, ii, jj, kk, kkk, iii, jjj, jjjj, kkkk) # The Intel 8351N is a known Xeon Platinum with AVX-512 support target = KNOWN_DEVICES[Target.Category.CPU]["Intel 8351N"] plan = schedule.create_plan(target) - plan.cache(A, index = ii, element_type = ScalarType.int16, vectorize=False) - plan.cache(B, index = jjjj, trigger_index = jj, layout = Array.Layout.LAST_MAJOR, vectorize=False) + plan.cache(A, index=ii, element_type=ScalarType.int16, vectorize=False) + plan.cache(B, index=jjjj, trigger_index=jj, layout=Array.Layout.LAST_MAJOR, vectorize=False) plan.cache(C, iii) plan.vectorize(jjjj) @@ -3640,8 +3819,6 @@ def _(): package.build(test_name, format=Package.Format.DEFAULT, mode=Package.Mode.RELEASE, output_dir=output_dir) # Don't check correctness as we've set a target that we may not be running the tests on - - def test_int16_matmul_vpmaddwd_16_element_host(self): test_name = "test_int16_matmul_vpmaddwd_16_element_host" M = 240 @@ -3660,18 +3837,26 @@ def _(): C[i, j] += A[i, k] * B[k, j] schedule = nest.create_schedule() - ii, jj, kk = schedule.tile({ i: 24, j: 128, k: 128 }) - iii, jjj, kkk = schedule.tile({ ii: 6, jj: 32, kk: 4 }) - jjjj, kkkk = schedule.tile({ jjj: 16, kkk: 2 }) + ii, jj, kk = schedule.tile({ + i: 24, + j: 128, + k: 128 + }) + iii, jjj, kkk = schedule.tile({ + ii: 6, + jj: 32, + kk: 4 + }) + jjjj, kkkk = schedule.tile({ + jjj: 16, + kkk: 2 + }) - schedule.reorder(i, j, k, - ii, jj, kk, - kkk, iii, jjj, - jjjj, kkkk) + schedule.reorder(i, j, k, ii, jj, kk, kkk, iii, jjj, jjjj, kkkk) plan = schedule.create_plan() - plan.cache(A, index = ii, element_type = ScalarType.int16, vectorize=False) - plan.cache(B, index = jjjj, trigger_index = jj, layout = Array.Layout.LAST_MAJOR, vectorize=False) + plan.cache(A, index=ii, element_type=ScalarType.int16, vectorize=False) + plan.cache(B, index=jjjj, trigger_index=jj, layout=Array.Layout.LAST_MAJOR, vectorize=False) plan.cache(C, iii) plan.vectorize(jjjj) @@ -3716,24 +3901,32 @@ def _(): C[i, j] += cast(A[i, k], ScalarType.int32) * cast(B[k, j], ScalarType.int32) schedule = nest.create_schedule() - ii, jj, kk = schedule.tile({ i: 24, j: 128, k: 128 }) - iii, jjj, kkk = schedule.tile({ ii: 6, jj: 16, kk: 4 }) - jjjj, kkkk = schedule.tile({ jjj: 8, kkk: 2 }) + ii, jj, kk = schedule.tile({ + i: 24, + j: 128, + k: 128 + }) + iii, jjj, kkk = schedule.tile({ + ii: 6, + jj: 16, + kk: 4 + }) + jjjj, kkkk = schedule.tile({ + jjj: 8, + kkk: 2 + }) - schedule.reorder(i, j, k, - ii, jj, kk, - kkk, iii, jjj, - jjjj, kkkk) + schedule.reorder(i, j, k, ii, jj, kk, kkk, iii, jjj, jjjj, kkkk) plan = schedule.create_plan() - plan.cache(A, index = ii, element_type = ScalarType.int16, vectorize=False) - plan.cache(B, index = jjjj, trigger_index = jj, layout = Array.Layout.LAST_MAJOR, vectorize=False) + plan.cache(A, index=ii, element_type=ScalarType.int16, vectorize=False) + plan.cache(B, index=jjjj, trigger_index=jj, layout=Array.Layout.LAST_MAJOR, vectorize=False) plan.cache(C, iii) plan.vectorize(jjjj) package = Package() function = package.add(plan, args=(A, B, C), base_name=test_name) - + A_test = np.random.random((M, K)).astype(np.int16) B_test = np.random.random((K, N)).astype(np.int16) C_test = np.random.random((M, N)).astype(np.int32) @@ -3760,7 +3953,7 @@ def test_int32_horizontal_vector_add(self): N = 16 A = Array(role=Role.INPUT, element_type=ScalarType.int32, shape=(M, N), layout=Array.Layout.FIRST_MAJOR) - B = Array(role=Role.INPUT, element_type=ScalarType.int32, shape=(M,), layout=Array.Layout.FIRST_MAJOR) + B = Array(role=Role.INPUT, element_type=ScalarType.int32, shape=(M, ), layout=Array.Layout.FIRST_MAJOR) nest = Nest(shape=(M, N)) i, j = nest.get_indices() @@ -3776,11 +3969,11 @@ def _(): package = Package() function = package.add(plan, args=(A, B), base_name=test_name) - + A_test = np.random.random((M, N)).astype(np.int32) - B_test = np.random.random((M,)).astype(np.int32) + B_test = np.random.random((M, )).astype(np.int32) - B_ref = np.zeros((M,)).astype(np.int32) + B_ref = np.zeros((M, )).astype(np.int32) B_ref[:] = B_test[:] for j in range(N): B_ref[:] += A_test[:, j] @@ -3807,7 +4000,7 @@ def test_int16_to_int32_horizontal_vector_add_simple(self): N = 16 A = Array(role=Role.INPUT, element_type=ScalarType.int16, shape=(M, N), layout=Array.Layout.FIRST_MAJOR) - B = Array(role=Role.INPUT, element_type=ScalarType.int32, shape=(M,), layout=Array.Layout.FIRST_MAJOR) + B = Array(role=Role.INPUT, element_type=ScalarType.int32, shape=(M, ), layout=Array.Layout.FIRST_MAJOR) nest = Nest(shape=(M, N)) i, j = nest.get_indices() @@ -3824,12 +4017,11 @@ def _(): package = Package() function = package.add(plan, args=(A, B), base_name=test_name) - + A_test = np.random.random((M, N)).astype(np.int16) - B_test = np.random.random((M,)).astype(np.int32) + B_test = np.random.random((M, )).astype(np.int32) - B_ref = np.zeros((M,)).astype(np.int32) - B_ref[:] = B_test[:] + B_ref = B_test.copy() for j in range(N): B_ref[:] += A_test[:, j] @@ -3842,143 +4034,335 @@ def _(): # build the HAT package with verifiers.VerifyPackage(self, test_name, output_dir) as v: - package.build(test_name, format=Package.Format.DEFAULT, mode=Package.Mode.RELEASE, output_dir=output_dir) + package.build( + test_name, + format=Package.Format.MLIR | Package.Format.DEFAULT, + mode=Package.Mode.RELEASE, + output_dir=output_dir + ) v.check_correctness( function.name, before=correctness_check_values["pre"], after=correctness_check_values["post"], ) + def test_int16_to_int32_horizontal_vector_1_row(self): + test_name = "test_int16_to_int32_horizontal_vector_1_row" + M = 1 + N = 16 + + A = Array(role=Role.INPUT, element_type=ScalarType.int16, shape=(M, N), layout=Array.Layout.FIRST_MAJOR) + B = Array(role=Role.INPUT, element_type=ScalarType.int32, shape=(M, ), layout=Array.Layout.FIRST_MAJOR) + + nest = Nest(shape=(M, N)) + i, j = nest.get_indices() + + @nest.iteration_logic + def _(): + B[i] += A[i, j] + + schedule = nest.create_schedule() + plan = schedule.create_plan() + plan.vectorize(i) + plan.vectorize(j) + + package = Package() + function = package.add(plan, args=(A, B), base_name=test_name) + + A_test = np.random.random((M, N)).astype(np.int16) + B_test = np.random.random((M, )).astype(np.int32) + + B_ref = B_test.copy() + for j in range(N): + B_ref[:] += A_test[:, j] + + correctness_check_values = { + "pre": (A_test, B_test), + "post": (A_test, B_ref), + } + + output_dir = pathlib.Path(TEST_PACKAGE_DIR) / test_name + + # build the HAT package + with verifiers.VerifyPackage(self, test_name, output_dir) as v: + package.build( + test_name, + format=Package.Format.MLIR | Package.Format.DEFAULT, + mode=Package.Mode.RELEASE, + output_dir=output_dir + ) + v.check_correctness( + function.name, + before=correctness_check_values["pre"], + after=correctness_check_values["post"], + ) + + + def test_int32_horizontal_vector_add_simple(self): + test_name = "test_int32_horizontal_vector_add_simple" + M = 256 + N = 8 + + A = Array(role=Role.INPUT, element_type=ScalarType.int32, shape=(M, N), layout=Array.Layout.FIRST_MAJOR) + B = Array(role=Role.INPUT, element_type=ScalarType.int32, shape=(M, ), layout=Array.Layout.FIRST_MAJOR) + + nest = Nest(shape=(M, N)) + i, j = nest.get_indices() + + @nest.iteration_logic + def _(): + B[i] += A[i, j] + + schedule = nest.create_schedule() + ii = schedule.split(i, 4) + schedule.reorder(i, ii, j) + plan = schedule.create_plan() + plan.vectorize(ii) + + package = Package() + function = package.add(plan, args=(A, B), base_name=test_name) + + A_test = np.random.random((M, N)).astype(np.int32) + B_test = np.random.random((M, )).astype(np.int32) + + B_ref = B_test.copy() + for j in range(N): + B_ref[:] += A_test[:, j] + + correctness_check_values = { + "pre": (A_test, B_test), + "post": (A_test, B_ref), + } + + output_dir = pathlib.Path(TEST_PACKAGE_DIR) / test_name + + # build the HAT package + with verifiers.VerifyPackage(self, test_name, output_dir) as v: + package.build( + test_name, + format=Package.Format.MLIR | Package.Format.DEFAULT, + mode=Package.Mode.RELEASE, + output_dir=output_dir + ) + v.check_correctness( + function.name, + before=correctness_check_values["pre"], + after=correctness_check_values["post"], + ) + + def test_float32_horizontal_vector_add_simple(self): + test_name = "test_float32_horizontal_vector_add_simple" + M = 256 + N = 8 + + A = Array(role=Role.INPUT, element_type=ScalarType.float32, shape=(M, N), layout=Array.Layout.FIRST_MAJOR) + B = Array(role=Role.INPUT, element_type=ScalarType.float32, shape=(M, ), layout=Array.Layout.FIRST_MAJOR) + + nest = Nest(shape=(M, N)) + i, j = nest.get_indices() + + @nest.iteration_logic + def _(): + B[i] += A[i, j] + + schedule = nest.create_schedule() + ii = schedule.split(i, 4) + schedule.reorder(i, ii, j) + plan = schedule.create_plan() + plan.vectorize(ii) + + package = Package() + function = package.add(plan, args=(A, B), base_name=test_name) + + A_test = np.random.random((M, N)).astype(np.float32) + B_test = np.random.random((M, )).astype(np.float32) + + B_ref = B_test.copy() + for j in range(N): + B_ref[:] += A_test[:, j] + + correctness_check_values = { + "pre": (A_test, B_test), + "post": (A_test, B_ref), + } + + output_dir = pathlib.Path(TEST_PACKAGE_DIR) / test_name + + # build the HAT package + with verifiers.VerifyPackage(self, test_name, output_dir) as v: + package.build( + test_name, + format=Package.Format.MLIR | Package.Format.DEFAULT, + mode=Package.Mode.RELEASE, + output_dir=output_dir + ) + v.check_correctness( + function.name, + before=correctness_check_values["pre"], + after=correctness_check_values["post"], + ) # Cache widening the type def test_matmul_input_cache_element_type_widen(self) -> None: test_name = "test_matmul_input_cache_element_type_widen" - self._matmul_cache_element_type_common(test_name, - array_element_types=(ScalarType.int16, ScalarType.int16, ScalarType.int16), - cache_element_types=(ScalarType.int32, ScalarType.int32, ScalarType.int16)) + self._matmul_cache_element_type_common( + test_name, + array_element_types=(ScalarType.int16, ScalarType.int16, ScalarType.int16), + cache_element_types=(ScalarType.int32, ScalarType.int32, ScalarType.int16) + ) def test_matmul_output_cache_element_type_widen(self) -> None: test_name = "test_matmul_output_cache_element_type_widen" - self._matmul_cache_element_type_common(test_name, - array_element_types=(ScalarType.int16, ScalarType.int16, ScalarType.int16), - cache_element_types=(ScalarType.int16, ScalarType.int16, ScalarType.int32)) + self._matmul_cache_element_type_common( + test_name, + array_element_types=(ScalarType.int16, ScalarType.int16, ScalarType.int16), + cache_element_types=(ScalarType.int16, ScalarType.int16, ScalarType.int32) + ) def test_matmul_input_output_cache_element_type_widen(self) -> None: test_name = "test_matmul_input_output_cache_element_type_widen" - self._matmul_cache_element_type_common(test_name, - array_element_types=(ScalarType.int16, ScalarType.int16, ScalarType.int16), - cache_element_types=(ScalarType.int32, ScalarType.int32, ScalarType.int32)) - + self._matmul_cache_element_type_common( + test_name, + array_element_types=(ScalarType.int16, ScalarType.int16, ScalarType.int16), + cache_element_types=(ScalarType.int32, ScalarType.int32, ScalarType.int32) + ) # Cache narrowing the type def test_matmul_input_cache_element_type_narrow(self) -> None: test_name = "test_matmul_input_cache_element_type_narrow" - self._matmul_cache_element_type_common(test_name, - array_element_types=(ScalarType.int32, ScalarType.int32, ScalarType.int32), - cache_element_types=(ScalarType.int16, ScalarType.int16, ScalarType.int32)) + self._matmul_cache_element_type_common( + test_name, + array_element_types=(ScalarType.int32, ScalarType.int32, ScalarType.int32), + cache_element_types=(ScalarType.int16, ScalarType.int16, ScalarType.int32) + ) def test_matmul_output_cache_element_type_narrow(self) -> None: test_name = "test_matmul_output_cache_element_type_narrow" - self._matmul_cache_element_type_common(test_name, - array_element_types=(ScalarType.int32, ScalarType.int32, ScalarType.int32), - cache_element_types=(ScalarType.int32, ScalarType.int32, ScalarType.int16)) + self._matmul_cache_element_type_common( + test_name, + array_element_types=(ScalarType.int32, ScalarType.int32, ScalarType.int32), + cache_element_types=(ScalarType.int32, ScalarType.int32, ScalarType.int16) + ) def test_matmul_input_output_cache_element_type_narrow(self) -> None: test_name = "test_matmul_input_output_cache_element_type_narrow" - self._matmul_cache_element_type_common(test_name, - array_element_types=(ScalarType.int32, ScalarType.int32, ScalarType.int32), - cache_element_types=(ScalarType.int16, ScalarType.int16, ScalarType.int16)) - + self._matmul_cache_element_type_common( + test_name, + array_element_types=(ScalarType.int32, ScalarType.int32, ScalarType.int32), + cache_element_types=(ScalarType.int16, ScalarType.int16, ScalarType.int16) + ) # Cache converting the type from int to float def test_matmul_input_cache_element_type_int_to_float(self) -> None: test_name = "test_matmul_input_cache_element_type_int_to_float" - self._matmul_cache_element_type_common(test_name, - array_element_types=(ScalarType.int32, ScalarType.int32, ScalarType.int32), - cache_element_types=(ScalarType.float32, ScalarType.float32, ScalarType.int32)) + self._matmul_cache_element_type_common( + test_name, + array_element_types=(ScalarType.int32, ScalarType.int32, ScalarType.int32), + cache_element_types=(ScalarType.float32, ScalarType.float32, ScalarType.int32) + ) def test_matmul_output_cache_element_type_int_to_float(self) -> None: test_name = "test_matmul_output_cache_element_type_int_to_float" - self._matmul_cache_element_type_common(test_name, - array_element_types=(ScalarType.int32, ScalarType.int32, ScalarType.int32), - cache_element_types=(ScalarType.int32, ScalarType.int32, ScalarType.float32)) + self._matmul_cache_element_type_common( + test_name, + array_element_types=(ScalarType.int32, ScalarType.int32, ScalarType.int32), + cache_element_types=(ScalarType.int32, ScalarType.int32, ScalarType.float32) + ) def test_matmul_input_output_cache_element_type_int_to_float(self) -> None: test_name = "test_matmul_input_output_cache_element_type_int_to_float" - self._matmul_cache_element_type_common(test_name, - array_element_types=(ScalarType.int32, ScalarType.int32, ScalarType.int32), - cache_element_types=(ScalarType.float32, ScalarType.float32, ScalarType.float32)) - + self._matmul_cache_element_type_common( + test_name, + array_element_types=(ScalarType.int32, ScalarType.int32, ScalarType.int32), + cache_element_types=(ScalarType.float32, ScalarType.float32, ScalarType.float32) + ) # Cache converting the type from float to int def test_matmul_input_cache_element_type_float_to_int(self) -> None: test_name = "test_matmul_input_cache_element_type_float_to_int" - self._matmul_cache_element_type_common(test_name, - array_element_types=(ScalarType.float32, ScalarType.float32, ScalarType.float32), - cache_element_types=(ScalarType.int32, ScalarType.int32, ScalarType.float32), - check_correctness=False) # float to int results in so much rounding that correctness checks are not useful + self._matmul_cache_element_type_common( + test_name, + array_element_types=(ScalarType.float32, ScalarType.float32, ScalarType.float32), + cache_element_types=(ScalarType.int32, ScalarType.int32, ScalarType.float32), + check_correctness=False + ) # float to int results in so much rounding that correctness checks are not useful def test_matmul_output_cache_element_type_float_to_int(self) -> None: test_name = "test_matmul_output_cache_element_type_float_to_int" - self._matmul_cache_element_type_common(test_name, - array_element_types=(ScalarType.float32, ScalarType.float32, ScalarType.float32), - cache_element_types=(ScalarType.float32, ScalarType.float32, ScalarType.int32), - check_correctness=False) # float to int results in so much rounding that correctness checks are not useful + self._matmul_cache_element_type_common( + test_name, + array_element_types=(ScalarType.float32, ScalarType.float32, ScalarType.float32), + cache_element_types=(ScalarType.float32, ScalarType.float32, ScalarType.int32), + check_correctness=False + ) # float to int results in so much rounding that correctness checks are not useful def test_matmul_input_output_cache_element_type_float_to_int(self) -> None: test_name = "test_matmul_input_output_cache_element_type_float_to_int" - self._matmul_cache_element_type_common(test_name, - array_element_types=(ScalarType.float32, ScalarType.float32, ScalarType.float32), - cache_element_types=(ScalarType.int32, ScalarType.int32, ScalarType.int32), - check_correctness=False) # float to int results in so much rounding that correctness checks are not useful - + self._matmul_cache_element_type_common( + test_name, + array_element_types=(ScalarType.float32, ScalarType.float32, ScalarType.float32), + cache_element_types=(ScalarType.int32, ScalarType.int32, ScalarType.int32), + check_correctness=False + ) # float to int results in so much rounding that correctness checks are not useful # Cache converting the type from int to uint def test_matmul_input_cache_element_type_int_to_uint(self) -> None: test_name = "test_matmul_input_cache_element_type_int_to_uint" - self._matmul_cache_element_type_common(test_name, - array_element_types=(ScalarType.int32, ScalarType.int32, ScalarType.int32), - cache_element_types=(ScalarType.uint32, ScalarType.uint32, ScalarType.int32)) + self._matmul_cache_element_type_common( + test_name, + array_element_types=(ScalarType.int32, ScalarType.int32, ScalarType.int32), + cache_element_types=(ScalarType.uint32, ScalarType.uint32, ScalarType.int32) + ) def test_matmul_output_cache_element_type_int_to_uint(self) -> None: test_name = "test_matmul_output_cache_element_type_int_to_uint" - self._matmul_cache_element_type_common(test_name, - array_element_types=(ScalarType.int32, ScalarType.int32, ScalarType.int32), - cache_element_types=(ScalarType.int32, ScalarType.int32, ScalarType.uint32)) + self._matmul_cache_element_type_common( + test_name, + array_element_types=(ScalarType.int32, ScalarType.int32, ScalarType.int32), + cache_element_types=(ScalarType.int32, ScalarType.int32, ScalarType.uint32) + ) def test_matmul_input_output_cache_element_type_int_to_uint(self) -> None: test_name = "test_matmul_input_output_cache_element_type_int_to_uint" - self._matmul_cache_element_type_common(test_name, - array_element_types=(ScalarType.int32, ScalarType.int32, ScalarType.int32), - cache_element_types=(ScalarType.uint32, ScalarType.uint32, ScalarType.uint32)) - + self._matmul_cache_element_type_common( + test_name, + array_element_types=(ScalarType.int32, ScalarType.int32, ScalarType.int32), + cache_element_types=(ScalarType.uint32, ScalarType.uint32, ScalarType.uint32) + ) # Cache converting the type from uint to int def test_matmul_input_cache_element_type_uint_to_int(self) -> None: test_name = "test_matmul_input_cache_element_type_uint_to_int" - self._matmul_cache_element_type_common(test_name, - array_element_types=(ScalarType.uint32, ScalarType.uint32, ScalarType.uint32), - cache_element_types=(ScalarType.int32, ScalarType.int32, ScalarType.uint32)) + self._matmul_cache_element_type_common( + test_name, + array_element_types=(ScalarType.uint32, ScalarType.uint32, ScalarType.uint32), + cache_element_types=(ScalarType.int32, ScalarType.int32, ScalarType.uint32) + ) def test_matmul_output_cache_element_type_uint_to_int(self) -> None: test_name = "test_matmul_output_cache_element_type_uint_to_int" - self._matmul_cache_element_type_common(test_name, - array_element_types=(ScalarType.uint32, ScalarType.uint32, ScalarType.uint32), - cache_element_types=(ScalarType.uint32, ScalarType.uint32, ScalarType.int32)) + self._matmul_cache_element_type_common( + test_name, + array_element_types=(ScalarType.uint32, ScalarType.uint32, ScalarType.uint32), + cache_element_types=(ScalarType.uint32, ScalarType.uint32, ScalarType.int32) + ) def test_matmul_input_output_cache_element_type_uint_to_int(self) -> None: test_name = "test_matmul_input_output_cache_element_type_uint_to_int" - self._matmul_cache_element_type_common(test_name, - array_element_types=(ScalarType.uint32, ScalarType.uint32, ScalarType.uint32), - cache_element_types=(ScalarType.int32, ScalarType.int32, ScalarType.int32)) + self._matmul_cache_element_type_common( + test_name, + array_element_types=(ScalarType.uint32, ScalarType.uint32, ScalarType.uint32), + cache_element_types=(ScalarType.int32, ScalarType.int32, ScalarType.int32) + ) # Cache converting the type from uint to int and sign extending def test_matmul_input_cache_element_type_uint_to_int(self) -> None: test_name = "test_matmul_input_cache_element_type_uint8_to_int16" - self._matmul_cache_element_type_common(test_name, - array_element_types=(ScalarType.uint8, ScalarType.uint8, ScalarType.int32), - cache_element_types=(ScalarType.int16, ScalarType.int16, ScalarType.int32)) - + self._matmul_cache_element_type_common( + test_name, + array_element_types=(ScalarType.uint8, ScalarType.uint8, ScalarType.int32), + cache_element_types=(ScalarType.int16, ScalarType.int16, ScalarType.int32) + ) def test_gpu_barrier_opt(self) -> None: from accera import Array, Nest, Package, ScalarType, Target @@ -5122,10 +5506,7 @@ def test_gpu_cache_different_input_layouts(self): A = Array(role=Role.INPUT, element_type=ScalarType.float32, shape=(S, M, K), layout=Array.Layout.FIRST_MAJOR) B = Array(role=Role.INPUT, element_type=ScalarType.float32, shape=(S, K, N), layout=Array.Layout.LAST_MAJOR) C = Array( - role=Role.INPUT_OUTPUT, - element_type=ScalarType.float32, - shape=(S, M, N), - layout=Array.Layout.FIRST_MAJOR + role=Role.INPUT_OUTPUT, element_type=ScalarType.float32, shape=(S, M, N), layout=Array.Layout.FIRST_MAJOR ) nest = Nest(shape=(S, M, N, K)) @@ -5171,15 +5552,11 @@ def file_check_fn(verifier): # Function decl checker.check_label('accv.func nested @test_gpu_cache_different_input_layouts_') - checker.check_same( - '%[[Array_A:[a-z0-9_]+]]: memref<4x2560x2048xf32>' - ) + checker.check_same('%[[Array_A:[a-z0-9_]+]]: memref<4x2560x2048xf32>') checker.check_same( '%[[Array_B:[a-z0-9_]+]]: memref<4x2048x1536xf32, affine_map<(d0, d1, d2) -> (d0 + d1 * 4 + d2 * 8192)>>' ) - checker.check_same( - '%[[Array_C:[a-z0-9_]+]]: memref<4x2560x1536xf32>' - ) + checker.check_same('%[[Array_C:[a-z0-9_]+]]: memref<4x2560x1536xf32>') # Block X/Y checker.check('%[[Block_Y:[0-9_]+]] = gpu.block_id y') @@ -5251,10 +5628,7 @@ def test_gpu_cache_block_level_private_mem(self): A = Array(role=Role.INPUT, element_type=ScalarType.float32, shape=(M, K), layout=Array.Layout.FIRST_MAJOR) B = Array(role=Role.INPUT, element_type=ScalarType.float32, shape=(K, N), layout=Array.Layout.LAST_MAJOR) C = Array( - role=Role.INPUT_OUTPUT, - element_type=ScalarType.float32, - shape=(M, N), - layout=Array.Layout.FIRST_MAJOR + role=Role.INPUT_OUTPUT, element_type=ScalarType.float32, shape=(M, N), layout=Array.Layout.FIRST_MAJOR ) nest = Nest(shape=(M, N, K)) @@ -5339,10 +5713,7 @@ def test_gpu_cache_block_level_shared_mem(self): A = Array(role=Role.INPUT, element_type=ScalarType.float32, shape=(M, K), layout=Array.Layout.FIRST_MAJOR) B = Array(role=Role.INPUT, element_type=ScalarType.float32, shape=(K, N), layout=Array.Layout.LAST_MAJOR) C = Array( - role=Role.INPUT_OUTPUT, - element_type=ScalarType.float32, - shape=(M, N), - layout=Array.Layout.FIRST_MAJOR + role=Role.INPUT_OUTPUT, element_type=ScalarType.float32, shape=(M, N), layout=Array.Layout.FIRST_MAJOR ) nest = Nest(shape=(M, N, K)) @@ -5401,7 +5772,8 @@ def file_check_fn(verifier): package, package_name, file_check_fn=file_check_fn, - check_correctness=False, # We expect this test to produce incorrect gemm results since we are caching output in shared memory and every thread is repeating each others's work. + check_correctness= + False, # We expect this test to produce incorrect gemm results since we are caching output in shared memory and every thread is repeating each others's work. file_list=[f"{package_name}.cu", f"{package_name}.hat"], package_format=Package.Format.DEFAULT | Package.Format.MLIR ) @@ -5428,10 +5800,7 @@ def test_gpu_cache_block_level_global_mem(self): A = Array(role=Role.INPUT, element_type=ScalarType.float32, shape=(M, K), layout=Array.Layout.FIRST_MAJOR) B = Array(role=Role.INPUT, element_type=ScalarType.float32, shape=(K, N), layout=Array.Layout.LAST_MAJOR) C = Array( - role=Role.INPUT_OUTPUT, - element_type=ScalarType.float32, - shape=(M, N), - layout=Array.Layout.FIRST_MAJOR + role=Role.INPUT_OUTPUT, element_type=ScalarType.float32, shape=(M, N), layout=Array.Layout.FIRST_MAJOR ) nest = Nest(shape=(M, N, K)) @@ -5561,10 +5930,7 @@ def test_rocm_cache_double_buffering__with_c_cache_tensorize(self) -> None: A = Array(role=Role.INPUT, element_type=ScalarType.float32, shape=(M, K), layout=Array.Layout.FIRST_MAJOR) B = Array(role=Role.INPUT, element_type=ScalarType.float32, shape=(K, N), layout=Array.Layout.FIRST_MAJOR) C = Array( - role=Role.INPUT_OUTPUT, - element_type=ScalarType.float32, - shape=(M, N), - layout=Array.Layout.FIRST_MAJOR + role=Role.INPUT_OUTPUT, element_type=ScalarType.float32, shape=(M, N), layout=Array.Layout.FIRST_MAJOR ) nest = Nest(shape=(M, N, K)) @@ -5646,10 +6012,7 @@ def test_rocm_c_cache_private(self) -> None: A = Array(role=Role.INPUT, element_type=ScalarType.float32, shape=(M, K), layout=Array.Layout.FIRST_MAJOR) B = Array(role=Role.INPUT, element_type=ScalarType.float32, shape=(K, N), layout=Array.Layout.FIRST_MAJOR) C = Array( - role=Role.INPUT_OUTPUT, - element_type=ScalarType.float32, - shape=(M, N), - layout=Array.Layout.FIRST_MAJOR + role=Role.INPUT_OUTPUT, element_type=ScalarType.float32, shape=(M, N), layout=Array.Layout.FIRST_MAJOR ) nest = Nest(shape=(M, N, K)) @@ -5832,10 +6195,7 @@ def test_rocm_tensorize_fp16(self) -> None: A = Array(role=Role.INPUT, element_type=ScalarType.float16, shape=(M, K), layout=Array.Layout.FIRST_MAJOR) B = Array(role=Role.INPUT, element_type=ScalarType.float16, shape=(K, N), layout=Array.Layout.FIRST_MAJOR) C = Array( - role=Role.INPUT_OUTPUT, - element_type=ScalarType.float32, - shape=(M, N), - layout=Array.Layout.FIRST_MAJOR + role=Role.INPUT_OUTPUT, element_type=ScalarType.float32, shape=(M, N), layout=Array.Layout.FIRST_MAJOR ) nest = Nest(shape=(M, N, K)) @@ -5899,10 +6259,7 @@ def test_rocm_cache_double_buffering_tensorize_fp16(self) -> None: A = Array(role=Role.INPUT, element_type=ScalarType.float16, shape=(M, K), layout=Array.Layout.FIRST_MAJOR) B = Array(role=Role.INPUT, element_type=ScalarType.float16, shape=(K, N), layout=Array.Layout.FIRST_MAJOR) C = Array( - role=Role.INPUT_OUTPUT, - element_type=ScalarType.float32, - shape=(M, N), - layout=Array.Layout.FIRST_MAJOR + role=Role.INPUT_OUTPUT, element_type=ScalarType.float32, shape=(M, N), layout=Array.Layout.FIRST_MAJOR ) nest = Nest(shape=(M, N, K)) @@ -5986,10 +6343,7 @@ def test_rocm_double_buffer_small_cache_vectorized_unvectorized_tensorized(self) A = Array(role=Role.INPUT, element_type=ScalarType.float32, shape=(M, K), layout=Array.Layout.FIRST_MAJOR) B = Array(role=Role.INPUT, element_type=ScalarType.float32, shape=(K, N), layout=Array.Layout.FIRST_MAJOR) C = Array( - role=Role.INPUT_OUTPUT, - element_type=ScalarType.float32, - shape=(M, N), - layout=Array.Layout.FIRST_MAJOR + role=Role.INPUT_OUTPUT, element_type=ScalarType.float32, shape=(M, N), layout=Array.Layout.FIRST_MAJOR ) nest = Nest(shape=(M, N, K)) @@ -6058,7 +6412,6 @@ def _(): package_format=Package.Format.DEFAULT | Package.Format.MLIR ) - def test_loop_erase_hack(self) -> None: # We want to fuse two nests along at least one dimension that only one of them should actually have, but for positioning reasons # it must exist in both. We therefore fuse along all the dimensions and erase the inner unfused loops that we don't actually need @@ -6082,7 +6435,11 @@ def _(): C[i0, j0] += A[i0, k0] * B[k0, j0] schedule0 = nest0.create_schedule() - ii0, jj0, kk0 = schedule0.tile({ i0: M_tile, j0: N_tile, k0: K_tile }) + ii0, jj0, kk0 = schedule0.tile({ + i0: M_tile, + j0: N_tile, + k0: K_tile + }) schedule0.reorder(i0, j0, k0, ii0, jj0, kk0) # Create nest1 and schedule1 @@ -6094,7 +6451,11 @@ def _(): C[i1, j1] = C[i1, j1] * Scalar(0.2) schedule1 = nest1.create_schedule() - ii1, jj1, kk1 = schedule1.tile({ i1: M_tile, j1: N_tile, k1: K_tile }) + ii1, jj1, kk1 = schedule1.tile({ + i1: M_tile, + j1: N_tile, + k1: K_tile + }) schedule1.reorder(i1, j1, k1, ii1, jj1, kk1) schedule = fuse((schedule0, schedule1), partial=3) @@ -6116,11 +6477,12 @@ def test_dynamic_size_redundant_split(self) -> None: split_size = 32 m_extent = Dimension(name='m_extent') - input_arr = Array(role=Role.INPUT, element_type=ScalarType.float32, shape=(m_extent,)) - output_arr = Array(role=Role.INPUT_OUTPUT, element_type=ScalarType.float32, shape=(m_extent,)) + input_arr = Array(role=Role.INPUT, element_type=ScalarType.float32, shape=(m_extent, )) + output_arr = Array(role=Role.INPUT_OUTPUT, element_type=ScalarType.float32, shape=(m_extent, )) - nest = Nest((m_extent,)) + nest = Nest((m_extent, )) i, = nest.get_indices() + @nest.iteration_logic def _(): output_arr[i] += input_arr[i] @@ -6137,8 +6499,8 @@ def _(): fn = package.add(plan, args=(m_extent, input_arr, output_arr), base_name=package_name) M_test = np.int64(67) - input_test = np.random.random((M_test,)).astype(np.float32) - output_test = np.random.random((M_test,)).astype(np.float32) + input_test = np.random.random((M_test, )).astype(np.float32) + output_test = np.random.random((M_test, )).astype(np.float32) correctness_check_values = { "pre": [M_test, input_test, output_test], "post": [M_test, input_test, output_test + input_test], @@ -6160,11 +6522,12 @@ def test_dynamic_size_redundant_split_1(self) -> None: split_size = 1 m_extent = Dimension("m_extent") - input_arr = Array(role=Role.INPUT, element_type=ScalarType.float32, shape=(m_extent,)) - output_arr = Array(role=Role.INPUT_OUTPUT, element_type=ScalarType.float32, shape=(m_extent,)) + input_arr = Array(role=Role.INPUT, element_type=ScalarType.float32, shape=(m_extent, )) + output_arr = Array(role=Role.INPUT_OUTPUT, element_type=ScalarType.float32, shape=(m_extent, )) - nest = Nest((m_extent,)) + nest = Nest((m_extent, )) i, = nest.get_indices() + @nest.iteration_logic def _(): output_arr[i] += input_arr[i] @@ -6181,8 +6544,8 @@ def _(): fn = package.add(plan, args=(m_extent, input_arr, output_arr), base_name=package_name) M_test = np.int64(1) - input_test = np.random.random((M_test,)).astype(np.float32) - output_test = np.random.random((M_test,)).astype(np.float32) + input_test = np.random.random((M_test, )).astype(np.float32) + output_test = np.random.random((M_test, )).astype(np.float32) correctness_check_values = { "pre": [M_test, input_test, output_test], "post": [M_test, input_test, output_test + input_test], @@ -6191,7 +6554,9 @@ def _(): # Build the HAT package output_dir = pathlib.Path(TEST_PACKAGE_DIR) / package_name with verifiers.VerifyPackage(self, package_name, output_dir) as v: - package.build(package_name, format=self.PACKAGE_FORMAT, mode=self.PACKAGE_MODE, output_dir=output_dir, _quiet=False) + package.build( + package_name, format=self.PACKAGE_FORMAT, mode=self.PACKAGE_MODE, output_dir=output_dir, _quiet=False + ) v.check_correctness( fn.name, @@ -6204,11 +6569,12 @@ def test_dynamic_size_split_1(self) -> None: split_size = 1 m_extent = Dimension("m_extent") - input_arr = Array(role=Role.INPUT, element_type=ScalarType.float32, shape=(m_extent,)) - output_arr = Array(role=Role.INPUT_OUTPUT, element_type=ScalarType.float32, shape=(m_extent,)) + input_arr = Array(role=Role.INPUT, element_type=ScalarType.float32, shape=(m_extent, )) + output_arr = Array(role=Role.INPUT_OUTPUT, element_type=ScalarType.float32, shape=(m_extent, )) - nest = Nest((m_extent,)) + nest = Nest((m_extent, )) i, = nest.get_indices() + @nest.iteration_logic def _(): output_arr[i] += input_arr[i] @@ -6224,8 +6590,8 @@ def _(): fn = package.add(plan, args=(m_extent, input_arr, output_arr), base_name=package_name) M_test = np.int64(1) - input_test = np.random.random((M_test,)).astype(np.float32) - output_test = np.random.random((M_test,)).astype(np.float32) + input_test = np.random.random((M_test, )).astype(np.float32) + output_test = np.random.random((M_test, )).astype(np.float32) correctness_check_values = { "pre": [M_test, input_test, output_test], "post": [M_test, input_test, output_test + input_test], @@ -6234,7 +6600,9 @@ def _(): # Build the HAT package output_dir = pathlib.Path(TEST_PACKAGE_DIR) / package_name with verifiers.VerifyPackage(self, package_name, output_dir) as v: - package.build(package_name, format=self.PACKAGE_FORMAT, mode=self.PACKAGE_MODE, output_dir=output_dir, _quiet=False) + package.build( + package_name, format=self.PACKAGE_FORMAT, mode=self.PACKAGE_MODE, output_dir=output_dir, _quiet=False + ) v.check_correctness( fn.name, @@ -6248,11 +6616,12 @@ def test_dynamic_size_split_and_redundant_split_1(self) -> None: inner_split_size = 1 m_extent = Dimension("m_extent") - input_arr = Array(role=Role.INPUT, element_type=ScalarType.float32, shape=(m_extent,)) - output_arr = Array(role=Role.INPUT_OUTPUT, element_type=ScalarType.float32, shape=(m_extent,)) + input_arr = Array(role=Role.INPUT, element_type=ScalarType.float32, shape=(m_extent, )) + output_arr = Array(role=Role.INPUT_OUTPUT, element_type=ScalarType.float32, shape=(m_extent, )) - nest = Nest((m_extent,)) + nest = Nest((m_extent, )) i, = nest.get_indices() + @nest.iteration_logic def _(): output_arr[i] += input_arr[i] @@ -6270,8 +6639,8 @@ def _(): fn = package.add(plan, args=(m_extent, input_arr, output_arr), base_name=package_name) M_test = np.int64(37) - input_test = np.random.random((M_test,)).astype(np.float32) - output_test = np.random.random((M_test,)).astype(np.float32) + input_test = np.random.random((M_test, )).astype(np.float32) + output_test = np.random.random((M_test, )).astype(np.float32) correctness_check_values = { "pre": [M_test, input_test, output_test], "post": [M_test, input_test, output_test + input_test], @@ -6280,7 +6649,9 @@ def _(): # Build the HAT package output_dir = pathlib.Path(TEST_PACKAGE_DIR) / package_name with verifiers.VerifyPackage(self, package_name, output_dir) as v: - package.build(package_name, format=self.PACKAGE_FORMAT, mode=self.PACKAGE_MODE, output_dir=output_dir, _quiet=False) + package.build( + package_name, format=self.PACKAGE_FORMAT, mode=self.PACKAGE_MODE, output_dir=output_dir, _quiet=False + ) v.check_correctness( fn.name, @@ -6292,18 +6663,21 @@ def test_vectorized_masked_buffer_fill(self) -> None: from accera._lang_python._lang import _If N_input = 5 N_output = 8 - Input = Array(role=Role.INPUT, element_type=ScalarType.int32, shape=(N_input,)) - Output = Array(role=Role.INPUT_OUTPUT, element_type=ScalarType.int32, shape=(N_output,)) + Input = Array(role=Role.INPUT, element_type=ScalarType.int32, shape=(N_input, )) + Output = Array(role=Role.INPUT_OUTPUT, element_type=ScalarType.int32, shape=(N_output, )) package = Package() - nest = Nest(shape=(N_output,)) + nest = Nest(shape=(N_output, )) i, = nest.get_indices() @nest.iteration_logic def _nest(): + def store_value(): Output[i] = Input[i] + def store_zero(): Output[i] = 0 + _If(i < N_input, store_value).Else(store_zero) sched = nest.create_schedule() @@ -6314,24 +6688,31 @@ def store_zero(): output_dir = pathlib.Path(TEST_PACKAGE_DIR) / package_name shutil.rmtree(output_dir, ignore_errors=True) with verifiers.VerifyPackage(self, package_name, output_dir) as v: - package.build(name=package_name, format=self.PACKAGE_FORMAT | Package.Format.MLIR_VERBOSE, mode=self.PACKAGE_MODE, output_dir=output_dir) - + package.build( + name=package_name, + format=self.PACKAGE_FORMAT | Package.Format.MLIR_VERBOSE, + mode=self.PACKAGE_MODE, + output_dir=output_dir + ) + def test_vectorized_masked_store(self) -> None: from accera._lang_python._lang import _If N_input = 8 N_output = 5 - Input = Array(role=Role.INPUT, element_type=ScalarType.int32, shape=(N_input,)) - Output = Array(role=Role.INPUT_OUTPUT, element_type=ScalarType.int32, shape=(N_output,)) + Input = Array(role=Role.INPUT, element_type=ScalarType.int32, shape=(N_input, )) + Output = Array(role=Role.INPUT_OUTPUT, element_type=ScalarType.int32, shape=(N_output, )) package = Package() - nest = Nest(shape=(N_input,)) + nest = Nest(shape=(N_input, )) i, = nest.get_indices() @nest.iteration_logic def _nest(): + def store_value(): Output[i] = Input[i] + _If(i < N_output, store_value) - + sched = nest.create_schedule() plan = sched.create_plan() plan.vectorize(i) @@ -6340,22 +6721,29 @@ def store_value(): output_dir = pathlib.Path(TEST_PACKAGE_DIR) / package_name shutil.rmtree(output_dir, ignore_errors=True) with verifiers.VerifyPackage(self, package_name, output_dir) as v: - package.build(name=package_name, format=self.PACKAGE_FORMAT | Package.Format.MLIR_VERBOSE, mode=self.PACKAGE_MODE, output_dir=output_dir) + package.build( + name=package_name, + format=self.PACKAGE_FORMAT | Package.Format.MLIR_VERBOSE, + mode=self.PACKAGE_MODE, + output_dir=output_dir + ) def test_vectorized_masked_accumulate(self) -> None: from accera._lang_python._lang import _If N_input = 8 N_output = 5 - Input = Array(role=Role.INPUT, element_type=ScalarType.int32, shape=(N_input,)) - Output = Array(role=Role.INPUT_OUTPUT, element_type=ScalarType.int32, shape=(N_output,)) + Input = Array(role=Role.INPUT, element_type=ScalarType.int32, shape=(N_input, )) + Output = Array(role=Role.INPUT_OUTPUT, element_type=ScalarType.int32, shape=(N_output, )) package = Package() - nest = Nest(shape=(N_input,)) + nest = Nest(shape=(N_input, )) i, = nest.get_indices() @nest.iteration_logic def _nest(): + def store_value(): Output[i] += Input[i] + _If(i < N_output, store_value) sched = nest.create_schedule() @@ -6366,7 +6754,112 @@ def store_value(): output_dir = pathlib.Path(TEST_PACKAGE_DIR) / package_name shutil.rmtree(output_dir, ignore_errors=True) with verifiers.VerifyPackage(self, package_name, output_dir) as v: - package.build(name=package_name, format=self.PACKAGE_FORMAT | Package.Format.MLIR_VERBOSE, mode=self.PACKAGE_MODE, output_dir=output_dir) + package.build( + name=package_name, + format=self.PACKAGE_FORMAT | Package.Format.MLIR_VERBOSE, + mode=self.PACKAGE_MODE, + output_dir=output_dir + ) + + def test_packing_floordiv_mod_no_splits(self) -> None: + package_name = "test_packing_floordiv_mod_no_splits" + M = 256 + M_inner = 16 + N = 64 + N_inner = 8 + input_shape = (M, N) + packed_shape = (M // M_inner, N // N_inner, M_inner, N_inner) + Input = Array(role=Role.INPUT, element_type=ScalarType.int32, shape=input_shape) + Output = Array(role=Role.INPUT_OUTPUT, element_type=ScalarType.int32, shape=packed_shape) + + package = Package() + nest = Nest(shape=(M, N)) + i, j = nest.get_indices() + + @nest.iteration_logic + def _nest(): + Output[i // M_inner, j // N_inner, i % M_inner, j % N_inner] = Input[i, j] + + sched = nest.create_schedule() + plan = sched.create_plan() + fn = package.add(plan, args=(Input, Output), base_name=package_name) + output_dir = pathlib.Path(TEST_PACKAGE_DIR) / package_name + shutil.rmtree(output_dir, ignore_errors=True) + with verifiers.VerifyPackage(self, package_name, output_dir) as v: + package.build( + name=package_name, + format=self.PACKAGE_FORMAT | Package.Format.MLIR, + mode=self.PACKAGE_MODE, + output_dir=output_dir + ) + + def test_packing_floordiv_mod_splits(self) -> None: + package_name = "test_packing_floordiv_mod_splits" + M = 256 + M_inner = 16 + N = 64 + N_inner = 8 + input_shape = (M, N) + packed_shape = (M // M_inner, N // N_inner, M_inner, N_inner) + Input = Array(role=Role.INPUT, element_type=ScalarType.int32, shape=input_shape) + Output = Array(role=Role.INPUT_OUTPUT, element_type=ScalarType.int32, shape=packed_shape) + + package = Package() + nest = Nest(shape=(M, N)) + i, j = nest.get_indices() + + @nest.iteration_logic + def _nest(): + Output[i // M_inner, j // N_inner, i % M_inner, j % N_inner] = Input[i, j] + + sched = nest.create_schedule() + ii, jj = sched.tile(dict(zip([i, j], [M_inner, N_inner]))) + plan = sched.create_plan() + fn = package.add(plan, args=(Input, Output), base_name=package_name) + output_dir = pathlib.Path(TEST_PACKAGE_DIR) / package_name + shutil.rmtree(output_dir, ignore_errors=True) + with verifiers.VerifyPackage(self, package_name, output_dir) as v: + package.build( + name=package_name, + format=self.PACKAGE_FORMAT | Package.Format.MLIR, + mode=self.PACKAGE_MODE, + output_dir=output_dir + ) + + def test_packing_floordiv_mod_splits_vectorize(self) -> None: + package_name = "test_packing_floordiv_mod_splits_vectorize" + M = 256 + M_inner = 16 + N = 64 + N_inner = 8 + input_shape = (M, N) + packed_shape = (M // M_inner, N // N_inner, M_inner, N_inner) + Input = Array(role=Role.INPUT, element_type=ScalarType.int32, shape=input_shape) + Output = Array(role=Role.INPUT_OUTPUT, element_type=ScalarType.int32, shape=packed_shape) + + package = Package() + nest = Nest(shape=(M, N)) + i, j = nest.get_indices() + + @nest.iteration_logic + def _nest(): + Output[i // M_inner, j // N_inner, i % M_inner, j % N_inner] = Input[i, j] + + sched = nest.create_schedule() + ii, jj = sched.tile(dict(zip([i, j], [M_inner, N_inner]))) + plan = sched.create_plan() + plan.vectorize(jj) + fn = package.add(plan, args=(Input, Output), base_name=package_name) + output_dir = pathlib.Path(TEST_PACKAGE_DIR) / package_name + shutil.rmtree(output_dir, ignore_errors=True) + with verifiers.VerifyPackage(self, package_name, output_dir) as v: + package.build( + name=package_name, + format=self.PACKAGE_FORMAT | Package.Format.MLIR, + mode=self.PACKAGE_MODE, + output_dir=output_dir + ) + if __name__ == '__main__': unittest.main(verbosity=10) diff --git a/accera/transforms/CMakeLists.txt b/accera/transforms/CMakeLists.txt index 52569fed..eca1aab3 100644 --- a/accera/transforms/CMakeLists.txt +++ b/accera/transforms/CMakeLists.txt @@ -70,11 +70,15 @@ set(rcgpu_include ) set(accaffine_src + src/affine/AffineLoopNormalize.cpp src/affine/AffineSimplifications.cpp + src/affine/CheckBoundsPass.cpp ) set(accaffine_include + include/affine/AffineLoopNormalize.h include/affine/AffineSimplifications.h + include/affine/CheckBoundsPass.h ) set(accvec_src diff --git a/accera/transforms/include/AcceraPasses.h b/accera/transforms/include/AcceraPasses.h index 77ac0098..d065115c 100644 --- a/accera/transforms/include/AcceraPasses.h +++ b/accera/transforms/include/AcceraPasses.h @@ -7,6 +7,8 @@ #pragma once #include "affine/AffineSimplifications.h" +#include "affine/AffineLoopNormalize.h" +#include "affine/CheckBoundsPass.h" #include "exec/ExecutionPlanToAffineLoweringPass.h" #include "gpu/AcceraToGPUPass.h" #include "gpu/AcceraVulkanPasses.h" diff --git a/accera/transforms/include/AcceraPasses.td b/accera/transforms/include/AcceraPasses.td index a1edef9b..fdc5f453 100644 --- a/accera/transforms/include/AcceraPasses.td +++ b/accera/transforms/include/AcceraPasses.td @@ -209,6 +209,24 @@ def AcceraAffineSimplification : Pass<"acc-affine-simplify"> { ]; } + +//===----------------------------------------------------------------------===// +// AcceraAffineLoopNormalize +//===----------------------------------------------------------------------===// + +def AcceraAffineLoopNormalize : Pass<"acc-affine-loop-normalize"> { + let summary = "Normalize affine for ops"; + let constructor = "accera::transforms::affine::createAcceraAffineLoopNormalizePass()"; + let dependentDialects = [ + "mlir::StandardOpsDialect", + "mlir::AffineDialect", + "mlir::linalg::LinalgDialect", + "mlir::memref::MemRefDialect", + "mlir::math::MathDialect", + "mlir::scf::SCFDialect", + "mlir::gpu::GPUDialect" + ]; +} //===----------------------------------------------------------------------===// // BarrierOpt //===----------------------------------------------------------------------===// diff --git a/accera/transforms/include/affine/AffineLoopNormalize.h b/accera/transforms/include/affine/AffineLoopNormalize.h new file mode 100644 index 00000000..a80a645d --- /dev/null +++ b/accera/transforms/include/affine/AffineLoopNormalize.h @@ -0,0 +1,20 @@ +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#pragma once + +#include + +namespace mlir +{ +class Pass; +class RewritePatternSet; +using OwningRewritePatternList = RewritePatternSet; +} // namespace mlir + +namespace accera::transforms::affine +{ +std::unique_ptr createAcceraAffineLoopNormalizePass(); +} // namespace accera::transforms::affine diff --git a/accera/transforms/include/affine/CheckBoundsPass.h b/accera/transforms/include/affine/CheckBoundsPass.h new file mode 100644 index 00000000..eb23d766 --- /dev/null +++ b/accera/transforms/include/affine/CheckBoundsPass.h @@ -0,0 +1,31 @@ +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#pragma once + +#include +#include + +namespace mlir +{ +class Pass; +class RewritePatternSet; +} // namespace mlir + +namespace accera::transforms::affine +{ + +// Unit attr name for controlling whether bounds checking has already been performed on an op +const std::string BoundsCheckedAttrName = "accaffine.bounds_checked"; + +// Unit attr name for controlling whether bounds checking is done for ops within a marked op +const std::string AccessBoundsCheckAttrName = "accaffine.access_bounds_check"; + +void populateBoundsCheckingPatterns(mlir::RewritePatternSet& patterns); + +// TODO : implement +// std::unique_ptr createBoundsCheckingPass(); + +} // namespace accera::transforms::affine diff --git a/accera/transforms/src/AcceraPasses.cpp b/accera/transforms/src/AcceraPasses.cpp index ded2358b..e075383d 100644 --- a/accera/transforms/src/AcceraPasses.cpp +++ b/accera/transforms/src/AcceraPasses.cpp @@ -153,6 +153,14 @@ void addAcceraToLLVMPassPipeline(OpPassManager& pm, const AcceraPassPipelineOpti pmAdaptor.addPass(createSymbolDCEPass()); pmAdaptor.addPass(affine::createAffineSimplificationPass()); pmAdaptor.addPass(createCanonicalizerPass()); + pmAdaptor.addPass(createCSEPass()); + pmAdaptor.addPass(affine::createAcceraAffineLoopNormalizePass()); + + pmAdaptor.addPass(createCanonicalizerPass()); + pmAdaptor.addPass(createCSEPass()); + pmAdaptor.addPass(affine::createAffineSimplificationPass()); + pmAdaptor.addPass(createCanonicalizerPass()); + pmAdaptor.addPass(createCSEPass()); pmAdaptor.addPass(vectorization::createVectorizationPass({ options.printVecOpDetails.getValue() })); pmAdaptor.addPass(vectorization::createVectorizationUnrollPass({ options.printVecOpDetails.getValue() })); pmAdaptor.addPass(value::createValueUnrollingPass()); diff --git a/accera/transforms/src/affine/AffineLoopNormalize.cpp b/accera/transforms/src/affine/AffineLoopNormalize.cpp new file mode 100644 index 00000000..d29a9539 --- /dev/null +++ b/accera/transforms/src/affine/AffineLoopNormalize.cpp @@ -0,0 +1,97 @@ +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#include "affine/AffineLoopNormalize.h" + +#include "AcceraPasses.h" + +#include +#include +#include +#include + +#include + +#include +#include + +namespace +{ + +// mlir::normalizeAffineFor has a bug where AffineForOp map bounds are treated as having the same operands. +// To work around this bug, adjust the lower and upper bound maps to have the same operands and adjust +// the upper bound map's indexing accordingly + +// E.g. given +// affine.for %arg5 = affine_map<()[s0] -> (-s0)>()[%arg4] to affine_map<()[s0, s1] -> (s0 - s1)>()[%arg0, %arg3] +// without adjusting, mlir::normalizeAffineFor would turn this into: +// affine.for %arg5 = 0 to affine_map<()[s0, s1, s2] -> (s0 - s1 + s0)>()[%arg0, %0, %1] +// this term should be s2 now, though it *was* s0 before ------^ +// So to work around this bug, we change the for loop maps to be: +// affine.for %arg5 = affine_map<()[s0, s1, s2] -> (-s0)>()[%arg4, %arg0, %arg3] to affine_map<()[s0, s1, s2] -> (s1 - s2)>()[%arg4, %arg0, %arg3] +// Where we simply concatenate operands for each map together, and adjust the indexing used by the upper bound map to account +// for however many dims and symbols the lower bound map had initially +void workaroundModifyAffineForOp(mlir::AffineForOp& op) +{ + mlir::AffineBound lb = op.getLowerBound(); + mlir::AffineBound ub = op.getUpperBound(); + mlir::AffineMap origLbMap = lb.getMap(); + mlir::AffineMap origUbMap = ub.getMap(); + auto lbDimCount = origLbMap.getNumDims(); + auto lbSymCount = origLbMap.getNumSymbols(); + auto ubDimCount = origUbMap.getNumDims(); + auto ubSymCount = origUbMap.getNumSymbols(); + + // Adjust the number of dims and syms to the LB map, but otherwise it doesn't need to change + mlir::MutableAffineMap mutableLbMap(origLbMap); + mutableLbMap.setNumDims(lbDimCount + ubDimCount); + mutableLbMap.setNumSymbols(lbSymCount + ubSymCount); + mlir::AffineMap newLbMap = mutableLbMap.getAffineMap(); + + // Adjust the number of dims and syms to the UB map, and also shift its expressions by the number of lb dims and syms + mlir::AffineMap shiftedUbMap = origUbMap.shiftDims(lbDimCount).shiftSymbols(lbSymCount); + mlir::MutableAffineMap mutableShiftedUbMap(shiftedUbMap); + mutableShiftedUbMap.setNumDims(lbDimCount + ubDimCount); + mutableShiftedUbMap.setNumSymbols(lbSymCount + ubSymCount); + mlir::AffineMap newUbMap = mutableShiftedUbMap.getAffineMap(); + + // interleave [ lbDims..., ubDims..., lbSyms..., ubSyms... ] because dim operands occur before symbol operands when applying a map + std::vector combinedOperands; + combinedOperands.reserve(ub.getNumOperands() + lb.getNumOperands()); + combinedOperands.insert(combinedOperands.end(), lb.operandBegin(), lb.operandBegin() + lbDimCount); + combinedOperands.insert(combinedOperands.end(), ub.operandBegin(), ub.operandBegin() + ubDimCount); + combinedOperands.insert(combinedOperands.end(), lb.operandBegin() + lbDimCount, lb.operandEnd()); + combinedOperands.insert(combinedOperands.end(), ub.operandBegin() + ubDimCount, ub.operandEnd()); + + // Now we have our new maps and operands, so adjust the given for op and return + op.setLowerBound(combinedOperands, newLbMap); + op.setUpperBound(combinedOperands, newUbMap); +} + +struct AcceraAffineLoopNormalizePass : public accera::transforms::AcceraAffineLoopNormalizeBase +{ + // This pass and these patterns only differs from the builtin mlir AffineLoopNormalize pass in that it adjusts the maps on AffineForOps + // before calling into the mlir affine loop normalize function in order to work around a bug in that implementation + void runOnOperation() final + { + auto op = getOperation(); + + // See \mlir\lib\Dialect\Affine\Transforms\AffineLoopNormalize.cpp + op->walk([](mlir::AffineForOp affineForOp) { + workaroundModifyAffineForOp(affineForOp); + mlir::normalizeAffineFor(affineForOp); + }); + } +}; + +} // namespace + +namespace accera::transforms::affine +{ +std::unique_ptr createAcceraAffineLoopNormalizePass() +{ + return std::make_unique(); +} +} // namespace accera::transforms::affine diff --git a/accera/transforms/src/affine/CheckBoundsPass.cpp b/accera/transforms/src/affine/CheckBoundsPass.cpp new file mode 100644 index 00000000..84e18e2f --- /dev/null +++ b/accera/transforms/src/affine/CheckBoundsPass.cpp @@ -0,0 +1,321 @@ +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#include "affine/CheckBoundsPass.h" + +#include "AcceraPasses.h" + +#include +#include +#include + +#include +#include +#include +#include +#include + +using namespace accera; +namespace v = accera::ir::value; + +namespace +{ + +bool IsBoundsChecked(mlir::Operation* op) +{ + return op->getAttr(transforms::affine::BoundsCheckedAttrName) != nullptr; +} + +void SetBoundsChecked(mlir::OpBuilder& builder, mlir::Operation* op) +{ + op->setAttr(transforms::affine::BoundsCheckedAttrName, builder.getUnitAttr()); +} + +template +bool HasOutOfBoundsAccess(LoadOrStoreOp op, mlir::Location loc) +{ + // This is a pared down version of mlir::boundCheckLoadOrStoreOp, which has a bug currently where it only returns failure (out of bounds) + // if the last thing it checks has a failure, rather than anything it checks. + + mlir::MemRefRegion accessRegion(loc); + auto memRefType = op.getMemRefType(); + unsigned rank = memRefType.getRank(); + (void)accessRegion.compute(op, 0, nullptr /*sliceState*/, false /*addMemRefDimBounds */); + bool outOfBounds = false; + + // TODO : handle dynamic dimension out of bounds checks generically + if (!memRefType.hasStaticShape()) + { + return false; + } + + // For each dimension, check for out of bounds. + for (unsigned dim = 0; dim < rank; ++dim) + { + // Intersect memory region with constraint capturing out of bounds (both out + // of upper and out of lower), and check if the constraint system is + // feasible. If it is, there is at least one point out of bounds. + + // Check for overflow: d_i >= memref dim size. + mlir::FlatAffineValueConstraints upperConstraints(*accessRegion.getConstraints()); + int64_t dimSize = memRefType.getDimSize(dim); + upperConstraints.addBound(mlir::FlatAffineConstraints::LB, dim, dimSize); + + // Check for a negative index: d_i <= -1. + mlir::FlatAffineValueConstraints lowerConstraints(*accessRegion.getConstraints()); + lowerConstraints.addBound(mlir::FlatAffineConstraints::UB, dim, -1); + + if (!upperConstraints.isEmpty() || !lowerConstraints.isEmpty()) + { + outOfBounds = true; + break; + } + } + return outOfBounds; +} + +struct OutOfBoundsLoadRewrite : public mlir::OpRewritePattern +{ + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite(mlir::memref::LoadOp loadOp, mlir::PatternRewriter& rewriter) const final; +}; + +struct OutOfBoundsAffineLoadRewrite : public mlir::OpRewritePattern +{ + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite(mlir::AffineLoadOp affineLoadOp, mlir::PatternRewriter& rewriter) const final; +}; + +struct OutOfBoundsStoreRewrite : public mlir::OpRewritePattern +{ + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite(mlir::memref::StoreOp toreOp, mlir::PatternRewriter& rewriter) const final; +}; + +struct OutOfBoundsAffineStoreRewrite : public mlir::OpRewritePattern +{ + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite(mlir::AffineStoreOp affineStoreOp, mlir::PatternRewriter& rewriter) const final; +}; + +mlir::LogicalResult OutOfBoundsLoadRewriteCommon(mlir::AffineLoadOp affineLoadOp, mlir::PatternRewriter& rewriter) +{ + if (IsBoundsChecked(affineLoadOp)) + { + return mlir::success(); + } + auto loc = affineLoadOp.getLoc(); + mlir::AffineLoadOp::Adaptor adaptor{ affineLoadOp }; + + if (HasOutOfBoundsAccess(affineLoadOp, loc)) + { + // This load has a potential out-of-bounds access, so replace it with a conditional load + + auto accessMapAttr = affineLoadOp.getAffineMapAttr(); + auto accessMap = accessMapAttr.getValue(); + auto loadSrc = affineLoadOp.memref(); + auto loadSrcType = loadSrc.getType(); + assert(loadSrcType.isa()); + auto memRefType = loadSrcType.cast(); + + auto loadResultType = affineLoadOp.result().getType(); + + std::vector constraintExprs; + constraintExprs.reserve(accessMap.getNumResults() * 2); // One lower bound and one upper bound check per src dimension + std::vector accessIndices(adaptor.indices().begin(), adaptor.indices().end()); + auto resolvedAccessIndices = ir::util::MultiDimAffineApply(rewriter, loc, accessMap, accessIndices); + mlir::SmallVector constraintEqFlags(accessMap.getNumResults() * 2, false); + for (size_t srcDim = 0; srcDim < accessMap.getNumResults(); srcDim++) + { + // Lower bound check + constraintExprs.push_back(rewriter.getAffineDimExpr(srcDim)); // Will check whether this index is >= 0 + + // Upper bound check + constraintExprs.push_back(memRefType.getDimSize(srcDim) - rewriter.getAffineDimExpr(srcDim) - rewriter.getAffineConstantExpr(1)); // Will check whether (this dimSize - this index - 1) >= 0 (note: -1 since we're doing a >= check with 0-based indices) + } + + std::vector tmpBufferShape{ 1 }; // only one element of type loadResultType + mlir::MemRefType tmpElementType; + std::optional execTargetOpt = ir::util::ResolveExecutionTarget(affineLoadOp); + assert(execTargetOpt.has_value()); + auto execTarget = *execTargetOpt; + mlir::Value tmpBuffer; + if (execTarget == v::ExecutionTarget::GPU) + { + tmpElementType = mlir::MemRefType::get(tmpBufferShape, loadResultType, {}, static_cast(v::MemorySpace::Private)); + tmpBuffer = rewriter.create(loc, tmpElementType, llvm::None); + } + else + { + tmpElementType = mlir::MemRefType::get(tmpBufferShape, loadResultType); + tmpBuffer = rewriter.create(loc, tmpElementType, mlir::ValueRange{}, rewriter.getI64IntegerAttr(ir::executionPlan::AVX2Alignment)); + } + + auto zeroIndex = rewriter.create(loc, 0); + + auto srcBoundsCheckSet = mlir::IntegerSet::get(resolvedAccessIndices.size(), 0, constraintExprs, constraintEqFlags); + auto ifOp = rewriter.create(loc, srcBoundsCheckSet, mlir::ValueRange{ resolvedAccessIndices }, true); // true indicating we want an "else" region + + auto thenBuilder = ifOp.getThenBodyBuilder(); + auto newLoadOp = thenBuilder.create(loc, loadSrc, accessMap, accessIndices); + SetBoundsChecked(thenBuilder, newLoadOp); + + auto thenStoreOp = thenBuilder.create(loc, newLoadOp.getResult(), tmpBuffer, mlir::ValueRange{ zeroIndex }); + SetBoundsChecked(thenBuilder, thenStoreOp); + + auto elseBuilder = ifOp.getElseBodyBuilder(); + // TODO : support user-specified padding value rather than always using 0 + auto constantZero = elseBuilder.create(loc, elseBuilder.getZeroAttr(loadResultType)); + auto elseStoreOp = elseBuilder.create(loc, constantZero.getResult(), tmpBuffer, mlir::ValueRange{ zeroIndex }); + SetBoundsChecked(elseBuilder, elseStoreOp); + + auto tmpSlotLoad = rewriter.create(loc, tmpBuffer, mlir::ValueRange{ zeroIndex }); + SetBoundsChecked(rewriter, tmpSlotLoad); + + affineLoadOp.replaceAllUsesWith(tmpSlotLoad.getResult()); + rewriter.eraseOp(affineLoadOp); + } + + return mlir::success(); +} + +mlir::LogicalResult OutOfBoundsLoadRewrite::matchAndRewrite(mlir::memref::LoadOp loadOp, mlir::PatternRewriter& rewriter) const +{ + // Only check for out-of-bounds-accesses inside of ops that are marked for bounds checking + if (!ir::util::AncestorOpContainsAttrOfName(loadOp, transforms::affine::AccessBoundsCheckAttrName)) + { + return mlir::success(); + } + + if (IsBoundsChecked(loadOp)) + { + return mlir::success(); + } + // Convert std.load to affine.load with an identity map + auto loc = loadOp.getLoc(); + mlir::memref::LoadOp::Adaptor adaptor{ loadOp }; + auto memRefType = adaptor.memref().getType().cast(); + auto affineLoadOp = rewriter.create(loc, adaptor.memref(), rewriter.getMultiDimIdentityMap(memRefType.getRank()), adaptor.indices()); + loadOp.replaceAllUsesWith(affineLoadOp.getResult()); + auto result = OutOfBoundsLoadRewriteCommon(affineLoadOp, rewriter); + rewriter.eraseOp(loadOp); + return result; +} + +mlir::LogicalResult OutOfBoundsAffineLoadRewrite::matchAndRewrite(mlir::AffineLoadOp affineLoadOp, mlir::PatternRewriter& rewriter) const +{ + // Only check for out-of-bounds-accesses inside of ops that are marked for bounds checking + if (!ir::util::AncestorOpContainsAttrOfName(affineLoadOp, transforms::affine::AccessBoundsCheckAttrName)) + { + return mlir::success(); + } + + return OutOfBoundsLoadRewriteCommon(affineLoadOp, rewriter); +} + +mlir::LogicalResult OutOfBoundsStoreRewriteCommon(mlir::AffineStoreOp affineStoreOp, mlir::PatternRewriter& rewriter) +{ + if (IsBoundsChecked(affineStoreOp)) + { + return mlir::success(); + } + + auto loc = affineStoreOp.getLoc(); + mlir::AffineStoreOp::Adaptor adaptor{ affineStoreOp }; + + if (HasOutOfBoundsAccess(affineStoreOp, loc)) + { + // This store has a potential out-of-bounds access, so replace it with a conditional store + + auto accessMapAttr = affineStoreOp.getAffineMapAttr(); + auto accessMap = accessMapAttr.getValue(); + auto storeDst = affineStoreOp.memref(); + auto storeDstType = storeDst.getType(); + assert(storeDstType.isa()); + auto memRefType = storeDstType.cast(); + + // TODO : de-dupe affine.if constraint code with load case + std::vector constraintExprs; + constraintExprs.reserve(accessMap.getNumResults() * 2); // One lower bound and one upper bound check per src dimension + std::vector accessIndices(adaptor.indices().begin(), adaptor.indices().end()); + auto resolvedAccessIndices = ir::util::MultiDimAffineApply(rewriter, loc, accessMap, accessIndices); + mlir::SmallVector constraintEqFlags(accessMap.getNumResults() * 2, false); + for (size_t srcDim = 0; srcDim < accessMap.getNumResults(); srcDim++) + { + // Lower bound check + constraintExprs.push_back(rewriter.getAffineDimExpr(srcDim)); // Will check whether this index is >= 0 + + // Upper bound check + constraintExprs.push_back(memRefType.getDimSize(srcDim) - rewriter.getAffineDimExpr(srcDim) - rewriter.getAffineConstantExpr(1)); // Will check whether (this dimSize - this index - 1) >= 0 (note: -1 since we're doing a >= check with 0-based indices) + } + + auto srcBoundsCheckSet = mlir::IntegerSet::get(resolvedAccessIndices.size(), 0, constraintExprs, constraintEqFlags); + auto ifOp = rewriter.create(loc, srcBoundsCheckSet, mlir::ValueRange{ resolvedAccessIndices }, true); // true indicating we want an "else" region + + auto thenBuilder = ifOp.getThenBodyBuilder(); + auto newStoreOp = thenBuilder.create(loc, affineStoreOp.value(), affineStoreOp.memref(), accessMap, accessIndices); + SetBoundsChecked(thenBuilder, newStoreOp); + + rewriter.eraseOp(affineStoreOp); + } + + return mlir::success(); +} + +mlir::LogicalResult OutOfBoundsStoreRewrite::matchAndRewrite(mlir::memref::StoreOp storeOp, mlir::PatternRewriter& rewriter) const +{ + // Only check for out-of-bounds-accesses inside of ops that are marked for bounds checking + if (!ir::util::AncestorOpContainsAttrOfName(storeOp, transforms::affine::AccessBoundsCheckAttrName)) + { + return mlir::success(); + } + + if (IsBoundsChecked(storeOp)) + { + return mlir::success(); + } + // Convert std.store to affine.store with an identity map + auto loc = storeOp.getLoc(); + mlir::memref::StoreOp::Adaptor adaptor{ storeOp }; + auto memRefType = adaptor.memref().getType().cast(); + auto affineStoreOp = rewriter.create(loc, adaptor.value(), adaptor.memref(), rewriter.getMultiDimIdentityMap(memRefType.getRank()), adaptor.indices()); + auto result = OutOfBoundsStoreRewriteCommon(affineStoreOp, rewriter); + rewriter.eraseOp(storeOp); + return result; +} + +mlir::LogicalResult OutOfBoundsAffineStoreRewrite::matchAndRewrite(mlir::AffineStoreOp affineStoreOp, mlir::PatternRewriter& rewriter) const +{ + // Only check for out-of-bounds-accesses inside of ops that are marked for bounds checking + if (!ir::util::AncestorOpContainsAttrOfName(affineStoreOp, transforms::affine::AccessBoundsCheckAttrName)) + { + return mlir::success(); + } + + return OutOfBoundsStoreRewriteCommon(affineStoreOp, rewriter); +} + +} // namespace + +namespace accera::transforms::affine +{ + + +void populateBoundsCheckingPatterns(mlir::RewritePatternSet& patterns) +{ + patterns.insert(patterns.getContext()); +} + +// TODO : implement +// std::unique_ptr createBoundsCheckingPass(){} + +} \ No newline at end of file diff --git a/accera/transforms/src/exec/ExecutionPlanToAffineLoweringPass.cpp b/accera/transforms/src/exec/ExecutionPlanToAffineLoweringPass.cpp index 14c71481..21989a90 100644 --- a/accera/transforms/src/exec/ExecutionPlanToAffineLoweringPass.cpp +++ b/accera/transforms/src/exec/ExecutionPlanToAffineLoweringPass.cpp @@ -5,7 +5,9 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// #include "exec/ExecutionPlanToAffineLoweringPass.h" + #include "AcceraPasses.h" +#include "affine/CheckBoundsPass.h" #include #include @@ -77,7 +79,6 @@ namespace // in internal utilities as well as MLIR APIs as a mlir::StringRef. Note that mlir::StringRef // has a constructor that takes a const std::string& for convenience -const std::string BoundsCheckedAttrName = "accxp_bounds_checked"; const std::string BaseArrayAccessMapAttrName = "accxp_base_array_access_map"; const std::string BaseArrayAccessIndicesAttrName = "accxp_base_array_access_indices"; @@ -256,33 +257,6 @@ struct HoistScalingToCacheReduceRewrite : public OpRewritePattern -{ - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(mlir::memref::LoadOp loadOp, PatternRewriter& rewriter) const final; -}; - -struct OutOfBoundsAffineLoadRewrite : public OpRewritePattern -{ - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(mlir::AffineLoadOp affineLoadOp, PatternRewriter& rewriter) const final; -}; - -struct OutOfBoundsStoreRewrite : public OpRewritePattern -{ - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(mlir::memref::StoreOp toreOp, PatternRewriter& rewriter) const final; -}; - -struct OutOfBoundsAffineStoreRewrite : public OpRewritePattern -{ - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(mlir::AffineStoreOp affineStoreOp, PatternRewriter& rewriter) const final; -}; struct ConvertLoadsToAffineRewrite : public OpRewritePattern { @@ -511,58 +485,7 @@ std::optional GetDimSizeForBaseIndices(const std::vector& baseIn return dimSize; } -bool IsBoundsChecked(Operation* op) -{ - return op->getAttr(BoundsCheckedAttrName) != nullptr; -} -void SetBoundsChecked(OpBuilder& builder, Operation* op) -{ - op->setAttr(BoundsCheckedAttrName, builder.getUnitAttr()); -} - -template -bool HasOutOfBoundsAccess(LoadOrStoreOp op, mlir::Location loc) -{ - // This is a pared down version of mlir::boundCheckLoadOrStoreOp, which has a bug currently where it only returns failure (out of bounds) - // if the last thing it checks has a failure, rather than anything it checks. - - mlir::MemRefRegion accessRegion(loc); - auto memRefType = op.getMemRefType(); - unsigned rank = memRefType.getRank(); - (void)accessRegion.compute(op, 0, nullptr /*sliceState*/, false /*addMemRefDimBounds */); - bool outOfBounds = false; - - // TODO : handle dynamic dimension out of bounds checks generically - if (!memRefType.hasStaticShape()) - { - return false; - } - - // For each dimension, check for out of bounds. - for (unsigned dim = 0; dim < rank; ++dim) - { - // Intersect memory region with constraint capturing out of bounds (both out - // of upper and out of lower), and check if the constraint system is - // feasible. If it is, there is at least one point out of bounds. - - // Check for overflow: d_i >= memref dim size. - FlatAffineValueConstraints upperConstraints(*accessRegion.getConstraints()); - int64_t dimSize = memRefType.getDimSize(dim); - upperConstraints.addBound(FlatAffineConstraints::LB, dim, dimSize); - - // Check for a negative index: d_i <= -1. - FlatAffineValueConstraints lowerConstraints(*accessRegion.getConstraints()); - lowerConstraints.addBound(FlatAffineConstraints::UB, dim, -1); - - if (!upperConstraints.isEmpty() || !lowerConstraints.isEmpty()) - { - outOfBounds = true; - break; - } - } - return outOfBounds; -} // Returns whether left and right contain the same elements (possibly reordered) template @@ -2533,7 +2456,7 @@ LogicalResult ActiveElementCacheCopyOpRewrite::matchAndRewrite(ActiveElementCach auto copyOrder = copyScheduleOp.getOrder(); for (const auto& loopIndex : copyOrder) { - copyScheduleOp.addLoopAttribute(loopIndex, rewriter.getStringAttr(AccessBoundsCheckAttrName), rewriter.getUnitAttr()); + copyScheduleOp.addLoopAttribute(loopIndex, rewriter.getStringAttr(affine::AccessBoundsCheckAttrName), rewriter.getUnitAttr()); } if (execTarget == v::ExecutionTarget::GPU && !IsMemspaceLocal(dstMemRefSpace)) @@ -3000,7 +2923,7 @@ LogicalResult ActiveBlockCacheCopyOpRewrite::matchAndRewrite(ActiveBlockCacheCop auto copyOrder = copyScheduleOp.getOrder(); for (const auto& loopIndex : copyOrder) { - copyScheduleOp.addLoopAttribute(loopIndex, rewriter.getStringAttr(AccessBoundsCheckAttrName), rewriter.getUnitAttr()); + copyScheduleOp.addLoopAttribute(loopIndex, rewriter.getStringAttr(affine::AccessBoundsCheckAttrName), rewriter.getUnitAttr()); } } } @@ -3179,7 +3102,7 @@ LogicalResult ActiveBlockCacheReduceOpRewrite::matchAndRewrite(ActiveBlockCacheR auto copyOrder = reduceScheduleOp.getOrder(); for (const auto& loopIndex : copyOrder) { - reduceScheduleOp.addLoopAttribute(loopIndex, rewriter.getStringAttr(AccessBoundsCheckAttrName), rewriter.getUnitAttr()); + reduceScheduleOp.addLoopAttribute(loopIndex, rewriter.getStringAttr(affine::AccessBoundsCheckAttrName), rewriter.getUnitAttr()); } } else @@ -3272,7 +3195,7 @@ LogicalResult ActiveElementCacheReduceOpRewrite::matchAndRewrite(ActiveElementCa auto reduceOrder = reduceScheduleOp.getOrder(); for (const auto& loopIndex : reduceOrder) { - reduceScheduleOp.addLoopAttribute(loopIndex, rewriter.getStringAttr(AccessBoundsCheckAttrName), rewriter.getUnitAttr()); + reduceScheduleOp.addLoopAttribute(loopIndex, rewriter.getStringAttr(affine::AccessBoundsCheckAttrName), rewriter.getUnitAttr()); } rewriter.eraseOp(cacheReduceOp); @@ -7023,215 +6946,7 @@ LogicalResult HoistScalingToCacheReduceRewrite::matchAndRewrite(mlir::AffineStor return success(); } -bool AncestorOpContainsAttrOfName(Operation* op, const mlir::StringRef& name) -{ - while (op != nullptr) - { - if (op->getAttr(name) != nullptr) - { - return true; - } - op = op->getParentOp(); - } - return false; -} - -LogicalResult OutOfBoundsLoadRewriteCommon(mlir::AffineLoadOp affineLoadOp, PatternRewriter& rewriter) -{ - if (IsBoundsChecked(affineLoadOp)) - { - return success(); - } - auto loc = affineLoadOp.getLoc(); - mlir::AffineLoadOp::Adaptor adaptor{ affineLoadOp }; - if (HasOutOfBoundsAccess(affineLoadOp, loc)) - { - // This load has a potential out-of-bounds access, so replace it with a conditional load - - auto accessMapAttr = affineLoadOp.getAffineMapAttr(); - auto accessMap = accessMapAttr.getValue(); - auto loadSrc = affineLoadOp.memref(); - auto loadSrcType = loadSrc.getType(); - assert(loadSrcType.isa()); - auto memRefType = loadSrcType.cast(); - - auto loadResultType = affineLoadOp.result().getType(); - - std::vector constraintExprs; - constraintExprs.reserve(accessMap.getNumResults() * 2); // One lower bound and one upper bound check per src dimension - std::vector accessIndices(adaptor.indices().begin(), adaptor.indices().end()); - auto resolvedAccessIndices = util::MultiDimAffineApply(rewriter, loc, accessMap, accessIndices); - SmallVector constraintEqFlags(accessMap.getNumResults() * 2, false); - for (size_t srcDim = 0; srcDim < accessMap.getNumResults(); srcDim++) - { - // Lower bound check - constraintExprs.push_back(rewriter.getAffineDimExpr(srcDim)); // Will check whether this index is >= 0 - - // Upper bound check - constraintExprs.push_back(memRefType.getDimSize(srcDim) - rewriter.getAffineDimExpr(srcDim) - rewriter.getAffineConstantExpr(1)); // Will check whether (this dimSize - this index - 1) >= 0 (note: -1 since we're doing a >= check with 0-based indices) - } - - std::vector tmpBufferShape{ 1 }; // only one element of type loadResultType - mlir::MemRefType tmpElementType; - std::optional execTargetOpt = util::ResolveExecutionTarget(affineLoadOp); - assert(execTargetOpt.has_value()); - auto execTarget = *execTargetOpt; - mlir::Value tmpBuffer; - if (execTarget == v::ExecutionTarget::GPU) - { - tmpElementType = mlir::MemRefType::get(tmpBufferShape, loadResultType, {}, static_cast(v::MemorySpace::Private)); - tmpBuffer = rewriter.create(loc, tmpElementType, llvm::None); - } - else - { - tmpElementType = mlir::MemRefType::get(tmpBufferShape, loadResultType); - tmpBuffer = rewriter.create(loc, tmpElementType, mlir::ValueRange{}, rewriter.getI64IntegerAttr(AVX2Alignment)); - } - - auto zeroIndex = rewriter.create(loc, 0); - - auto srcBoundsCheckSet = mlir::IntegerSet::get(resolvedAccessIndices.size(), 0, constraintExprs, constraintEqFlags); - auto ifOp = rewriter.create(loc, srcBoundsCheckSet, ValueRange{ resolvedAccessIndices }, true); // true indicating we want an "else" region - - auto thenBuilder = ifOp.getThenBodyBuilder(); - auto newLoadOp = thenBuilder.create(loc, loadSrc, accessMap, accessIndices); - SetBoundsChecked(thenBuilder, newLoadOp); - - auto thenStoreOp = thenBuilder.create(loc, newLoadOp.getResult(), tmpBuffer, ValueRange{ zeroIndex }); - SetBoundsChecked(thenBuilder, thenStoreOp); - - auto elseBuilder = ifOp.getElseBodyBuilder(); - // TODO : support user-specified padding value rather than always using 0 - auto constantZero = elseBuilder.create(loc, elseBuilder.getZeroAttr(loadResultType)); - auto elseStoreOp = elseBuilder.create(loc, constantZero.getResult(), tmpBuffer, ValueRange{ zeroIndex }); - SetBoundsChecked(elseBuilder, elseStoreOp); - - auto tmpSlotLoad = rewriter.create(loc, tmpBuffer, ValueRange{ zeroIndex }); - SetBoundsChecked(rewriter, tmpSlotLoad); - - affineLoadOp.replaceAllUsesWith(tmpSlotLoad.getResult()); - rewriter.eraseOp(affineLoadOp); - } - - return success(); -} - -LogicalResult OutOfBoundsLoadRewrite::matchAndRewrite(mlir::memref::LoadOp loadOp, PatternRewriter& rewriter) const -{ - // Only check for out-of-bounds-accesses inside of ops that are marked for bounds checking - if (!AncestorOpContainsAttrOfName(loadOp, AccessBoundsCheckAttrName)) - { - return success(); - } - - if (IsBoundsChecked(loadOp)) - { - return success(); - } - // Convert std.load to affine.load with an identity map - auto loc = loadOp.getLoc(); - mlir::memref::LoadOp::Adaptor adaptor{ loadOp }; - auto memRefType = adaptor.memref().getType().cast(); - auto affineLoadOp = rewriter.create(loc, adaptor.memref(), rewriter.getMultiDimIdentityMap(memRefType.getRank()), adaptor.indices()); - loadOp.replaceAllUsesWith(affineLoadOp.getResult()); - auto result = OutOfBoundsLoadRewriteCommon(affineLoadOp, rewriter); - rewriter.eraseOp(loadOp); - return result; -} - -LogicalResult OutOfBoundsAffineLoadRewrite::matchAndRewrite(mlir::AffineLoadOp affineLoadOp, PatternRewriter& rewriter) const -{ - // Only check for out-of-bounds-accesses inside of ops that are marked for bounds checking - if (!AncestorOpContainsAttrOfName(affineLoadOp, AccessBoundsCheckAttrName)) - { - return success(); - } - - return OutOfBoundsLoadRewriteCommon(affineLoadOp, rewriter); -} - -LogicalResult OutOfBoundsStoreRewriteCommon(mlir::AffineStoreOp affineStoreOp, PatternRewriter& rewriter) -{ - if (IsBoundsChecked(affineStoreOp)) - { - return success(); - } - - auto loc = affineStoreOp.getLoc(); - mlir::AffineStoreOp::Adaptor adaptor{ affineStoreOp }; - - if (HasOutOfBoundsAccess(affineStoreOp, loc)) - { - // This store has a potential out-of-bounds access, so replace it with a conditional store - - auto accessMapAttr = affineStoreOp.getAffineMapAttr(); - auto accessMap = accessMapAttr.getValue(); - auto storeDst = affineStoreOp.memref(); - auto storeDstType = storeDst.getType(); - assert(storeDstType.isa()); - auto memRefType = storeDstType.cast(); - - // TODO : de-dupe affine.if constraint code with load case - std::vector constraintExprs; - constraintExprs.reserve(accessMap.getNumResults() * 2); // One lower bound and one upper bound check per src dimension - std::vector accessIndices(adaptor.indices().begin(), adaptor.indices().end()); - auto resolvedAccessIndices = util::MultiDimAffineApply(rewriter, loc, accessMap, accessIndices); - SmallVector constraintEqFlags(accessMap.getNumResults() * 2, false); - for (size_t srcDim = 0; srcDim < accessMap.getNumResults(); srcDim++) - { - // Lower bound check - constraintExprs.push_back(rewriter.getAffineDimExpr(srcDim)); // Will check whether this index is >= 0 - - // Upper bound check - constraintExprs.push_back(memRefType.getDimSize(srcDim) - rewriter.getAffineDimExpr(srcDim) - rewriter.getAffineConstantExpr(1)); // Will check whether (this dimSize - this index - 1) >= 0 (note: -1 since we're doing a >= check with 0-based indices) - } - - auto srcBoundsCheckSet = mlir::IntegerSet::get(resolvedAccessIndices.size(), 0, constraintExprs, constraintEqFlags); - auto ifOp = rewriter.create(loc, srcBoundsCheckSet, ValueRange{ resolvedAccessIndices }, true); // true indicating we want an "else" region - - auto thenBuilder = ifOp.getThenBodyBuilder(); - auto newStoreOp = thenBuilder.create(loc, affineStoreOp.value(), affineStoreOp.memref(), accessMap, accessIndices); - SetBoundsChecked(thenBuilder, newStoreOp); - - rewriter.eraseOp(affineStoreOp); - } - - return success(); -} - -LogicalResult OutOfBoundsStoreRewrite::matchAndRewrite(mlir::memref::StoreOp storeOp, PatternRewriter& rewriter) const -{ - // Only check for out-of-bounds-accesses inside of ops that are marked for bounds checking - if (!AncestorOpContainsAttrOfName(storeOp, AccessBoundsCheckAttrName)) - { - return success(); - } - - if (IsBoundsChecked(storeOp)) - { - return success(); - } - // Convert std.store to affine.store with an identity map - auto loc = storeOp.getLoc(); - mlir::memref::StoreOp::Adaptor adaptor{ storeOp }; - auto memRefType = adaptor.memref().getType().cast(); - auto affineStoreOp = rewriter.create(loc, adaptor.value(), adaptor.memref(), rewriter.getMultiDimIdentityMap(memRefType.getRank()), adaptor.indices()); - auto result = OutOfBoundsStoreRewriteCommon(affineStoreOp, rewriter); - rewriter.eraseOp(storeOp); - return result; -} - -LogicalResult OutOfBoundsAffineStoreRewrite::matchAndRewrite(mlir::AffineStoreOp affineStoreOp, PatternRewriter& rewriter) const -{ - // Only check for out-of-bounds-accesses inside of ops that are marked for bounds checking - if (!AncestorOpContainsAttrOfName(affineStoreOp, AccessBoundsCheckAttrName)) - { - return success(); - } - - return OutOfBoundsStoreRewriteCommon(affineStoreOp, rewriter); -} template LogicalResult ConvertStoreToAffine(PatternRewriter& rewriter, OpType op) @@ -7469,7 +7184,7 @@ void OutOfBoundsAccessHandlingPass::runOnOperation() ConversionTarget target(getContext()); RewritePatternSet patterns(&getContext()); - accera::transforms::executionPlan::populateOutOfBoundsAccessHandlingPatterns(patterns); + accera::transforms::affine::populateBoundsCheckingPatterns(patterns); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); @@ -7601,14 +7316,6 @@ void populateExecutionPlanScaleHoistingPatterns(mlir::RewritePatternSet& pattern patterns.insert(patterns.getContext()); } -void populateOutOfBoundsAccessHandlingPatterns(mlir::RewritePatternSet& patterns) -{ - patterns.insert(patterns.getContext()); -} - void populateConvergeLoadStoresPatterns(mlir::RewritePatternSet& patterns) { patterns.insert PatternRewriter& rewriter) const override; }; +using ValueVHADDOp = vir::vhadd; +struct VhaddLowering : public OpRewritePattern +{ + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite( + ValueVHADDOp op, + PatternRewriter& rewriter) const override; +}; + using vir::EnterProfileRegionOp; using vir::ExitProfileRegionOp; using vir::PrintProfileResultsOp; @@ -2630,6 +2640,67 @@ LogicalResult PrintOpLowering::matchAndRewrite( return success(); } +LogicalResult VhaddLowering::matchAndRewrite(ValueVHADDOp op, PatternRewriter& rewriter) const +{ + // vphaddd(ymm0, ymm1) -> [ymm0[0]+ymm0[1], ymm0[2]+ymm0[3], + // ymm1[0]+ymm1[1], ymm1[2]+ymm1[3], + // ymm0[4]+ymm0[5], ymm0[6]+ymm0[7], + // ymm1[4]+ymm1[5], ymm1[6]+ymm1[7]] + // this is equivalent to + // tmp0 = shuffle ymm0, ymm1 : [0, 2, 8, 10, 4, 6, 12, 14] + // tmp1 = shuffle ymm0, ymm1 : [1, 3, 9, 11, 5, 7, 13, 15] + // res = tmp0 + tmp1 + // (And similarly for other element types within 256-bit vectors) + + auto loc = op.getLoc(); + + auto lhs = op.lhs(); + auto rhs = op.rhs(); + + auto vecType = lhs.getType().cast(); + auto rank = vecType.getRank(); + assert(rank == 1 && "vhadd only supports rank-1 vectors"); + auto elementType = vecType.getElementType(); + auto elementCount = vecType.getNumElements(); + auto bitwidth = elementType.getIntOrFloatBitWidth(); + assert((bitwidth == 16 || bitwidth == 32 || bitwidth == 64) && "vhadd only supports 16-bit, 32-bit, and 64-bit elements"); + + auto elementsPerGroup = 64 / bitwidth; + const size_t groups = static_cast(elementCount * bitwidth) / 64; // 2 groups for 128 bits, 4 groups for 256 bits + + // Group 0 Group 1 Group 2 Group 3 + // 64-bit: [ 0, 4, 2, 6 ] + // 32-bit: [ 0, 2, 8, 10, 4, 6, 12, 14 ] + // 16-bit: [ 0, 2, 4, 6, 16, 18, 20, 22, 8, 10, 12, 14, 24, 26, 28, 30 ] + std::vector evensMask; + evensMask.reserve(elementsPerGroup * groups); + + // Group 0 Group 1 Group 2 Group 3 + // 64-bit: [ 1, 5, 3, 7 ] + // 32-bit: [ 1, 3, 9, 11, 5, 7, 13, 15 ] + // 16-bit: [ 1, 3, 5, 7, 17, 19, 21, 23, 9, 11, 13, 15, 25, 27, 29, 31 ] + std::vector oddsMask; + oddsMask.reserve(elementsPerGroup * groups); + + for (size_t groupIdx = 0; groupIdx < groups; ++groupIdx) + { + auto groupOffset = (groupIdx % 2) * elementCount; // Every odd group is from the second vector, and so is offset by a vectors-worth of elements + auto offsetInVector = (groupIdx / 2) * elementsPerGroup; // Each pair of groups is from the same section of their respective original vectors + for (size_t elementIdx = 0; elementIdx < elementsPerGroup; ++elementIdx) + { + evensMask.push_back(static_cast(groupOffset + (elementIdx + offsetInVector) * 2)); + oddsMask.push_back(static_cast(groupOffset + (elementIdx + offsetInVector) * 2 + 1)); + } + } + + auto evenShuffleOp = rewriter.create(loc, vecType, lhs, rhs, rewriter.getI64ArrayAttr(evensMask)); + auto oddShuffleOp = rewriter.create(loc, vecType, lhs, rhs, rewriter.getI64ArrayAttr(oddsMask)); + + rewriter.replaceOpWithNewOp(op, vir::BinaryOpPredicate::ADD, evenShuffleOp, oddShuffleOp); + + return success(); +} + void ValueToStdLoweringPass::runOnModule() { auto module = getOperation(); @@ -2764,7 +2835,8 @@ void populateValueToStandardPatterns(bool enableProfiling, mlir::RewritePatternS StoreOpLowering, TerminatorLowering, UnaryOpLowering, - ViewOpLowering>(context); + ViewOpLowering, + VhaddLowering>(context); patterns.insert #include +#include +#include +#include +#include +#include +#include +#include #include #include @@ -28,6 +35,7 @@ #include #include #include +#include #include #include @@ -41,6 +49,19 @@ namespace v = accera::ir::value; // TODO : plumb through a sufficient target enum / bitmap so we can dynamically enable/disable vpmaddwd and other pattern matchers #define MATCH_VPMADDWD_INTRINSIC 1 +namespace +{ +mlir::LogicalResult reportMatchOpFailure(mlir::Operation* op, const std::string& message, const std::string& tag = "") +{ + if (!tag.empty()) + { + llvm::dbgs() << "[" << tag << "] "; + } + llvm::dbgs() << "While processing " << *op << ". " << message << "\n"; + return mlir::failure(); +} +} // namespace + namespace accera::transforms { @@ -1100,6 +1121,19 @@ std::optional VectorizeReferenceGlobalOp(mlir::PatternRewriter return clonedOp; } +std::optional VectorizeMemrefReinterpretCastGlobalOp(mlir::PatternRewriter& rewriter, + mlir::memref::ReinterpretCastOp op, + const VectorizedOpMap& vectorizedOps, + std::vector& laneMappings, + mlir::Value inductionVar, + int64_t step, + int64_t vectorSize) +{ + // This is just a passthrough -- just clone it + auto clonedOp = rewriter.clone(*op); + return clonedOp; +} + std::optional VectorizeOp(mlir::PatternRewriter& rewriter, mlir::Operation* op, const VectorizedOpMap& vectorizedOps, @@ -1169,6 +1203,9 @@ std::optional VectorizeOp(mlir::PatternRewriter& rewriter, .Case([&](v::ReferenceGlobalOp refGlobalOp) { return VectorizeReferenceGlobalOp(rewriter, refGlobalOp, vectorizedOps, laneMappings, inductionVar, step, vectorSize); }) + .Case([&](mlir::memref::ReinterpretCastOp reinterpretCastOp) { + return VectorizeMemrefReinterpretCastGlobalOp(rewriter, reinterpretCastOp, vectorizedOps, laneMappings, inductionVar, step, vectorSize); + }) .Default([&](mlir::Operation* defaultOp) -> std::optional { if (op->getNumResults() > 0) { @@ -1186,6 +1223,497 @@ std::optional VectorizeOp(mlir::PatternRewriter& rewriter, return resultOp; } +std::vector CreateLoopIterationMappings(mlir::OpBuilder& builder, mlir::Location loc, mlir::Value iv, int64_t begin, int64_t end, int64_t step, std::stack* tempOps = nullptr) +{ + auto iterCount = (end - begin) / step; + std::vector mappings(iterCount); + for (int64_t idx = begin; idx < end; idx += step) + { + auto offsetMap = mlir::AffineMap::get(1, 0, builder.getAffineDimExpr(0) + (idx * step)); + auto offsetInductionVar = builder.create(loc, offsetMap, ValueRange{ iv }); + if (tempOps) + { + tempOps->push(offsetInductionVar); + } + mappings[idx].map(iv, offsetInductionVar); + } + return mappings; +} + +std::vector CreateLoopIterationMappings(mlir::OpBuilder& builder, mlir::Location loc, mlir::AffineForOp forOp, std::stack* tempOps = nullptr) +{ + int64_t begin = forOp.getConstantLowerBound(); + int64_t end = forOp.getConstantUpperBound(); + int64_t step = forOp.getStep(); + auto iv = forOp.getInductionVar(); + + return CreateLoopIterationMappings(builder, loc, iv, begin, end, step, tempOps); +} + +LogicalResult MatchOptionalCast(mlir::Block::iterator& loopBodyIter, mlir::Value expectedInputVal, mlir::Type& castType, v::CastOp& castOp, mlir::Value& castResultVal, std::stack& matchedOps) +{ + if (isa(*loopBodyIter)) + { + castOp = cast(*loopBodyIter++); + if (castOp.source() != expectedInputVal) + { + return failure(); + } + auto castedValue = castOp.result(); + castType = castedValue.getType(); + castResultVal = castedValue; + matchedOps.push(castOp); + } + return success(); +} + +template +LogicalResult CheckOptionalRedundantLoadStore(mlir::PatternRewriter& rewriter, mlir::Block::iterator& loopBodyIter, mlir::Block::iterator& loopBodyEnd, std::stack& matchedOps, FailureFn&& reportMatchFailure) +{ + if (loopBodyIter != loopBodyEnd && isa(*loopBodyIter)) + { + auto loadOp = cast(*loopBodyIter++); + matchedOps.push(loadOp); + if (loopBodyIter != loopBodyEnd && isa(*loopBodyIter)) + { + auto storeOp = cast(*loopBodyIter++); + if (storeOp.getMemRef() != loadOp.getMemRef()) + { + return reportMatchFailure(storeOp, "Extraneous load/store aren't to the same memref"); + } + + auto strideOpt = GetConstantStrideBetweenAccesses(rewriter, loadOp, storeOp); + if (!strideOpt.has_value() || *strideOpt != 0) + { + return reportMatchFailure(storeOp, "Extraneous load/store aren't to the same location"); + } + + matchedOps.push(storeOp); + } + else + { + return reportMatchFailure(loadOp, "Failed to match extraneous store"); + } + } + return success(); +} + +std::pair GetLowHighSeparately(mlir::OpBuilder& builder, mlir::Location loc, mlir::Value vec) +{ + auto reductionVecType = vec.getType().cast(); + assert(reductionVecType.getRank() == 1 && "Can only separate low and high halfs of a 1-D vector"); + auto vecSize = reductionVecType.getNumElements(); + assert(vecSize % 2 == 0 && "Can only separate low and high halfs of a vector with an even number of elements"); + auto halfVecType = mlir::VectorType::get({ vecSize / 2 }, reductionVecType.getElementType()); + std::vector extractLowMask(vecSize / 2); + std::vector extractHighMask(vecSize / 2); + std::iota(extractLowMask.begin(), extractLowMask.end(), 0); // { 0, 1, 2, 3 } + std::iota(extractHighMask.begin(), extractHighMask.end(), vecSize / 2); // { 4, 5, 6, 7 } + mlir::Value low = builder.create(loc, halfVecType, vec, vec, builder.getI64ArrayAttr(extractLowMask)); + mlir::Value high = builder.create(loc, halfVecType, vec, vec, builder.getI64ArrayAttr(extractHighMask)); + return std::make_pair(low, high); +} + +mlir::LogicalResult vectorize2DHorizontalSumReduction(mlir::AffineForOp affineForOp, mlir::PatternRewriter& rewriter) +{ + // Special case for a 2-loop vectorization of a 4x16 i16 horizontal sum reduction to a 4x1 i32 + // Try to match a pattern like: + // for indices + // for i = 0...4: + // for j = 0...16: + // x = load(input[..., i, j]) : memref -> T1 + // y = load(output[..., i]) : memref (doesn't depend on j) -> T1 + // z = x + y + // store(z, output[..., i]) : (same position as load) + + // And replace it with a precise sequence of vector ops that will reduce the 4 lanes simultaneously + // Copied from onnxruntime for 4x16 i16 -> i32 case + // vpmaddwd ymm0,ymm0,ymm8 # horizontal word+word=dword per row + // vpmaddwd ymm1,ymm1,ymm8 + // vphaddd ymm0,ymm0,ymm1 # reduce and interleave Sum1/Sum0 + // vpmaddwd ymm2,ymm2,ymm8 + // vpmaddwd ymm3,ymm3,ymm8 + // vphaddd ymm1,ymm2,ymm3 # reduce and interleave Sum3/Sum2 + // vphaddd ymm0,ymm0,ymm1 # reduce and interleave Sum3/Sum2/Sum1/Sum0 + // vextracti128 xmm1,ymm0,1 # extract high dwords + // vpaddd xmm0,xmm0,xmm1 # reduce low/high dwords + // vmovdqu XMMWORD PTR [r9],xmm0 + + // The 4x8 i32 case should just be a subset of these instructions: + // vphaddd ymm0,ymm0,ymm1 + // vphaddd ymm1,ymm2,ymm3 + // vphaddd ymm0,ymm0,ymm1 + // vextracti128 xmm1,ymm0,1 + // vpaddd xmm0,xmm0,xmm1 + // vmovdqu XMMWORD PTR [r9],xmm0 + + // The 4x8 f32 case should be the same pattern, but with slightly different assembly instructions (but the same MLIR ops): + // vhaddps ymm0,ymm0,ymm1 + // vhaddps ymm1,ymm2,ymm3 + // vhaddps ymm0,ymm0,ymm1 + // vextractf128 xmm1,ymm0,1 + // vaddps xmm0,xmm0,xmm1 + // vmovups XMMWORD PTR [r9],xmm0 + + // Implement the matcher + auto reportMatchFailure = [&](mlir::Operation* op, std::string message) -> LogicalResult { + return reportMatchOpFailure(op, message, "vectorize2DHorizontalSumReduction"); + }; + + auto avx2Support = ir::util::ModuleSupportsTargetDeviceFeature(affineForOp, "avx2"); + if (!avx2Support) + { + // We're going to use the vpmaddwd instruction directly in this vectorization pattern and + // the vpmaddwd instruction is only supported on machines with the AVX2 or AVX512 instruction set extensions + return reportMatchFailure(affineForOp, "Target device does not support vpmaddwd instruction"); + } + + std::stack matchedOps; + std::stack tempOps; + ir::util::TempOpCleanupGuard tempOpGuard(&tempOps, rewriter); + + auto loc = affineForOp.getLoc(); + + SmallVector loops; + mlir::getPerfectlyNestedLoops(loops, affineForOp); + if (loops.size() != 2) // there should be exactly 2 loops in the nest being vectorized + { + return failure(); + } + + for (auto& loop : loops) + { + if (!loop.hasConstantBounds() || loop.getConstantLowerBound() != 0) + { + return failure(); + } + } + + auto outerLoop = loops.front(); // jj loop + int64_t outerLoopBegin = outerLoop.getConstantLowerBound(); + int64_t outerLoopEnd = outerLoop.getConstantUpperBound(); + int64_t outerLoopStep = outerLoop.getStep(); + int64_t outerLoopNumIters = (outerLoopEnd - outerLoopBegin) / outerLoopStep; + + if (outerLoopNumIters != 4) + { + return failure(); + } + + auto innerLoop = loops.back(); // the innermost loop, kk + int64_t innerLoopBegin = innerLoop.getConstantLowerBound(); + int64_t innerLoopEnd = innerLoop.getConstantUpperBound(); + int64_t innerLoopStep = innerLoop.getStep(); + int64_t innerLoopNumIters = (innerLoopEnd - innerLoopBegin) / innerLoopStep; + + std::vector supportedBaseInputElementTypes; + std::vector supportedCastInputElementTypes{ rewriter.getI32Type(), rewriter.getF32Type() }; + bool reduce16To8 = false; + if (innerLoopNumIters == 16) + { + supportedBaseInputElementTypes.push_back(rewriter.getIntegerType(16)); + reduce16To8 = true; + } + else if (innerLoopNumIters == 8) + { + supportedBaseInputElementTypes.push_back(rewriter.getI32Type()); + supportedBaseInputElementTypes.push_back(rewriter.getF32Type()); + } + else + { + return reportMatchFailure(affineForOp, "Unsupported number of inner loop iterations"); + } + + auto isInputTypeSupported = [&supportedBaseInputElementTypes, &supportedCastInputElementTypes](const mlir::Type& type, bool baseInputType) { + if (baseInputType) + return std::find(supportedBaseInputElementTypes.begin(), supportedBaseInputElementTypes.end(), type) != supportedBaseInputElementTypes.end(); + else + return std::find(supportedCastInputElementTypes.begin(), supportedCastInputElementTypes.end(), type) != supportedCastInputElementTypes.end(); + }; + + // iterate on loop body from begin to end to match the ops list + auto loopBodyIter = innerLoop.getBody()->begin(); + auto loopBodyEnd = innerLoop.getBody()->end(); + + auto vectorSize = innerLoopNumIters; + auto unrollMax = vectorSize; + + // Set up sequential mappings for the inner loop + std::vector innerLoopLaneMappings = CreateLoopIterationMappings(rewriter, loc, innerLoop, &tempOps); + + // 1. load from lhs array + if (loopBodyIter == loopBodyEnd || !isa(*loopBodyIter)) + { + return reportMatchFailure(affineForOp, "Failed to match the lhs load op"); + } + + auto lhsLoadOp = cast(*loopBodyIter++); + auto lhsLoadVal = lhsLoadOp.getResult(); // Keep the loaded val separate from the current lhs val for mapping later + auto lhsVal = lhsLoadVal; + matchedOps.push(lhsLoadOp); + + bool lhsLoadIsLoopSequential = IsUnrolledAccessSequential(rewriter, lhsLoadOp, innerLoopLaneMappings, unrollMax); + bool lhsLoadIsLoopConstant = IsUnrolledAccessConstant(rewriter, lhsLoadOp, innerLoopLaneMappings, unrollMax); + + // 1a. (optional) cast + v::CastOp lhsLoadCastOp; + mlir::Type lhsCastType; + if (failed(MatchOptionalCast(loopBodyIter, lhsVal, lhsCastType, lhsLoadCastOp, lhsVal, matchedOps))) + { + return reportMatchFailure(lhsLoadCastOp, "Cast after lhs load isn't casting the loaded value"); + } + + // 2. load from rhs array + if (loopBodyIter == loopBodyEnd || !isa(*loopBodyIter)) + { + return reportMatchFailure(affineForOp, "Failed to match the rhs load op"); + } + + auto rhsLoadOp = cast(*loopBodyIter++); + auto rhsLoadVal = rhsLoadOp.getResult(); + auto rhsVal = rhsLoadVal; + matchedOps.push(rhsLoadOp); + + bool rhsLoadIsLoopSequential = IsUnrolledAccessSequential(rewriter, rhsLoadOp, innerLoopLaneMappings, unrollMax); + bool rhsLoadIsLoopConstant = IsUnrolledAccessConstant(rewriter, rhsLoadOp, innerLoopLaneMappings, unrollMax); + + // 2a. (optional) cast + v::CastOp rhsLoadCastOp; + mlir::Type rhsCastType; + if (failed(MatchOptionalCast(loopBodyIter, rhsVal, rhsCastType, rhsLoadCastOp, rhsVal, matchedOps))) + { + return reportMatchFailure(rhsLoadCastOp, "Cast after rhs load isn't casting the loaded value"); + } + + // 3. bin op + if (loopBodyIter == loopBodyEnd || !isa(*loopBodyIter)) + { + return reportMatchFailure(affineForOp, "Failed to match the bin op"); + } + auto binOp = cast(*loopBodyIter++); + if (binOp.getPredicate() != v::BinaryOpPredicate::ADD) + { + // The specific instructions this lowering produces only works for add reductions + return reportMatchFailure(binOp, "Bin op isn't an add"); + } + auto binOpVal = binOp.getResult(); + bool lhsRhsLineUp = (binOp.lhs() == lhsVal) && (binOp.rhs() == rhsVal); + bool lhsRhsSwap = (binOp.lhs() == rhsVal) && (binOp.rhs() == lhsVal); + if (!lhsRhsLineUp && !lhsRhsSwap) + { + return reportMatchFailure(affineForOp, "Bin op isn't using loaded lhs and rhs values"); + } + matchedOps.push(binOp); + + // 4. store to output array + if (loopBodyIter == loopBodyEnd || !isa(*loopBodyIter)) + { + return reportMatchFailure(affineForOp, "Failed to match the store op"); + } + + auto storeOp = cast(*loopBodyIter++); + auto storedVal = storeOp.value(); + matchedOps.push(storeOp); + + // Check that the value being stored is the result of the BinOp + if (storedVal != binOpVal) + { + return reportMatchFailure(storeOp, "Store op isn't storing the result of the bin op"); + } + + // Check that store is constant wrt to the inner loop + bool storeIsLoopConstant = IsUnrolledAccessConstant(rewriter, storeOp, innerLoopLaneMappings, unrollMax); + if (!storeIsLoopConstant) + { + return reportMatchFailure(storeOp, "Store op isn't constant wrt the inner loop being vectorized"); + } + + // Check which load is sequential wrt the loop and which is constant and which one is being stored to + + mlir::AffineLoadOp outputLoadOp; + mlir::AffineLoadOp reductionLoadOp; + v::CastOp outputCastOp; + v::CastOp reductionCastOp; + if (storeOp.getMemRef() == lhsLoadOp.getMemRef()) + { + if (!lhsLoadIsLoopConstant) + { + return reportMatchFailure(lhsLoadOp, "LHS load op isn't constant wrt the loop being vectorized but is the same memref being stored to"); + } + if (!rhsLoadIsLoopSequential) + { + return reportMatchFailure(rhsLoadOp, "RHS load op isn't sequential when LHS load is constant"); + } + outputLoadOp = lhsLoadOp; + outputCastOp = lhsLoadCastOp; + + reductionLoadOp = rhsLoadOp; + reductionCastOp = rhsLoadCastOp; + } + else if (storeOp.getMemRef() == rhsLoadOp.getMemRef()) + { + if (!rhsLoadIsLoopConstant) + { + return reportMatchFailure(rhsLoadOp, "RHS load op isn't constant wrt the loop being vectorized but is the same memref being stored to"); + } + if (!lhsLoadIsLoopSequential) + { + return reportMatchFailure(lhsLoadOp, "LHS load op isn't sequential when RHS load is constant"); + } + outputLoadOp = rhsLoadOp; + outputCastOp = rhsLoadCastOp; + + reductionLoadOp = lhsLoadOp; + reductionCastOp = lhsLoadCastOp; + } + else + { + return reportMatchFailure(storeOp, "Store op isn't storing to the same memref as either load"); + } + + // Check that the output load and store are at the same position + + auto strideOpt = GetConstantStrideBetweenAccesses(rewriter, outputLoadOp, storeOp); + if (!strideOpt.has_value() || *strideOpt != 0) + { + return reportMatchFailure(storeOp, "Output load and store ops aren't at the same location"); + } + + // Check that the reduction memref has the right element type and that the store op memref has the right element type + auto inputReductionElementType = reductionLoadOp.getMemRefType().getElementType(); + if (!isInputTypeSupported(inputReductionElementType, true /* baseInput */)) + { + return reportMatchFailure(reductionLoadOp, "Reduction load op is not from a supported memref element type"); + } + if (!isInputTypeSupported(outputLoadOp.getMemRefType().getElementType(), false /* not baseInput */)) + { + return reportMatchFailure(outputLoadOp, "Output load op is not from a supported memref element type"); + } + + // Check that all that remains are optionally redundant load-stores and the yield op + + // match the final pair of redundant load and store ops + auto redundantLoadStoreResult = CheckOptionalRedundantLoadStore(rewriter, loopBodyIter, loopBodyEnd, matchedOps, reportMatchFailure); + if (failed(redundantLoadStoreResult)) + { + return redundantLoadStoreResult; + } + + // Ignore the yield op at the end + if (loopBodyIter != loopBodyEnd && isa(*loopBodyIter)) + { + (void)loopBodyIter++; + } + + if (loopBodyIter != loopBodyEnd) + { + LLVM_DEBUG(llvm::dbgs() << "Found additional instructions after the store"); + return failure(); + } + + // Now replace the matched ops with the multi vector load and parallel reduction sequence + + // Set the insertion point inside the loop so we can safely refer to loop induction vars + mlir::OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(innerLoop.getBody(), innerLoop.getBody()->getTerminator()->getIterator()); + + // Set up sequential mappings for the outer loop + std::vector outerLoopLaneMappings = CreateLoopIterationMappings(rewriter, loc, outerLoop); + + // Unroll and vectorize the reduction operand loads wrt the outer loop + std::vector origVectorValsToReduce; + std::vector currentVectorValsToReduce; + origVectorValsToReduce.reserve(outerLoopNumIters); + currentVectorValsToReduce.reserve(outerLoopNumIters); + for (auto& mapping : outerLoopLaneMappings) + { + // Handle LHS load + auto outerUnrollsReductionLoadOp = mlir::cast(rewriter.clone(*reductionLoadOp.getOperation(), mapping)); + tempOps.push(outerUnrollsReductionLoadOp); + auto vecReductionLoadOp = VectorizeAffineLoadOpHelper(rewriter, outerUnrollsReductionLoadOp, vectorSize); + mlir::Value origVecReductionLoadOpVal = vecReductionLoadOp.getResult(); + mlir::Value vecReductionLoadOpVal = origVecReductionLoadOpVal; + if (reductionCastOp) + { + auto castVecType = mlir::VectorType::get({ vectorSize }, reductionCastOp.getResult().getType()); + vecReductionLoadOpVal = rewriter.create(lhsLoadCastOp.getLoc(), vecReductionLoadOpVal, castVecType); + } + origVectorValsToReduce.push_back(origVecReductionLoadOpVal); + currentVectorValsToReduce.push_back(vecReductionLoadOpVal); + } + + // Now create the vectorized reduction following: + // vpmaddwd ymm0,ymm0,ymm8 # horizontal word+word=dword per row + // vpmaddwd ymm1,ymm1,ymm8 + // vphaddd ymm0,ymm0,ymm1 # reduce and interleave Sum1/Sum0 + // vpmaddwd ymm2,ymm2,ymm8 + // vpmaddwd ymm3,ymm3,ymm8 + // vphaddd ymm1,ymm2,ymm3 # reduce and interleave Sum3/Sum2 + // vphaddd ymm0,ymm0,ymm1 # reduce and interleave Sum3/Sum2/Sum1/Sum0 + // vextracti128 xmm1,ymm0,1 # extract high dwords + // vpaddd xmm0,xmm0,xmm1 # reduce low/high dwords + // vmovdqu XMMWORD PTR [r9],xmm0 + + auto inputVecType = mlir::VectorType::get({ vectorSize }, inputReductionElementType); + auto reductionVecType = inputVecType; + if (reduce16To8) + { + // Reduces 4x16xi16 -> 4x8xi32 + + // Create a buffer of int16 1's + auto one = rewriter.create(affineForOp.getLoc(), 1, rewriter.getIntegerType(16)); + mlir::Value vecOnes = rewriter.create(affineForOp.getLoc(), inputVecType, one); + + // Create a vpmaddwd for each row with the vector of ones + auto i32sPerVec = vectorSize / 2; + reductionVecType = mlir::VectorType::get({ i32sPerVec }, rewriter.getI32Type()); + currentVectorValsToReduce[0] = rewriter.create(binOp.getLoc(), reductionVecType, origVectorValsToReduce[0], vecOnes); + currentVectorValsToReduce[1] = rewriter.create(binOp.getLoc(), reductionVecType, origVectorValsToReduce[1], vecOnes); + currentVectorValsToReduce[2] = rewriter.create(binOp.getLoc(), reductionVecType, origVectorValsToReduce[2], vecOnes); + currentVectorValsToReduce[3] = rewriter.create(binOp.getLoc(), reductionVecType, origVectorValsToReduce[3], vecOnes); + } + + // Reduce and interleave (row0, row1) and then (row2, row3), then reduce and interleave those results together + // vphaddd(ymm0, ymm1) -> [ymm0[0]+ymm0[1], ymm0[2]+ymm0[3], + // ymm1[0]+ymm1[1], ymm1[2]+ymm1[3], + // ymm0[4]+ymm0[5], ymm0[6]+ymm0[7], + // ymm1[4]+ymm1[5], ymm1[6]+ymm1[7]] + + // Reduce and interleave (row0, row1) and (row2, row3) + mlir::Value row01_interleaved = rewriter.create(binOp.getLoc(), currentVectorValsToReduce[0], currentVectorValsToReduce[1]); + mlir::Value row23_interleaved = rewriter.create(binOp.getLoc(), currentVectorValsToReduce[2], currentVectorValsToReduce[3]); + + // Reduce and interleave ((row0, row1), (row2, row3)) + mlir::Value row0123_interleaved = rewriter.create(binOp.getLoc(), row01_interleaved, row23_interleaved); + + // Now the elements of row0123_interleaved are from rows [ 0, 1, 2, 3, 0, 1, 2, 3 ] + // So separate the low 4 elements and the high 4 elements separately and add them together elementwise + + auto [low_row0123_interleaved, high_row0123_interleaved] = GetLowHighSeparately(rewriter, binOp.getLoc(), row0123_interleaved); + + mlir::Value reduced_4x1 = rewriter.create(binOp.getLoc(), v::BinaryOpPredicate::ADD, low_row0123_interleaved, high_row0123_interleaved); + + // Now load the original values and add the reduced tile + auto vecOrigValues = VectorizeAffineLoadOpHelper(rewriter, outputLoadOp, outerLoopNumIters); + auto finalAccumulation = rewriter.create(binOp.getLoc(), v::BinaryOpPredicate::ADD, vecOrigValues, reduced_4x1); + + // Vectorize the store op and store the finalAccumulation vector of results + + mlir::AffineStoreOpAdaptor storeAdaptor{ storeOp }; + std::vector storeIndices(storeAdaptor.indices().begin(), storeAdaptor.indices().end()); + auto [flatCastOutputMemRef, flattenedOutputPos] = FlattenAccess(rewriter, storeOp, storeIndices); + rewriter.create(storeOp.getLoc(), finalAccumulation, flatCastOutputMemRef, mlir::ValueRange{ flattenedOutputPos }); + + // Set the step size for the vectorized loops such that they each have a single iteration and will later get simplified away while replacing any IV usage with their begin value + outerLoop.setStep(outerLoopStep * outerLoopNumIters); + innerLoop.setStep(innerLoopStep * innerLoopNumIters); + + // Erase the original non-vectorized ops + ir::util::EraseOps(matchedOps, rewriter); + + return mlir::success(); +} + // TODO : support multi-dim vector reductions mlir::LogicalResult vectorizeHorizontalReduction(mlir::AffineForOp affineForOp, mlir::PatternRewriter& rewriter) { @@ -1222,13 +1750,12 @@ mlir::LogicalResult vectorizeHorizontalReduction(mlir::AffineForOp affineForOp, // Implement the matcher auto reportMatchFailure = [&](mlir::Operation* op, std::string message) -> LogicalResult { - llvm::dbgs() << "[vectorizeHorizontalReduction] While processing " << *op << ". " << message << "\n"; - return rewriter.notifyMatchFailure(op, message); + return reportMatchOpFailure(op, message, "vectorizeHorizontalReduction"); }; std::stack matchedOps; std::stack tempOps; - ir::util::TempOpCleanupGuard(&tempOps, rewriter); + ir::util::TempOpCleanupGuard tempGuard(&tempOps, rewriter); SmallVector loops; mlir::getPerfectlyNestedLoops(loops, affineForOp); @@ -1263,7 +1790,7 @@ mlir::LogicalResult vectorizeHorizontalReduction(mlir::AffineForOp affineForOp, } auto lhsLoadOp = cast(*loopBodyIter++); - auto lhsLoadVal = lhsLoadOp.getResult(); // Keep the laoded val separate from the current lhs val for mapping later + auto lhsLoadVal = lhsLoadOp.getResult(); // Keep the loaded val separate from the current lhs val for mapping later auto lhsVal = lhsLoadVal; matchedOps.push(lhsLoadOp); @@ -1481,7 +2008,7 @@ mlir::LogicalResult vectorizeHorizontalReduction(mlir::AffineForOp affineForOp, // - store BinOp result to location Y // Check that all that remains are optionally redundant load-stores and the yield op - + // match the final pair of redundant load and store ops if (loopBodyIter != loopBodyEnd && isa(*loopBodyIter)) { @@ -1494,7 +2021,7 @@ mlir::LogicalResult vectorizeHorizontalReduction(mlir::AffineForOp affineForOp, { return reportMatchFailure(storeOp, "Extraneous load/store aren't to the same memref"); } - + auto strideOpt = GetConstantStrideBetweenAccesses(rewriter, loadOp, storeOp); if (!strideOpt.has_value() || *strideOpt != 0) { @@ -1530,55 +2057,95 @@ mlir::LogicalResult vectorizeHorizontalReduction(mlir::AffineForOp affineForOp, // LHS Load mlir::Value vecLhsVal; + mlir::Value vecLhsLoadVal; if (lhsLoadIsLoopSequential) { auto lhsLoadVecOp = VectorizeAffineLoadOpHelper(rewriter, lhsLoadOp, vectorSize); - vecLhsVal = lhsLoadVecOp.getResult(); - mappings.map(lhsLoadVal, vecLhsVal); + vecLhsLoadVal = lhsLoadVecOp.getResult(); } else { - vecLhsVal = mlir::cast(rewriter.clone(*lhsLoadOp.getOperation(), mappings)); + vecLhsLoadVal = mlir::cast(rewriter.clone(*lhsLoadOp.getOperation(), mappings)); } - mappings.map(lhsLoadVal, vecLhsVal); + vecLhsVal = vecLhsLoadVal; + mappings.map(lhsLoadVal, vecLhsLoadVal); // Optional cast if (lhsLoadCastOp) { // Create a vector cast auto castVecType = mlir::VectorType::get({ vectorSize }, lhsCastType); - vecLhsVal = rewriter.create(lhsLoadCastOp.getLoc(), vecLhsVal, castVecType); + vecLhsVal = rewriter.create(lhsLoadCastOp.getLoc(), vecLhsLoadVal, castVecType); } mappings.map(lhsVal, vecLhsVal); // RHS Load mlir::Value vecRhsVal; + mlir::Value vecRhsLoadVal; if (rhsLoadIsLoopSequential) { auto rhsLoadVecOp = VectorizeAffineLoadOpHelper(rewriter, rhsLoadOp, vectorSize); - vecRhsVal = rhsLoadVecOp.getResult(); - mappings.map(rhsLoadVal, vecRhsVal); + vecRhsLoadVal = rhsLoadVecOp.getResult(); } else { - vecRhsVal = mlir::cast(rewriter.clone(*rhsLoadOp.getOperation(), mappings)); + vecRhsLoadVal = mlir::cast(rewriter.clone(*rhsLoadOp.getOperation(), mappings)); } - mappings.map(rhsLoadVal, vecRhsVal); + vecRhsVal = vecRhsLoadVal; + mappings.map(rhsLoadVal, vecRhsLoadVal); // Optional cast if (rhsLoadCastOp) { // Create a vector cast auto castVecType = mlir::VectorType::get({ vectorSize }, rhsCastType); - vecRhsVal = rewriter.create(rhsLoadCastOp.getLoc(), vecRhsVal, castVecType); + vecRhsVal = rewriter.create(rhsLoadCastOp.getLoc(), vecRhsLoadVal, castVecType); } mappings.map(rhsVal, vecRhsVal); // Now create the appropriate vector reduce given the bin op type and apply it to either the LHS vector val or RHS vector val, whichever is the loaded vector auto vectorValToReduce = lhsLoadIsLoopSequential ? vecLhsVal : vecRhsVal; - auto reduceOp = rewriter.create(binOp.getLoc(), storeElementType, mlir::vector::stringifyEnum(reductionKind), vectorValToReduce, mlir::ValueRange{} /* optional accumulate values */); - - mlir::Value reducedVal = reduceOp.getResult(); + auto loadedVectorValToReduce = lhsLoadIsLoopSequential ? vecLhsLoadVal : vecRhsLoadVal; + auto initReductionElementType = lhsLoadIsLoopSequential ? lhsLoadVal.getType() : rhsLoadVal.getType(); + + mlir::Value reducedVal; + if (binOp.getPredicate() == v::BinaryOpPredicate::ADD && + storeElementType.isInteger(32) && + initReductionElementType.isInteger(16) && + vectorSize == 16 && + ir::util::ModuleSupportsTargetDeviceFeature(affineForOp, "avx2")) + { + // Special handling for 16-element i16 sum reduction on AVX2+ machines + // ymm2 = fill(1, i16) + // ymm0 = load vec + // vpmaddwd ymm0,ymm0,ymm2 [0, 1, 2, 3, ..., 14, 15] (16xi16) -> [0+1, 2+3, 4+5, ..., 14+15] (8xi32) + // vextracti128 xmm1,ymm0,1 [0+1, 2+3, 4+5, ..., 14+15] (8xi32) -> [0+1, ..., 6+7] (4xi32), [8+9, ..., 14+15] (4xi32) + // vpaddd xmm0,xmm0,xmm1 [0+1, ..., 6+7] (4xi32), [8+9, ..., 14+15] (4xi32) -> [0+1+8+9, 2+3+10+11, 4+5+12+13, 6+7+14+15] (4xi32) + // vphaddd xmm0,xmm0,xmm0 [0+1+8+9, 2+3+10+11, 4+5+12+13, 6+7+14+15] (4xi32) -> [0+1+8+9+2+3+10+11, 4+5+12+13+6+7+14+15, (duplicate 0...), (duplicate 1...)] (4xi32) + // vphaddd xmm0,xmm0,xmm0 [0+1+8+9+2+3+10+11, 4+5+12+13+6+7+14+15, (duplicate 0...), (duplicate 1...)] (4xi32) -> [0+1+8+9+2+3+10+11+4+5+12+13+6+7+14+15, (duplicate 0...), (duplicate 0...), (duplicate 0...)] (4xi32) + + // Create a buffer of int16 1's + auto one = rewriter.create(affineForOp.getLoc(), 1, rewriter.getIntegerType(16)); + auto inputVecType = mlir::VectorType::get({ vectorSize }, rewriter.getIntegerType(16)); + mlir::Value vecOnes = rewriter.create(affineForOp.getLoc(), inputVecType, one); + + auto i32sPerVec = vectorSize / 2; + auto reductionVecType = mlir::VectorType::get({ i32sPerVec }, rewriter.getI32Type()); + mlir::Value reduced_8xi32 = rewriter.create(binOp.getLoc(), reductionVecType, loadedVectorValToReduce, vecOnes); + + auto [low_reduced_4xi32, high_reduced_4xi32] = GetLowHighSeparately(rewriter, binOp.getLoc(), reduced_8xi32); + mlir::Value reduced_4xi32 = rewriter.create(binOp.getLoc(), v::BinaryOpPredicate::ADD, low_reduced_4xi32, high_reduced_4xi32); + + reduced_4xi32 = rewriter.create(binOp.getLoc(), reduced_4xi32, reduced_4xi32); + reduced_4xi32 = rewriter.create(binOp.getLoc(), reduced_4xi32, reduced_4xi32); + auto zeroPos = rewriter.create(binOp.getLoc(), 0); + reducedVal = rewriter.create(binOp.getLoc(), reduced_4xi32, zeroPos); + } + else + { + reducedVal = rewriter.create(binOp.getLoc(), storeElementType, mlir::vector::stringifyEnum(reductionKind), vectorValToReduce, mlir::ValueRange{} /* optional accumulate values */); + } + auto scalarValThatWasReduced = lhsLoadIsLoopSequential ? lhsVal : rhsVal; mappings.map(scalarValThatWasReduced, reducedVal); @@ -1628,13 +2195,12 @@ mlir::LogicalResult vectorizeSequentialCast(mlir::AffineForOp affineForOp, mlir: // Implement the matcher auto reportMatchFailure = [&](mlir::Operation* op, std::string message) -> LogicalResult { - llvm::dbgs() << "[vectorizeSequentialCast] While processing " << *op << ". " << message << "\n"; - return rewriter.notifyMatchFailure(op, message); + return reportMatchOpFailure(op, message, "vectorizeSequentialCast"); }; std::stack matchedOps; std::stack tempOps; - ir::util::TempOpCleanupGuard(&tempOps, rewriter); + ir::util::TempOpCleanupGuard tempGuard(&tempOps, rewriter); // Match j and k loop SmallVector loops; @@ -1834,8 +2400,7 @@ mlir::LogicalResult vectorizeTwoRowInterleavedPack(mlir::AffineForOp affineForOp // Implement the matcher auto reportMatchFailure = [&](mlir::Operation* op, std::string message) -> LogicalResult { - llvm::dbgs() << "[vectorizeTwoRowInterleavedPack] While processing " << *op << ". " << message << "\n"; - return rewriter.notifyMatchFailure(op, message); + return reportMatchOpFailure(op, message, "vectorizeTwoRowInterleavedPack"); }; std::stack matchedOps; @@ -2035,8 +2600,7 @@ mlir::LogicalResult vectorizeInt16MatMul(mlir::AffineForOp affineForOp, { // Implement the matcher auto reportMatchFailure = [&](mlir::Operation* op, std::string message) -> LogicalResult { - llvm::dbgs() << "[vectorizeInt16MatMul] While processing " << *op << ". " << message << "\n"; - return rewriter.notifyMatchFailure(op, message); + return reportMatchOpFailure(op, message, "vectorizeInt16MatMul"); }; auto avx2Support = ir::util::ModuleSupportsTargetDeviceFeature(affineForOp, "avx2"); @@ -2047,8 +2611,8 @@ mlir::LogicalResult vectorizeInt16MatMul(mlir::AffineForOp affineForOp, return reportMatchFailure(affineForOp, "Target device does not support vpmaddwd instruction"); } - std::vector supportedBaseInputElementTypes { rewriter.getIntegerType(8), rewriter.getIntegerType(8, false /* isSigned */), rewriter.getIntegerType(16) }; - std::vector supportedCastInputElementTypes { rewriter.getIntegerType(16), rewriter.getIntegerType(32) }; + std::vector supportedBaseInputElementTypes{ rewriter.getIntegerType(8), rewriter.getIntegerType(8, false /* isSigned */), rewriter.getIntegerType(16) }; + std::vector supportedCastInputElementTypes{ rewriter.getIntegerType(16), rewriter.getIntegerType(32) }; auto isInputTypeSupported = [&supportedBaseInputElementTypes, &supportedCastInputElementTypes](const mlir::Type& type, bool baseInputType) { if (baseInputType) return std::find(supportedBaseInputElementTypes.begin(), supportedBaseInputElementTypes.end(), type) != supportedBaseInputElementTypes.end(); @@ -2061,6 +2625,7 @@ mlir::LogicalResult vectorizeInt16MatMul(mlir::AffineForOp affineForOp, std::stack matchedOps; std::stack tempOps; + ir::util::TempOpCleanupGuard tempGuard(&tempOps, rewriter); // Match jj and kk loop in int16 matmul for vectorization rewrite rules SmallVector loops; @@ -2398,8 +2963,8 @@ mlir::LogicalResult vectorizeInt16MatMul(mlir::AffineForOp affineForOp, evenMaskVec.reserve(vectorSize / 2); for (int64_t i = 0; i < vectorSize / 2; ++i) { - oddMaskVec.push_back(i*2 + 1); - evenMaskVec.push_back(i*2); + oddMaskVec.push_back(i * 2 + 1); + evenMaskVec.push_back(i * 2); } auto oddMask = rewriter.getI64ArrayAttr(oddMaskVec); auto evenMask = rewriter.getI64ArrayAttr(evenMaskVec); @@ -2429,7 +2994,6 @@ mlir::LogicalResult vectorizeInt16MatMul(mlir::AffineForOp affineForOp, return { loadVecVal, oddShuffleVal, evenShuffleVal }; }; - // If there's only one broadcasted load, make sure it happens first for better vpmaddwd matching mlir::Value firstLoadVec; mlir::Value firstLoadOdds; @@ -2512,19 +3076,18 @@ mlir::LogicalResult vectorizeInt16MatMul(mlir::AffineForOp affineForOp, } mlir::LogicalResult vectorizeMaskedLoadStore(mlir::AffineForOp loopOp, - mlir::PatternRewriter& rewriter) + mlir::PatternRewriter& rewriter) { auto reportMatchFailure = [&](mlir::Operation* op, std::string message) -> LogicalResult { - llvm::dbgs() << "[vectorizeMaskedLoadStore] While processing " << *op << ". " << message << "\n"; - return rewriter.notifyMatchFailure(op, message); + return reportMatchOpFailure(op, message, "vectorizeMaskedLoadStore"); }; + std::stack matchedOps; // Set the insertion point to the end of the loop (just before the terminator) mlir::OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(loopOp.getBody(), loopOp.getBody()->getTerminator()->getIterator()); - if (!loopOp.hasConstantBounds() || loopOp.getConstantLowerBound() != 0) { return reportMatchFailure(loopOp, "Failed: loop op either doesn't have constant bounds or lower bound is not equal to zero"); @@ -2551,9 +3114,9 @@ mlir::LogicalResult vectorizeMaskedLoadStore(mlir::AffineForOp loopOp, auto cmpOpResult = cmpOp.result(); matchedOps.push(cmpOp); - loopBodyStart++; + loopBodyStart++; - // 2. match scf.if op + // 2. match scf.if op if (loopBodyStart == loopBodyEnd || !isa(*loopBodyStart)) { return reportMatchFailure(loopOp, "Failed to match the scf.if op"); @@ -2569,6 +3132,7 @@ mlir::LogicalResult vectorizeMaskedLoadStore(mlir::AffineForOp loopOp, auto thenOpsIter = thenBlock->getOperations().begin(); auto thenOpsEnd = thenBlock->getOperations().end(); + // load op if (thenOpsIter == thenOpsEnd || !isa(thenOpsIter)) { return reportMatchFailure(ifOp, "Failed to match the load op in then block"); @@ -2577,11 +3141,10 @@ mlir::LogicalResult vectorizeMaskedLoadStore(mlir::AffineForOp loopOp, auto loadOp = cast(thenOpsIter++); matchedOps.push(loadOp); - // Optionally allow casting the load value mlir::Value loadVal = loadOp.getResult(); - v::CastOp thenCastOp; // Check if there's a cast after the load + v::CastOp thenCastOp; if (thenOpsIter != thenOpsEnd && isa(thenOpsIter)) { thenCastOp = cast(thenOpsIter++); @@ -2589,6 +3152,57 @@ mlir::LogicalResult vectorizeMaskedLoadStore(mlir::AffineForOp loopOp, loadVal = thenCastOp.result(); } + // match second load op for accumulation case + + mlir::AffineLoadOp loadOp2; + mlir::Value loadVal2; + if (thenOpsIter != thenOpsEnd && isa(thenOpsIter)) + { + loadOp2 = cast(thenOpsIter++); + loadVal2 = loadOp2.getResult(); + matchedOps.push(loadOp2); + } + + // Optionally allow casting the second load value + v::CastOp thenCastOp2; + + // Check if there's a cast after the seond load + if (thenOpsIter != thenOpsEnd && isa(thenOpsIter)) + { + thenCastOp2 = cast(thenOpsIter++); + matchedOps.push(thenCastOp2); + loadVal2 = thenCastOp2.result(); + } + + // binary add op for accumulation case + v::BinOp accOp; + if (thenOpsIter != thenOpsEnd && isa(thenOpsIter)) + { + accOp = cast(thenOpsIter++); + } + + // Check that the operands for the accumulation op are in fact the values from load ops + mlir::Value accVal; + if (accOp) + { + if (!((accOp.lhs() == loadVal && accOp.rhs() == loadVal2) || (accOp.rhs() == loadVal && accOp.lhs() == loadVal2))) + { + return reportMatchFailure(accOp, "Failed to match the accumulation operands"); + } + matchedOps.push(accOp); + accVal = accOp.getResult(); + } + + // optionally check if there is a cast op after accumulation op + v::CastOp thenCastOp3; + if (thenOpsIter != thenOpsEnd && isa(thenOpsIter)) + { + thenCastOp3 = cast(thenOpsIter++); + matchedOps.push(thenCastOp3); + accVal = thenCastOp3.result(); + } + + // store op if (thenOpsIter == thenOpsEnd || !isa(thenOpsIter)) { return reportMatchFailure(ifOp, "Failed to match the store op in then block"); @@ -2621,6 +3235,7 @@ mlir::LogicalResult vectorizeMaskedLoadStore(mlir::AffineForOp loopOp, matchedOps.push(elseCastOp); } + // store op if (elseOpsIter == elseOpsEnd || !isa(elseOpsIter)) { return reportMatchFailure(ifOp, "Failed to match the store op in else block"); @@ -2638,6 +3253,12 @@ mlir::LogicalResult vectorizeMaskedLoadStore(mlir::AffineForOp loopOp, paddingOpValue = paddingOp.value(); } + if (loadOp && loadOp2 && elseBlock) + { + return reportMatchFailure(loopOp, "Failed: case of two conditional load ops cannot have an else block"); + } + + ////////////////////////////////////// // match successful, start rewriting here // unroll cmp ops (create lanemappings) std::vector laneMappings(unrollMax); @@ -2683,7 +3304,7 @@ mlir::LogicalResult vectorizeMaskedLoadStore(mlir::AffineForOp loopOp, // type of padding op value may not be the same as what is required by transfer read op // so, we need a cast here always. auto finalPaddingOpValue = rewriter.create(loopOp.getLoc(), paddingOpValue, loadElementType); - + // create transferRead op with mask value mlir::AffineLoadOpAdaptor adaptor{ loadOp }; std::vector indices(adaptor.indices().begin(), adaptor.indices().end()); @@ -2692,8 +3313,15 @@ mlir::LogicalResult vectorizeMaskedLoadStore(mlir::AffineForOp loopOp, // create a default identity map for mapping 1:1 dimension mlir::AffineMap permutationMap = mlir::AffineMap::getMinorIdentityMap(1, 1, rewriter.getContext()); mlir::AffineMapAttr permutationMapAttr = mlir::AffineMapAttr::get(permutationMap); + + // The inbound vector for a Transfer(Read|Write)Op has one entry per dimension and indicates whether the access + // in that dimension may go out of bounds. We are controlling for out-of-bounds accesses by creating a mask + // and trusting the rest of our infrastructure and DSL authoring, so we don't need the vector dialect to generate + // runtime inbounds checks for every element. + // Therefore, we set inbound = true here, under the assumption that our mask is sufficient to prevent an out of bounds + // access. llvm::SmallVector inbound_init; - inbound_init.push_back(false); + inbound_init.push_back(true); auto inbounds = rewriter.getBoolArrayAttr(inbound_init); mlir::Value valueToStore = rewriter.create(loadLoc, loadVectorType, flatCastMemref, mlir::ValueRange{ flattenedPosition }, permutationMap, finalPaddingOpValue, mask, inbounds); @@ -2706,6 +3334,43 @@ mlir::LogicalResult vectorizeMaskedLoadStore(mlir::AffineForOp loopOp, valueToStore = rewriter.create(loopOp.getLoc(), valueToStore, castVecType); } + // create vector masked load op from second load op + + if (loadOp2) + { + auto loadLoc2 = loadOp2.getLoc(); + auto loadMemRefType2 = loadOp2.getMemRefType(); + auto loadElementType2 = loadMemRefType2.getElementType(); + auto loadVectorType2 = mlir::VectorType::get({ unrollMax }, loadElementType2); + + auto finalPaddingOpValue2 = rewriter.create(loopOp.getLoc(), paddingOpValue, loadElementType2); + mlir::AffineLoadOpAdaptor adaptor2{ loadOp2 }; + std::vector indices2(adaptor2.indices().begin(), adaptor2.indices().end()); + auto [flatCastMemref2, flattenedPosition2] = FlattenAccess(rewriter, loadOp2, indices2); + + mlir::Value accumulateOperand2 = rewriter.create(loadLoc2, loadVectorType2, flatCastMemref2, mlir::ValueRange{ flattenedPosition2 }, permutationMap, finalPaddingOpValue2, mask, inbounds); + + // optional cast op after second vector transfer read op + if (thenCastOp2) // then cast op + { + // Create a cast to hold vector of values + auto castVecType2 = mlir::VectorType::get({ unrollMax }, thenCastOp2.getType()); + accumulateOperand2 = rewriter.create(loopOp.getLoc(), accumulateOperand2, castVecType2); + } + + // if there is a second masked load, accumulation operator must follow before final store + // create binary add op to accumulate results from first and second masked load ops + valueToStore = rewriter.create(accOp.getLoc(), v::BinaryOpPredicate::ADD, valueToStore, accumulateOperand2); + + // optional cast op after accumulation op + if (thenCastOp3) // then cast op + { + // Create a cast to hold vector of values + auto castVecType3 = mlir::VectorType::get({ unrollMax }, thenCastOp3.getType()); + valueToStore = rewriter.create(loopOp.getLoc(), valueToStore, castVecType3); + } + } + // create vector store op mlir::AffineStoreOpAdaptor adaptorStore{ storeOp }; std::vector baseIndicesStore(adaptorStore.indices().begin(), adaptorStore.indices().end()); @@ -2719,7 +3384,7 @@ mlir::LogicalResult vectorizeMaskedLoadStore(mlir::AffineForOp loopOp, } else { - // masked store op + // masked store op for first masked loaded value or accumulated result rewriter.create(storeOp.getLoc(), valueToStore, flatCastMemRefStore, mlir::ValueRange{ flattenedPosStore }, permutationMapAttr, mask, inbounds); } @@ -2731,10 +3396,285 @@ mlir::LogicalResult vectorizeMaskedLoadStore(mlir::AffineForOp loopOp, return mlir::success(); } +mlir::LogicalResult vectorizeTranspose8x4f32(mlir::AffineForOp affineForOp, + mlir::PatternRewriter& rewriter) +{ + // Implement the matcher + auto reportMatchFailure = [&](mlir::Operation* op, std::string message) -> LogicalResult { + return reportMatchOpFailure(op, message, "vectorizeTranspose8x4f32"); + }; + + // TODO: Relax this constraint, as it might still work on other systems + auto avx2Support = ir::util::ModuleSupportsTargetDeviceFeature(affineForOp, "avx2"); + if (!avx2Support) + { + return reportMatchFailure(affineForOp, "Target device does not support avx2 instructions"); + } + + // TODO: Check to see what behavior is like with i32 since it has the same bitwidth + std::vector supportedBaseInputElementTypes{ rewriter.getF32Type() }; + auto isInputTypeSupported = [&supportedBaseInputElementTypes](const mlir::Type& type) { + return std::find(supportedBaseInputElementTypes.begin(), supportedBaseInputElementTypes.end(), type) != supportedBaseInputElementTypes.end(); + }; + + std::stack matchedOps; + // std::stack tempOps; + // ir::util::TempOpCleanupGuard tempGuard(&tempOps, rewriter); + + // Match loops in transpose for vectorization rewrite rules + SmallVector loops; + mlir::getPerfectlyNestedLoops(loops, affineForOp); + if (loops.size() != 2) // there should be exactly 2 loops in the nest + { + return failure(); + } + + for (auto& loop : loops) + { + // TODO: Relax LB constraint + if (!loop.hasConstantBounds() || loop.getConstantLowerBound() != 0) + { + return failure(); + } + } + + // order of nested loops we are looking for is + // i {0 to 8} followed by j {0 to 4} + auto outerLoop = loops.front(); // i loop + int64_t i_begin = outerLoop.getConstantLowerBound(); + int64_t i_end = outerLoop.getConstantUpperBound(); + int64_t i_step = outerLoop.getStep(); + int64_t i_numIters = (i_end - i_begin) / i_step; + if (i_numIters != 8) + return failure(); + auto i_inductionVar = outerLoop.getInductionVar(); + + auto innerLoop = loops.back(); // the innermost loop, j + int64_t j_begin = innerLoop.getConstantLowerBound(); + int64_t j_end = innerLoop.getConstantUpperBound(); + int64_t j_step = innerLoop.getStep(); + int64_t j_numIters = (j_end - j_begin) / j_step; + if (j_numIters != 4) + return failure(); + auto j_inductionVar = innerLoop.getInductionVar(); + + // get unroll max for i and j + int64_t unrollMax_i = std::min(i_numIters, (i_end - i_begin)); + int64_t unrollMax_j = std::min(j_numIters, (j_end - j_begin)); + + // create IV map for i and j + auto inductionVarMap_i = AffineMap::get(1, 1, rewriter.getAffineDimExpr(0) + i_step * rewriter.getAffineSymbolExpr(0)); + auto inductionVarMap_j = AffineMap::get(1, 1, rewriter.getAffineDimExpr(0) + j_step * rewriter.getAffineSymbolExpr(0)); + + // iterate on loop body from begin to end to match the ops list + auto innerLoopBodyIter = innerLoop.getBody()->begin(); + auto innerLoopBodyEnd = innerLoop.getBody()->end(); + + // 1. load from source matrix + if (innerLoopBodyIter == innerLoopBodyEnd || !isa(*innerLoopBodyIter)) + { + return reportMatchFailure(affineForOp, "Failed to match the load from the source array"); + } + auto loadOp = cast(*innerLoopBodyIter++); + auto loadLoc = loadOp.getLoc(); + auto loadElementType = loadOp.getMemRefType().getElementType(); + + if (!isInputTypeSupported(loadElementType)) + { + return reportMatchFailure(affineForOp, "Load array element type is not a supported type"); + } + + // 2. store to destination matrix + if (innerLoopBodyIter == innerLoopBodyEnd || !isa(*innerLoopBodyIter)) + { + return reportMatchFailure(affineForOp, "Failed to match the store to the destination array"); + } + auto storeOp = cast(*innerLoopBodyIter++); + if (storeOp.getValueToStore() != loadOp) + { + return reportMatchFailure(storeOp, "Failed to match the store"); + } + + // Ignore the yield op at the end + if (innerLoopBodyIter != innerLoopBodyEnd && isa(*innerLoopBodyIter)) + { + (void)innerLoopBodyIter++; + } + + if (innerLoopBodyIter != innerLoopBodyEnd) + { + LLVM_DEBUG(llvm::dbgs() << "While processing " << *innerLoopBodyIter << ". The store was not the last instruction\n"; + llvm::dbgs() << "affine for : " << *affineForOp << "\n"; + llvm::dbgs() << "current inst " << *innerLoopBodyIter << "\n"); + return failure(); + } + + // Set the insertion point to the end of the inner loop (just before the terminator) + mlir::OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(innerLoop.getBody(), innerLoop.getBody()->getTerminator()->getIterator()); + + // create lanemappings for i, j + std::vector laneMappings_i(unrollMax_i); + std::vector laneMappings_j(unrollMax_j); + + for (int64_t i_idx = i_begin; i_idx < i_end; i_idx += i_step) + { + auto offset_i = rewriter.create(outerLoop.getLoc(), i_idx); + auto offsetInductionVar_i = rewriter.create(outerLoop.getLoc(), inductionVarMap_i, ValueRange{ i_inductionVar, offset_i }); + // TODO: Update if assumption of i_begin == 0 is relaxed + laneMappings_i[i_idx].map(i_inductionVar, offsetInductionVar_i); + } + for (int64_t j_idx = j_begin; j_idx < j_end; j_idx += j_step) + { + auto offset_j = rewriter.create(innerLoop.getLoc(), j_idx); + auto offsetInductionVar_j = rewriter.create(innerLoop.getLoc(), inductionVarMap_j, ValueRange{ j_inductionVar, offset_j }); + // TODO: Update if assumption of j_begin == 0 is relaxed + laneMappings_j[j_idx].map(j_inductionVar, offsetInductionVar_j); + } + + int64_t inputVectorSize = unrollMax_j; + if (!IsUnrolledAccessSequential(rewriter, loadOp, laneMappings_j, inputVectorSize)) + { + return reportMatchFailure(loadOp, "Failed: isUnrolledAcessSequential for load"); + } + matchedOps.push(loadOp); + + int64_t outputVectorSize = unrollMax_i; + if (!IsUnrolledAccessSequential(rewriter, storeOp, laneMappings_i, outputVectorSize)) + { + return reportMatchFailure(loadOp, "Failed: isUnrolledAcessSequential for store"); + } + matchedOps.push(storeOp); + + // At this point we know: + // - there are 2 nested loops, the inner of which has 4 iterations + // - the loops have constant bounds + // - the innermost loop contains a load that is sequential wrt the inner loop + // - the innermost loop contains a store of the loaded value that is sequential wrt the outer loop + // - there are no other ops in the innermost loop (other than a loop terminator op) + + // So now we can create the new vectorized version of the loops + + // 1. create vector load of the input rows + auto inputMemRefType = loadOp.getMemRefType(); + auto inputElementType = inputMemRefType.getElementType(); + auto inputVectorType = mlir::VectorType::get({ inputVectorSize }, inputElementType); + + // Clone the load op for each iteration of the i loop and vectorize each of those loads wrt the j loop + // lea rax,[rsi+rdx*2] + // vmovups xmm0,XMMWORD PTR [rsi] // Load 4 floats [0, 0...3] into xmm0 from rsi + // vmovups xmm1,XMMWORD PTR [rsi+rdx] // Load 4 floats [1, 0...3] into xmm1 from rsi + rdx + // lea rsi,[rax+rdx*2] + // vmovups xmm2,XMMWORD PTR [rax] // Load 4 floats [2, 0...3] into xmm2 from rax + // vmovups xmm3,XMMWORD PTR [rax+rdx] // Load 4 floats [3, 0...3] into xmm3 from rax + rdx + // lea rax,[rsi+rdx*2] + std::vector loadedVecs; + for (int64_t i_idx = i_begin; i_idx < i_end; i_idx += i_step) + { + auto clonedLoadOp = mlir::cast(rewriter.clone(*(loadOp.getOperation()), laneMappings_i[i_idx])); + + mlir::AffineLoadOpAdaptor loadAdaptor{ clonedLoadOp }; + std::vector loadIndices(loadAdaptor.indices().begin(), loadAdaptor.indices().end()); + + auto [flatCastInputMemRef, flattenedInputPos] = FlattenAccess(rewriter, clonedLoadOp, loadIndices); + mlir::Value loadedVec = rewriter.create(loadOp.getLoc(), inputVectorType, flatCastInputMemRef, mlir::ValueRange{ flattenedInputPos }); + loadedVecs.push_back(loadedVec); + } + assert(loadedVecs.size() == 8); + + // 2. create a vector.shuffle ops to transpose the inputs + + // shuffle mask is {0, 1, ... N, N + 1, ... 2N - 1}, where N is the number of j iterations + auto concatMask = rewriter.getI64ArrayAttr(llvm::to_vector(llvm::seq(0, unrollMax_j * 2))); + auto outputMemRefType = storeOp.getMemRefType(); + auto outputElementType = outputMemRefType.getElementType(); + auto outputVectorType = mlir::VectorType::get({ outputVectorSize }, outputElementType); + + // Merge the loaded vectors in loadedVec + // vinsertf128 ymm0,ymm0,XMMWORD PTR [rsi],1 // ymm0[0..3] = xmm0, ymm0[4..7] = 4 floats [4, 0...3] from rsi + // vinsertf128 ymm1,ymm1,XMMWORD PTR [rsi+rdx],1 // ymm1[0..3] = xmm1, ymm1[4..7] = 4 floats [5, 0...3] from rsi + rdx + // vinsertf128 ymm2,ymm2,XMMWORD PTR [rax],1 // ymm2[0..3] = xmm2, ymm2[4..7] = 4 floats [6, 0...3] from rax + // vinsertf128 ymm3,ymm3,XMMWORD PTR [rax+rdx],1 // ymm3[0..3] = xmm3, ymm3[4..7] = 4 floats [7, 0...3] from rax + rdx + llvm::SmallVector shuffleLoadOps; + for (int64_t loadIdx = 0; loadIdx < unrollMax_j; ++loadIdx) + { + auto shuffledRowsOp = rewriter.create(loadLoc, outputVectorType, loadedVecs[loadIdx], loadedVecs[loadIdx + unrollMax_j], concatMask); + shuffleLoadOps.push_back(shuffledRowsOp); + } + assert(shuffleLoadOps.size() == 4); + + // Interleave the 8-byte loaded vectors + // vunpcklps ymm4,ymm0,ymm1 // ymm0: [ 0000 4444 ] ymm1: [ 1111 5555 ] -> ymm4: [ 0101 4545 ] + // vunpckhps ymm5,ymm0,ymm1 // ymm0: [ 0000 4444 ] ymm1: [ 1111 5555 ] -> ymm5: [ 0101 4545 ] + // vunpcklps ymm0,ymm2,ymm3 // ymm2: [ 2222 6666 ] ymm3: [ 3333 7777 ] -> ymm0: [ 2323 6767 ] + // vunpckhps ymm1,ymm2,ymm3 // ymm2: [ 2222 6666 ] ymm3: [ 3333 7777 ] -> ymm1: [ 2323 6767 ] + auto lowByteInterleaveMaskF32 = rewriter.getI64ArrayAttr({ 0, outputVectorSize + 0, 1, outputVectorSize + 1, 4, outputVectorSize + 4, 5, outputVectorSize + 5 }); + auto highByteInterleaveMaskF32 = rewriter.getI64ArrayAttr({ 2, outputVectorSize + 2, 3, outputVectorSize + 3, 6, outputVectorSize + 6, 7, outputVectorSize + 7 }); + llvm::SmallVector intermediateShuffledOps; + for (int64_t loadIdx = 0; loadIdx < 4; loadIdx += 2) + { + intermediateShuffledOps.push_back( + rewriter.create(loadLoc, outputVectorType, shuffleLoadOps[loadIdx], shuffleLoadOps[loadIdx + 1], lowByteInterleaveMaskF32)); + intermediateShuffledOps.push_back( + rewriter.create(loadLoc, outputVectorType, shuffleLoadOps[loadIdx], shuffleLoadOps[loadIdx + 1], highByteInterleaveMaskF32)); + } + + // Interleave the vectors again + // vunpcklpd ymm2,ymm4,ymm0 // ymm4: [ 0101 4545 ] ymm0: [ 2323 6767 ] -> ymm2: [ 0123 4567 ] + // vunpckhpd ymm3,ymm4,ymm0 // ymm4: [ 0101 4545 ] ymm0: [ 2323 6767 ] -> ymm3: [ 0123 4567 ] + // vunpcklpd ymm0,ymm5,ymm1 // ymm5: [ 0101 4545 ] ymm1: [ 2323 6767 ] ymm0: [ 0123 4567] + // vunpckhpd ymm4,ymm5,ymm1 // ymm5: [ 0101 4545 ] ymm1: [ 2323 6767 ] ymm4: [ 0123 4567] + auto lowByteInterleaveMaskF64 = rewriter.getI64ArrayAttr({ 0, 1, 8, 9, 4, 5, 12, 13 }); + auto highByteInterleaveMaskF64 = rewriter.getI64ArrayAttr({ 2, 3, 10, 11, 6, 7, 14, 15 }); + llvm::SmallVector finalOutputShuffledOps; + for (int64_t idx = 0; idx < 2; ++idx) + { + finalOutputShuffledOps.push_back( + rewriter.create(loadLoc, outputVectorType, intermediateShuffledOps[idx], intermediateShuffledOps[idx + 2], lowByteInterleaveMaskF64)); + finalOutputShuffledOps.push_back( + rewriter.create(loadLoc, outputVectorType, intermediateShuffledOps[idx], intermediateShuffledOps[idx + 2], highByteInterleaveMaskF64)); + } + + auto storeLoc = storeOp.getLoc(); + { + // Cloning the store ops creates a dependency on the original load op, which means the load ops can't be + // erased unless the temporary store ops are erased first. So we introduce a new scope to ensure any temporary + // dependents are cleaned up before the main matching ops are erased + std::stack tempStoreOps; + ir::util::TempOpCleanupGuard storeGuard(&tempStoreOps, rewriter); + + // 3. create a vector store op of the transposed rows + // Clone the store op for each iteration of the j loop and vectorize each of those stores wrt the i loop + for (int64_t j_idx = j_begin, iter_idx = 0; j_idx < j_end; j_idx += j_step, ++iter_idx) + { + auto unrolledInductionVar_j = rewriter.create(storeLoc, j_idx); + mlir::BlockAndValueMapping jIterMapping; + jIterMapping.map(j_inductionVar, unrolledInductionVar_j); + auto clonedStoreOp = mlir::cast(rewriter.clone(*(storeOp.getOperation()), jIterMapping)); + tempStoreOps.push(clonedStoreOp); + + mlir::AffineStoreOpAdaptor storeAdaptor{ clonedStoreOp }; + std::vector storeIndices(storeAdaptor.indices().begin(), storeAdaptor.indices().end()); + + auto [flatCastOutputMemRef, flattenedOutputPos] = FlattenAccess(rewriter, clonedStoreOp, storeIndices); + (void)rewriter.create(storeLoc, finalOutputShuffledOps[iter_idx], flatCastOutputMemRef, mlir::ValueRange{ flattenedOutputPos }); + } + } + // Set the step size for the vectorized loops such that they each have a single iteration and will later get simplified away while replacing any IV usage with their begin value + outerLoop.setStep(i_step * i_numIters); + innerLoop.setStep(j_step * j_numIters); + + // Erase the original non-vectorized ops + ir::util::EraseOps(matchedOps, rewriter); + return mlir::success(); +} + mlir::LogicalResult TryVectorizeKnownSubgraph(mlir::AffineForOp affineForOp, mlir::PatternRewriter& rewriter) { // TODO : convert these to rewrite pattern structs with benefit weights + if (succeeded(vectorize2DHorizontalSumReduction(affineForOp, rewriter))) + return success(); if (succeeded(vectorizeHorizontalReduction(affineForOp, rewriter))) return success(); if (succeeded(vectorizeSequentialCast(affineForOp, rewriter))) @@ -2745,6 +3685,8 @@ mlir::LogicalResult TryVectorizeKnownSubgraph(mlir::AffineForOp affineForOp, return success(); if (succeeded(vectorizeMaskedLoadStore(affineForOp, rewriter))) return success(); + if (succeeded(vectorizeTranspose8x4f32(affineForOp, rewriter))) + return success(); return failure(); } diff --git a/accera/value/src/MLIREmitterContext.cpp b/accera/value/src/MLIREmitterContext.cpp index 730ee1be..d719ade9 100644 --- a/accera/value/src/MLIREmitterContext.cpp +++ b/accera/value/src/MLIREmitterContext.cpp @@ -2210,6 +2210,11 @@ Value MLIRContext::ReinterpretCastImpl(Value input, ValueType valueType) auto map6 = mlir::AffineMap::get(2, 0, d0 + (d1 % inputElementBytewidth)); affineMaps.push_back(map6); + // map7: (d0) -> (d0 floordiv sizeof(outputElementType)) + // => (byte memory space) -> (output element space) + auto map7 = mlir::AffineMap::get(1, 0, d0.floorDiv(outputElementBytewidth)); + affineMaps.push_back(map7); + // case 1: actual: 13, expected: 13 // case 2: actual: 52, expected: 52 // case 2: actual: 52, expected: 52 @@ -2227,8 +2232,8 @@ Value MLIRContext::ReinterpretCastImpl(Value input, ValueType valueType) if (inputMemrefType.hasStaticShape()) { // Case 1 - auto numElementsIninput = inputMemrefType.getNumElements(); - numElementsInOutput = (int64_t)(numElementsIninput * ((float)inputElementBytewidth / outputElementBytewidth)); + auto numElementsInInput = inputMemrefType.getNumElements(); + numElementsInOutput = (int64_t)(numElementsInInput * ((float)inputElementBytewidth / outputElementBytewidth)); if (numElementsInOutput <= 0) {