From 1ead0f7a37c637f4daa8d6aa6f14a32346ac8524 Mon Sep 17 00:00:00 2001 From: AlexRTer <74372589+AlexRTer@users.noreply.github.com> Date: Fri, 29 Nov 2024 18:17:23 +0100 Subject: [PATCH] Added tests for sparse exploit codegen - added script level and filecheck tests - replaced Operation ptr with cast to specific op to use `get` functions for readability - replaced notifyMatchFailure with ErrorHandler::compilerError --- .../lowering/SparsityExploitationPass.cpp | 36 +++++++++--------- test/CMakeLists.txt | 1 + test/api/cli/codegen/SparsityExploitTest.cpp | 29 ++++++++++++++ test/api/cli/codegen/sparsityExploit.daphne | 5 +++ test/codegen/sparseExploit.mlir | 38 +++++++++++++++++++ 5 files changed, 91 insertions(+), 18 deletions(-) create mode 100644 test/api/cli/codegen/SparsityExploitTest.cpp create mode 100644 test/api/cli/codegen/sparsityExploit.daphne create mode 100644 test/codegen/sparseExploit.mlir diff --git a/src/compiler/lowering/SparsityExploitationPass.cpp b/src/compiler/lowering/SparsityExploitationPass.cpp index 3cc51f245..f1f057873 100644 --- a/src/compiler/lowering/SparsityExploitationPass.cpp +++ b/src/compiler/lowering/SparsityExploitationPass.cpp @@ -75,22 +75,22 @@ class SparsityExploitation final : public mlir::OpConversionPattern(); // IntersectOp( sparseLhs, ln(denseLhs @ denseRhs) ) - Value sparseLhs = sparseIntersectOp->getOperand(0); - Operation *unaryOp = sparseIntersectOp->getOperand(1).getDefiningOp(); // ln(denseLhs @ denseRhs) + Value sparseLhs = sparseIntersectOp.getLhs(); + auto unaryOp = sparseIntersectOp.getRhs().getDefiningOp(); // ln(denseLhs @ denseRhs) - Operation *denseMatmulOp = unaryOp->getOperand(0).getDefiningOp(); // denseLhs @ denseRhs + auto denseMatmulOp = unaryOp.getArg().getDefiningOp(); // denseLhs @ denseRhs // daphne.matMul Op has 4 arguments, the latter two (`transa`, `transb`) are bools indicating whether // either matrix should be accessed as though it was transposed. - Value denseLhs = denseMatmulOp->getOperand(0); - Value denseRhs = denseMatmulOp->getOperand(1); + Value denseLhs = denseMatmulOp.getLhs(); + Value denseRhs = denseMatmulOp.getRhs(); bool transa = CompilerUtils::constantOrThrow( - denseMatmulOp->getOperand(2), "SparseExploitLowering: expected transa to be known at compile-time"); + denseMatmulOp.getTransa(), "SparseExploitLowering: expected transa to be known at compile-time"); bool transb = CompilerUtils::constantOrThrow( - denseMatmulOp->getOperand(3), "SparseExploitLowering: expected transb to be known at compile-time"); + denseMatmulOp.getTransb(), "SparseExploitLowering: expected transb to be known at compile-time"); auto sparseLhsMatType = sparseLhs.getType().template dyn_cast(); Type resElementType = sparseLhsMatType.getElementType(); @@ -108,8 +108,8 @@ class SparsityExploitation final : public mlir::OpConversionPattern([](Operation *op) { - Operation *definingEwMulOp = op->getOperand(0).getDefiningOp(); + auto definingEwMulOp = op->getOperand(0).getDefiningOp(); if (definingEwMulOp == nullptr) { return true; } - Type lhsType = definingEwMulOp->getOperand(0).getType(); + Type lhsType = definingEwMulOp.getLhs().getType(); auto lhsMatType = lhsType.dyn_cast(); - Value rhs = definingEwMulOp->getOperand(1); + Value rhs = definingEwMulOp.getRhs(); Type rhsType = rhs.getType(); auto rhsMatType = rhsType.dyn_cast(); @@ -311,18 +311,18 @@ void SparsityExploitationPass::runOnOperation() { return true; } - Operation *definingEwLnOp = rhs.getDefiningOp(); + auto definingEwLnOp = rhs.getDefiningOp(); if (definingEwLnOp == nullptr) { return true; } - Operation *definingMatMulOp = definingEwLnOp->getOperand(0).getDefiningOp(); + auto definingMatMulOp = definingEwLnOp.getArg().getDefiningOp(); if (definingMatMulOp == nullptr) { return true; } - Type matmulLhsType = definingMatMulOp->getOperand(0).getType(); - Type matmulRhsType = definingMatMulOp->getOperand(1).getType(); + Type matmulLhsType = definingMatMulOp.getLhs().getType(); + Type matmulRhsType = definingMatMulOp.getRhs().getType(); auto matmulLhsMatType = matmulLhsType.dyn_cast(); auto matmulRhsMatType = matmulRhsType.dyn_cast(); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index a8051d1b7..e4e6b758b 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -63,6 +63,7 @@ set(TEST_SOURCES api/cli/codegen/EwUnaryTest.cpp api/cli/codegen/EwOpLoopFusionTest.cpp api/cli/codegen/MapOpTest.cpp + api/cli/codegen/SparsityExploitTest.cpp api/cli/codegen/TransposeTest.cpp ir/daphneir/InferTypesTest.cpp diff --git a/test/api/cli/codegen/SparsityExploitTest.cpp b/test/api/cli/codegen/SparsityExploitTest.cpp new file mode 100644 index 000000000..1cb495507 --- /dev/null +++ b/test/api/cli/codegen/SparsityExploitTest.cpp @@ -0,0 +1,29 @@ +/* + * Copyright 2024 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include + +const std::string dirPath = "test/api/cli/codegen/"; + +TEST_CASE("sparsity exploit (sparse-dense cross-entropy)", TAG_CODEGEN) { + std::string result = "1.61078\n"; + compareDaphneToStr(result, dirPath + "sparsityExploit.daphne", "--select-matrix-repr"); + compareDaphneToStr(result, dirPath + "sparsityExploit.daphne", "--select-matrix-repr", "--mlir-codegen"); +} diff --git a/test/api/cli/codegen/sparsityExploit.daphne b/test/api/cli/codegen/sparsityExploit.daphne new file mode 100644 index 000000000..eb57018be --- /dev/null +++ b/test/api/cli/codegen/sparsityExploit.daphne @@ -0,0 +1,5 @@ +sparseLhs = rand(15, 10, 0.0, 1.0, 0.1, 1); // size: 15x10 range: (0,1) sparsity: 0.1 +DenseU = rand(15, 5, 0.0, 1.0, 1.0, 2); // size: 15x5 range: (0,1) sparsity: 1 +DenseV = rand(10, 5, 0.0, 1.0, 1.0, 3); // size: 10x5 range: (0,1) sparsity: 1 + +print(sum(sparseLhs * ln(DenseU @ t(DenseV)))); diff --git a/test/codegen/sparseExploit.mlir b/test/codegen/sparseExploit.mlir new file mode 100644 index 000000000..3b5789de0 --- /dev/null +++ b/test/codegen/sparseExploit.mlir @@ -0,0 +1,38 @@ +// RUN: daphne-opt -pass-pipeline="builtin.module(lower-sparse-exploit, canonicalize)" %s | FileCheck %s + +// COM: Canonicalizer is run to guarantee matMul, ewLn, and ewMul are also removed as they become redundant but are not removed by the sparse pass itself. +// COM: They are tested here regardless as it is crucial to the pass to replace the entire pattern. + +module { + func.func @double() { + %0 = "daphne.constant"() {value = true} : () -> i1 + %1 = "daphne.constant"() {value = false} : () -> i1 + %2 = "daphne.constant"() {value = 2 : index} : () -> index + %3 = "daphne.constant"() {value = 10 : index} : () -> index + %4 = "daphne.constant"() {value = 3 : si64} : () -> si64 + %5 = "daphne.constant"() {value = 2 : si64} : () -> si64 + %6 = "daphne.constant"() {value = 1 : si64} : () -> si64 + %7 = "daphne.constant"() {value = 0.000000e+00 : f64} : () -> f64 + %8 = "daphne.constant"() {value = 1.000000e+00 : f64} : () -> f64 + %9 = "daphne.constant"() {value = 2.000000e-01 : f64} : () -> f64 + %10 = "daphne.randMatrix"(%3, %3, %7, %8, %9, %6) : (index, index, f64, f64, f64, si64) -> !daphne.Matrix<10x10xf64:sp[2.000000e-01]:rep[sparse]> + %11 = "daphne.randMatrix"(%3, %2, %7, %8, %8, %5) : (index, index, f64, f64, f64, si64) -> !daphne.Matrix<10x2xf64:sp[1.000000e+00]> + %12 = "daphne.randMatrix"(%3, %2, %7, %8, %8, %4) : (index, index, f64, f64, f64, si64) -> !daphne.Matrix<10x2xf64:sp[1.000000e+00]> + // CHECK-NOT: daphne.matMul + // CHECK-NOT: daphne.ewLn + // CHECK-NOT: daphne.ewMul + // CHECK-NOT: daphne.sumAll + // CHECK: affine.parallel + // CHECK: scf.for + // CHECK: scf.for + // CHECK: math.fma + // CHECK: math.log + // CHECK: math.fma + %13 = "daphne.matMul"(%11, %12, %1, %0) : (!daphne.Matrix<10x2xf64:sp[1.000000e+00]>, !daphne.Matrix<10x2xf64:sp[1.000000e+00]>, i1, i1) -> !daphne.Matrix<10x10xf64:sp[1.000000e+00]> + %14 = "daphne.ewLn"(%13) : (!daphne.Matrix<10x10xf64:sp[1.000000e+00]>) -> !daphne.Matrix<10x10xf64> + %15 = "daphne.ewMul"(%10, %14) : (!daphne.Matrix<10x10xf64:sp[2.000000e-01]:rep[sparse]>, !daphne.Matrix<10x10xf64>) -> !daphne.Matrix<10x10xf64:sp[2.000000e-01]:rep[sparse]> + %16 = "daphne.sumAll"(%15) : (!daphne.Matrix<10x10xf64:sp[2.000000e-01]:rep[sparse]>) -> f64 + "daphne.print"(%16, %0, %1) : (f64, i1, i1) -> () + "daphne.return"() : () -> () + } +} \ No newline at end of file