From 574151c4f66ba9c121723dd2d3df6c1fd3dd0f31 Mon Sep 17 00:00:00 2001 From: m-birke Date: Tue, 12 Nov 2024 13:21:09 +0100 Subject: [PATCH 01/17] fix docs for SQL cross product and cartesian product --- doc/DaphneDSL/Builtins.md | 2 +- doc/tutorial/sqlTutorial.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/DaphneDSL/Builtins.md b/doc/DaphneDSL/Builtins.md index e78eaeb64..3b0318a15 100644 --- a/doc/DaphneDSL/Builtins.md +++ b/doc/DaphneDSL/Builtins.md @@ -490,7 +490,7 @@ We will support set operations such as **`intersect`**, **`merge`**, and **`exce - **`cartesian`**`(lhs:frame, rhs:frame)` - Calculates the cartesian (cross) product of the two input frames. + Calculates the cartesian product of the two input frames. - **`innerJoin`**`(lhs:frame, rhs:frame, lhsOn:str, rhsOn:str)` diff --git a/doc/tutorial/sqlTutorial.md b/doc/tutorial/sqlTutorial.md index d6f2002dd..2b42e5b7a 100644 --- a/doc/tutorial/sqlTutorial.md +++ b/doc/tutorial/sqlTutorial.md @@ -56,7 +56,7 @@ Other features we do and don't support right now can be found below. ### Supported Features -* Cross Product +* SQL Cross Product (Cartesian Product) * Complex Where Clauses * Inner Join with single and multiple join conditions separated by an "AND" Operator * Group By Clauses From a5353cfbc1826094bb81950d3428b5804bc25752 Mon Sep 17 00:00:00 2001 From: Marius Birkenbach <50267572+m-birke@users.noreply.github.com> Date: Mon, 25 Nov 2024 13:28:42 +0100 Subject: [PATCH 02/17] add some docs regarding WSL and UserConfig (#908) extend docs for WSL and UserConfig --- doc/DaphneDSL/Imports.md | 2 ++ doc/development/BuildingDaphne.md | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/doc/DaphneDSL/Imports.md b/doc/DaphneDSL/Imports.md index 1c9470c5d..0802efede 100644 --- a/doc/DaphneDSL/Imports.md +++ b/doc/DaphneDSL/Imports.md @@ -40,6 +40,8 @@ print(utils.x); } ``` +NOTE: to use a user config, the json file path needs to be passed as CLI arg to the DAPHNE binary `daphne --config=` + NOTE: `default_dirs` can hold many paths and it will look for the **one** specified file in each, whereas any other library names have a list consisting of **one** directory, from which **all** files will be imported (can be easily extended to multiple directories). Example: diff --git a/doc/development/BuildingDaphne.md b/doc/development/BuildingDaphne.md index 2a1d6bd87..baddcfa0b 100644 --- a/doc/development/BuildingDaphne.md +++ b/doc/development/BuildingDaphne.md @@ -132,6 +132,10 @@ All possible options for the build script: --- +## Building on WSL + +When using Windows Subsystems for Linux (WSL), the default memory limit for WSL is 50% of the total memory of the underlying Windows host. This can lead to build fails due to SIGKILL for DAPHNE builds. [Advanced settings configuration in WSL](https://learn.microsoft.com/en-us/windows/wsl/wsl-config) describes how the memory limit can be configured. + ## Extension ### Overview over the build script From bc65d5d9a5f8337731128d90f99625102b39bfff Mon Sep 17 00:00:00 2001 From: AlexRTer <74372589+AlexRTer@users.noreply.github.com> Date: Mon, 25 Nov 2024 14:16:25 +0100 Subject: [PATCH 03/17] [codegen] sparsity exploitation codegen pass This PR adds a codegen pass to lower the following expression to a fused operator that exploits sparsity sum(CSRMat * ln(denseLhs @ t(denseRhs))). It can be run simply by enabling the codegen pipeline (--mlir-codegen) and ensuring the lhs of the elementwise multiplication is a CSRMatrix (currently --select-matrix-repr) with the corresponding cli flags. By computing the sum directly, the pass not only avoids materializing potentially large dense matrices in the dense, right matrix multiplication, it also only computes the necessary dot products corresponding to non-zero entries in the CSRMatrix. Thus, it uses constant memory and reduces runtime significantly. closes #919 Co-authored-by: philipportner ortner@tu-berlin.de --- src/compiler/execution/DaphneIrExecutor.cpp | 6 + .../SelectMatrixRepresentationsPass.cpp | 1 - src/compiler/lowering/CMakeLists.txt | 1 + .../lowering/RewriteToCallKernelOpPass.cpp | 20 +- .../lowering/SparsityExploitationPass.cpp | 345 ++++++++++++++++++ src/compiler/utils/CompilerUtils.h | 4 +- src/ir/daphneir/DaphneOps.td | 23 +- src/ir/daphneir/Passes.h | 1 + src/ir/daphneir/Passes.td | 3 + src/parser/catalog/KernelCatalogParser.cpp | 7 +- src/runtime/local/datastructures/CSRMatrix.h | 14 +- .../kernels/ConvertCSRMatrixToColIdxsMemRef.h | 36 ++ .../ConvertCSRMatrixToRowOffsetsMemRef.h | 36 ++ .../kernels/ConvertCSRMatrixToValuesMemRef.h | 36 ++ .../kernels/ConvertDenseMatrixToMemRef.h | 1 + src/runtime/local/kernels/kernels.json | 87 +++++ 16 files changed, 606 insertions(+), 15 deletions(-) create mode 100644 src/compiler/lowering/SparsityExploitationPass.cpp create mode 100644 src/runtime/local/kernels/ConvertCSRMatrixToColIdxsMemRef.h create mode 100644 src/runtime/local/kernels/ConvertCSRMatrixToRowOffsetsMemRef.h create mode 100644 src/runtime/local/kernels/ConvertCSRMatrixToValuesMemRef.h diff --git a/src/compiler/execution/DaphneIrExecutor.cpp b/src/compiler/execution/DaphneIrExecutor.cpp index e8f69c5d5..651c4937a 100644 --- a/src/compiler/execution/DaphneIrExecutor.cpp +++ b/src/compiler/execution/DaphneIrExecutor.cpp @@ -258,6 +258,12 @@ void DaphneIrExecutor::buildCodegenPipeline(mlir::PassManager &pm) { pm.addPass(mlir::daphne::createPrintIRPass("IR before codegen pipeline")); pm.addPass(mlir::daphne::createDaphneOptPass()); + + pm.addPass(mlir::daphne::createSparsityExploitationPass()); + // SparseExploit fuses multiple operations which only need to be lowered if still needed elsewhere. + // Todo: if possible, run only if SparseExploitLowering was successful. + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::daphne::createEwOpLoweringPass()); pm.addPass(mlir::daphne::createAggAllOpLoweringPass()); pm.addPass(mlir::daphne::createAggDimOpLoweringPass()); diff --git a/src/compiler/inference/SelectMatrixRepresentationsPass.cpp b/src/compiler/inference/SelectMatrixRepresentationsPass.cpp index f49fa05f1..32a2b6f67 100644 --- a/src/compiler/inference/SelectMatrixRepresentationsPass.cpp +++ b/src/compiler/inference/SelectMatrixRepresentationsPass.cpp @@ -23,7 +23,6 @@ #include #include -#include using namespace mlir; diff --git a/src/compiler/lowering/CMakeLists.txt b/src/compiler/lowering/CMakeLists.txt index 18939f92e..fcbc1c995 100644 --- a/src/compiler/lowering/CMakeLists.txt +++ b/src/compiler/lowering/CMakeLists.txt @@ -34,6 +34,7 @@ add_mlir_dialect_library(MLIRDaphneTransforms AggAllOpLowering.cpp AggDimOpLowering.cpp TransposeOpLowering.cpp + SparsityExploitationPass.cpp DEPENDS MLIRDaphneOpsIncGen diff --git a/src/compiler/lowering/RewriteToCallKernelOpPass.cpp b/src/compiler/lowering/RewriteToCallKernelOpPass.cpp index f99c9e472..fe73b1acd 100644 --- a/src/compiler/lowering/RewriteToCallKernelOpPass.cpp +++ b/src/compiler/lowering/RewriteToCallKernelOpPass.cpp @@ -32,12 +32,13 @@ #include "mlir/IR/IRMapping.h" #include "mlir/IR/Location.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/ArrayRef.h" #include #include #include #include -#include +#include #include #include #include @@ -133,10 +134,19 @@ class KernelReplacement : public RewritePattern { return mlir::daphne::FrameType::get(mctx, {mlir::daphne::UnknownType::get(mctx)}); if (auto lt = t.dyn_cast()) return mlir::daphne::ListType::get(mctx, adaptType(lt.getElementType(), generalizeToStructure)); - if (auto mrt = t.dyn_cast()) - // Remove any dimension information ({0, 0}), but retain the element - // type. - return mlir::MemRefType::get({0, 0}, mrt.getElementType()); + if (auto mrt = t.dyn_cast()) { + // Remove any specific dimension information ({0}), but retain the rank and element type. + int64_t mrtRank = mrt.getRank(); + if (mrtRank == 1) { + return mlir::MemRefType::get({0}, mrt.getElementType()); + } else if (mrtRank == 2) { + return mlir::MemRefType::get({0, 0}, mrt.getElementType()); + } else { + throw std::runtime_error( + "RewriteToCallKernelOpPass: expected MemRef to be of rank 1 or 2 but was given " + + std::to_string(mrtRank)); + } + } return t; } diff --git a/src/compiler/lowering/SparsityExploitationPass.cpp b/src/compiler/lowering/SparsityExploitationPass.cpp new file mode 100644 index 000000000..3cc51f245 --- /dev/null +++ b/src/compiler/lowering/SparsityExploitationPass.cpp @@ -0,0 +1,345 @@ +/* + * Copyright 2023 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include +#include +#include + +#include "ir/daphneir/Daphne.h" +#include "ir/daphneir/Passes.h" +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/LLVMCommon/LoweringOptions.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/DialectInterface.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeRange.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/ArrayRef.h" + +using namespace mlir; + +/** + * @brief sum(sparse * ln(dense @ dense)) where either dense matrix can be transposed + */ +class SparsityExploitation final : public mlir::OpConversionPattern { + public: + using OpConversionPattern::OpAdaptor; + + SparsityExploitation(TypeConverter &typeConverter, mlir::MLIRContext *ctx) + : mlir::OpConversionPattern(typeConverter, ctx) { + this->setDebugName("SparsityExploitation"); + } + + LogicalResult matchAndRewrite(daphne::AllAggSumOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + + // ------------------------------------------------------------------------------------ + // Pattern Matching allAgg( IntersectOp( sparseLhs, ln(denseLhs @ denseRhs) ) ) + // ------------------------------------------------------------------------------------ + + // Rewrite is only called after the pattern has already been matched during legalization, so the results are + // assumed not to be nullptr. + Operation *sparseIntersectOp = + adaptor.getArg().getDefiningOp(); // IntersectOp( sparseLhs, ln(denseLhs @ denseRhs) ) + + Value sparseLhs = sparseIntersectOp->getOperand(0); + Operation *unaryOp = sparseIntersectOp->getOperand(1).getDefiningOp(); // ln(denseLhs @ denseRhs) + + Operation *denseMatmulOp = unaryOp->getOperand(0).getDefiningOp(); // denseLhs @ denseRhs + + // daphne.matMul Op has 4 arguments, the latter two (`transa`, `transb`) are bools indicating whether + // either matrix should be accessed as though it was transposed. + Value denseLhs = denseMatmulOp->getOperand(0); + Value denseRhs = denseMatmulOp->getOperand(1); + bool transa = CompilerUtils::constantOrThrow( + denseMatmulOp->getOperand(2), "SparseExploitLowering: expected transa to be known at compile-time"); + bool transb = CompilerUtils::constantOrThrow( + denseMatmulOp->getOperand(3), "SparseExploitLowering: expected transb to be known at compile-time"); + + auto sparseLhsMatType = sparseLhs.getType().template dyn_cast(); + Type resElementType = sparseLhsMatType.getElementType(); + ssize_t sparseLhsRows = sparseLhsMatType.getNumRows(); + ssize_t sparseLhsCols = sparseLhsMatType.getNumCols(); + + auto denseLhsMatType = denseLhs.getType().template dyn_cast(); + Type dotProdElementType = sparseLhsMatType.getElementType(); + ssize_t denseLhsRows = denseLhsMatType.getNumRows(); + ssize_t denseLhsCols = denseLhsMatType.getNumCols(); + + auto denseRhsMatType = denseRhs.getType().template dyn_cast(); + ssize_t denseRhsRows = denseRhsMatType.getNumRows(); + ssize_t denseRhsCols = denseRhsMatType.getNumCols(); + + if (sparseLhsRows < 0 || sparseLhsCols < 0 || denseLhsRows < 0 || denseLhsCols < 0 || denseRhsRows < 0 || + denseRhsCols < 0) { + return rewriter.notifyMatchFailure( + op, + "sparse exploit codegen currently only works with matrix dimensions that are known at compile time"); + } + + // Verify dense-dense Matmul and sparse-dense intersection op have matching dimensions. + if ((transa ? denseLhsRows : denseLhsCols) != (transb ? denseRhsCols : denseRhsRows)) { + throw ErrorHandler::compilerError( + loc, "SparseExploitLowering", + "dense-dense matrix multiplication operands must have matching inner dimension (given: " + "(" + + std::to_string(transa ? denseLhsCols : denseLhsRows) + "x" + + std::to_string(transa ? denseLhsRows : denseLhsCols) + ") and (" + + std::to_string(transb ? denseRhsCols : denseRhsRows) + "x" + + std::to_string(transb ? denseRhsRows : denseRhsCols) + ")"); + } + if (sparseLhsRows != (transa ? denseLhsCols : denseLhsRows) || + sparseLhsCols != (transb ? denseRhsRows : denseRhsCols)) { + throw ErrorHandler::compilerError( + loc, "SparseExploitLowering", + "sparse-dense intersectionOp operands must have equal dimensions (given: (" + + std::to_string(sparseLhsRows) + "x" + std::to_string(sparseLhsCols) + ") and (" + + std::to_string(transa ? denseLhsCols : denseLhsRows) + "x" + + std::to_string(transb ? denseRhsRows : denseRhsCols) + ")"); + } + + MemRefType denseLhsMemRefType = MemRefType::get({denseLhsRows, denseLhsCols}, denseLhsMatType.getElementType()); + MemRefType denseRhsMemRefType = MemRefType::get({denseRhsRows, denseRhsCols}, denseRhsMatType.getElementType()); + auto denseLhsMemRef = rewriter.create(loc, denseLhsMemRefType, denseLhs); + auto denseRhsMemRef = rewriter.create(loc, denseRhsMemRefType, denseRhs); + + MemRefType sparseLhsValuesMemRefType = + MemRefType::get({ShapedType::kDynamic}, sparseLhsMatType.getElementType()); + MemRefType sparseLhsColIdxsMemRefType = MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()); + MemRefType sparseLhsRowOffsetsMemRefType = MemRefType::get({sparseLhsRows + 1}, rewriter.getIndexType()); + + auto argValues = + rewriter.create(loc, sparseLhsValuesMemRefType, sparseLhs); + auto argColIdx = + rewriter.create(loc, sparseLhsColIdxsMemRefType, sparseLhs); + auto argRowPtr = + rewriter.create(loc, sparseLhsRowOffsetsMemRefType, sparseLhs); + + /* The loop nest below loops first over the row offsets, meaning over the rows of the CSRMatrix. + * + * Then, in a nested loop, over it iterates over the range created by the difference of an entry in + * row offsets and its successor, which corresponds to the column length of row the first offset was taken of. + * + * Finally, the inner-most loop simply computes the dot product of the dense matrices at the given + * row/column indices. Its result is passed to the first nested loop that computes the logarithm, + * multiplies it with the value at the same indices in the CSRMatrix, and passes it to the outer + * loop which performs a (thread safe) addition. + * + * Todo: Test parallelization of the two inner scf loops (replace with scf::parallel / scf::reduce instead of + * scf::yield) and enable lowering of parallel loops to different threads (e.g. using omp dialect/OpenMP + * library). + */ + auto outerLoop = rewriter.create(loc, TypeRange{resElementType}, + ArrayRef{arith::AtomicRMWKind::addf}, + ArrayRef{sparseLhsRows}); + rewriter.setInsertionPointToStart(outerLoop.getBody()); + { + Value rowPtr = rewriter.create(loc, argRowPtr, ValueRange{outerLoop.getIVs()[0]}); + Value nextRowPtr = rewriter.create( + loc, argRowPtr, AffineMap::get(1, 0, rewriter.getAffineDimExpr(0) + 1, rewriter.getContext()), + ValueRange{outerLoop.getIVs()[0]}); + + Value innerLoopAcc = rewriter.create(loc, rewriter.getZeroAttr(resElementType)); + + auto innerLoop = rewriter.create( + loc, rowPtr, nextRowPtr, rewriter.create(loc, 1), ValueRange{innerLoopAcc}, + [&](OpBuilder &OpBuilderNested, Location locNested, Value loopIdx, ValueRange loopInvariants) { + Value colIdx = OpBuilderNested.create(locNested, argColIdx, ValueRange{loopIdx}); + + Value resDotProd = OpBuilderNested.create( + locNested, OpBuilderNested.getZeroAttr(dotProdElementType)); + + auto dotProdLoop = OpBuilderNested.create( + locNested, rewriter.create(loc, 0), + rewriter.create(loc, denseLhsCols), + rewriter.create(loc, 1), ValueRange{resDotProd}, + [&](OpBuilder &OpBuilderTwiceNested, Location locTwiceNested, Value loopIdxNested, + ValueRange loopInvariantsNested) { + Value currentValLhs = OpBuilderTwiceNested.create( + locTwiceNested, denseLhsMemRef, + transa ? ValueRange{loopIdxNested, outerLoop.getIVs()[0]} + : ValueRange{outerLoop.getIVs()[0], loopIdxNested}); + Value currentValRhs = OpBuilderTwiceNested.create( + locTwiceNested, denseRhsMemRef, + transb ? ValueRange{colIdx, loopIdxNested} : ValueRange{loopIdxNested, colIdx}); + + Value accDotProd; + if (llvm::isa(dotProdElementType)) { + currentValLhs = convertToSignlessInt(OpBuilderTwiceNested, locTwiceNested, + typeConverter, currentValLhs, dotProdElementType); + currentValRhs = convertToSignlessInt(OpBuilderTwiceNested, locTwiceNested, + typeConverter, currentValRhs, dotProdElementType); + accDotProd = OpBuilderTwiceNested.create(locTwiceNested, currentValLhs, + currentValRhs); + accDotProd = OpBuilderTwiceNested.create(locTwiceNested, accDotProd, + loopInvariantsNested[0]); + } else { + accDotProd = OpBuilderTwiceNested.create( + locTwiceNested, currentValLhs, currentValRhs, loopInvariantsNested[0]); + } + + OpBuilderTwiceNested.create(locTwiceNested, ValueRange{accDotProd}); + }); // end dot product + + Value dotProdLoopRes = dotProdLoop.getResult(0); + if (llvm::isa(dotProdElementType)) { + dotProdLoopRes = this->typeConverter->materializeTargetConversion( + OpBuilderNested, locNested, resElementType, dotProdLoopRes); + } + Value mappedDotProd = OpBuilderNested.create(locNested, dotProdLoopRes); + + Value currentVal = + OpBuilderNested.create(locNested, argValues, ValueRange{loopIdx}); + + // After log has been applied, arg will always be a float type so FMA is no restriction + Value intersectAggRes = + OpBuilderNested.create(locNested, currentVal, mappedDotProd, loopInvariants[0]); + + OpBuilderNested.create(locNested, ValueRange{intersectAggRes}); + }); // end inner loop + + rewriter.create(loc, ValueRange{innerLoop.getResult(0)}); + } // end outer loop + rewriter.setInsertionPointAfter(outerLoop); + + rewriter.replaceOp(op, outerLoop->getResult(0)); + return success(); + } +}; + +// **************************************************************************** +// General Pass Setup +// **************************************************************************** + +namespace { +/** + * @brief This pass lowers a specific pattern of sparse-dense operations + * to a loop nest that avoids materializing potentially large dense intermediates. + * + * The matched pattern is sum(sparse * ln(dense @ dense)) + * where either dense matrix can be transposed. + */ +struct SparsityExploitationPass + : public mlir::PassWrapper> { + explicit SparsityExploitationPass() = default; + + void getDependentDialects(mlir::DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnOperation() final; + + [[nodiscard]] StringRef getArgument() const final { return "lower-sparse-exploit"; } + [[nodiscard]] StringRef getDescription() const final { + return "This pass lowers a sum(sparse * ln(dense @ dense)) pattern " + "to 3 nested loops that avoid materializing potentially large dense intermediates."; + } +}; +} // end anonymous namespace + +void SparsityExploitationPass::runOnOperation() { + mlir::ConversionTarget target(getContext()); + mlir::RewritePatternSet patterns(&getContext()); + mlir::LowerToLLVMOptions llvmOptions(&getContext()); + mlir::LLVMTypeConverter typeConverter(&getContext(), llvmOptions); + + typeConverter.addConversion(convertInteger); + typeConverter.addConversion(convertFloat); + typeConverter.addConversion([](Type type) { return type; }); + typeConverter.addArgumentMaterialization(materializeCastFromIllegal); + typeConverter.addSourceMaterialization(materializeCastToIllegal); + typeConverter.addTargetMaterialization(materializeCastFromIllegal); + + target.addLegalDialect(); + + // Marks Op as illegal only if it matches the pattern: sum(sparse * ln(dense @ dense)) + target.addDynamicallyLegalOp([](Operation *op) { + Operation *definingEwMulOp = op->getOperand(0).getDefiningOp(); + if (definingEwMulOp == nullptr) { + return true; + } + + Type lhsType = definingEwMulOp->getOperand(0).getType(); + auto lhsMatType = lhsType.dyn_cast(); + + Value rhs = definingEwMulOp->getOperand(1); + Type rhsType = rhs.getType(); + auto rhsMatType = rhsType.dyn_cast(); + + if (!lhsMatType || !rhsMatType) { + return true; + } + if (lhsMatType.getRepresentation() != daphne::MatrixRepresentation::Sparse || + rhsMatType.getRepresentation() != daphne::MatrixRepresentation::Dense) { + return true; + } + + Operation *definingEwLnOp = rhs.getDefiningOp(); + if (definingEwLnOp == nullptr) { + return true; + } + + Operation *definingMatMulOp = definingEwLnOp->getOperand(0).getDefiningOp(); + if (definingMatMulOp == nullptr) { + return true; + } + + Type matmulLhsType = definingMatMulOp->getOperand(0).getType(); + Type matmulRhsType = definingMatMulOp->getOperand(1).getType(); + auto matmulLhsMatType = matmulLhsType.dyn_cast(); + auto matmulRhsMatType = matmulRhsType.dyn_cast(); + + if (matmulLhsMatType.getRepresentation() != daphne::MatrixRepresentation::Dense || + matmulRhsMatType.getRepresentation() != daphne::MatrixRepresentation::Dense) { + return true; + } + + return false; + }); + + patterns.insert(typeConverter, &getContext()); + auto module = getOperation(); + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); +} + +std::unique_ptr daphne::createSparsityExploitationPass() { + return std::make_unique(); +} diff --git a/src/compiler/utils/CompilerUtils.h b/src/compiler/utils/CompilerUtils.h index 5a72d8bf2..9af38755b 100644 --- a/src/compiler/utils/CompilerUtils.h +++ b/src/compiler/utils/CompilerUtils.h @@ -212,7 +212,9 @@ struct CompilerUtils { return "Target"; else if (auto memRefType = t.dyn_cast()) { const std::string vtName = mlirTypeToCppTypeName(memRefType.getElementType(), angleBrackets, false); - return angleBrackets ? ("StridedMemRefType<" + vtName + ",2>") : ("StridedMemRefType_" + vtName + "_2"); + const std::string rankStr = std::to_string(memRefType.getRank()); + return angleBrackets ? ("StridedMemRefType<" + vtName + "," + rankStr + ">") + : ("StridedMemRefType_" + vtName + "_" + rankStr); } std::string typeName; diff --git a/src/ir/daphneir/DaphneOps.td b/src/ir/daphneir/DaphneOps.td index a9b437477..0f717cd9b 100644 --- a/src/ir/daphneir/DaphneOps.td +++ b/src/ir/daphneir/DaphneOps.td @@ -91,6 +91,27 @@ def Daphne_ConvertDenseMatrixToMemRef : Daphne_Op<"convertDenseMatrixToMemRef", let results = (outs AnyMemRef:$output); } +def Daphne_ConvertCSRMatrixToValuesMemRef : Daphne_Op<"convertCSRMatrixToValuesMemRef", [Pure]> { + let summary = "Given a CSRMatrix, return a StridedMemRefType."; + let description = [{ Constructs a StridedMemRefType with rank 1 containing the values of a CSRMatrix* with already allocated memory. }]; + let arguments = (ins MatrixOrU:$arg); + let results = (outs AnyMemRef:$resValues); +} + +def Daphne_ConvertCSRMatrixToColIdxsMemRef : Daphne_Op<"convertCSRMatrixToColIdxsMemRef", [Pure]> { + let summary = "Given a CSRMatrix, return a StridedMemRefType."; + let description = [{ Constructs a StridedMemRefType with rank 1 containing the column indices of a CSRMatrix* with already allocated memory. }]; + let arguments = (ins MatrixOrU:$arg); + let results = (outs AnyMemRef:$resColIdxs); +} + +def Daphne_ConvertCSRMatrixToRowOffsetsMemRef : Daphne_Op<"convertCSRMatrixToRowOffsetsMemRef", [Pure]> { + let summary = "Given a CSRMatrix, return a StridedMemRefType."; + let description = [{ Constructs a StridedMemRefType with rank 1 containing the row offsets of a CSRMatrix* with already allocated memory. }]; + let arguments = (ins MatrixOrU:$arg); + let results = (outs AnyMemRef:$resRowOffsets); +} + // **************************************************************************** // Data generation // **************************************************************************** @@ -195,7 +216,7 @@ def Daphne_MatMulOp : Daphne_Op<"matMul", [ DataTypeMat, ValueTypeFromArgs, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, CUDASupport, FPGAOPENCLSupport, - CastFirstTwoArgsToResType + CastFirstTwoArgsToResType, NoMemoryEffect ]> { let arguments = (ins MatrixOf<[NumScalar]>:$lhs, MatrixOf<[NumScalar]>:$rhs, BoolScalar:$transa, BoolScalar:$transb); let results = (outs MatrixOf<[NumScalar]>:$res); diff --git a/src/ir/daphneir/Passes.h b/src/ir/daphneir/Passes.h index d90b7a9a7..464aeb577 100644 --- a/src/ir/daphneir/Passes.h +++ b/src/ir/daphneir/Passes.h @@ -47,6 +47,7 @@ std::unique_ptr createDaphneOptPass(); std::unique_ptr createDistributeComputationsPass(); std::unique_ptr createDistributePipelinesPass(); std::unique_ptr createEwOpLoweringPass(); +std::unique_ptr createSparsityExploitationPass(); std::unique_ptr createInferencePass(InferenceConfig cfg = {false, true, true, true, true}); std::unique_ptr createInsertDaphneContextPass(const DaphneUserConfig &cfg); std::unique_ptr createLowerToLLVMPass(const DaphneUserConfig &cfg); diff --git a/src/ir/daphneir/Passes.td b/src/ir/daphneir/Passes.td index 46fe5295a..0f5ab3144 100644 --- a/src/ir/daphneir/Passes.td +++ b/src/ir/daphneir/Passes.td @@ -256,5 +256,8 @@ def LowerEwOpPass: Pass<"lower-ew", "::mlir::func::FuncOp"> { let constructor = "mlir::daphne::createEwOpLoweringPass()"; } +def SparsityExploitationPass: Pass<"lower-sparse-exploit", "::mlir::func::FuncOp"> { + let constructor = "mlir::daphne::createSparsityExploitationPass()"; +} #endif // SRC_IR_DAPHNEIR_PASSES_TD diff --git a/src/parser/catalog/KernelCatalogParser.cpp b/src/parser/catalog/KernelCatalogParser.cpp index 7560e7f4e..3c4c0fcfb 100644 --- a/src/parser/catalog/KernelCatalogParser.cpp +++ b/src/parser/catalog/KernelCatalogParser.cpp @@ -72,10 +72,11 @@ KernelCatalogParser::KernelCatalogParser(mlir::MLIRContext *mctx) { // MemRef type. if (!st.isa()) { // DAPHNE's StringType is not supported as the element type of a - // MemRef. The dimensions of the MemRef are irrelevant here, so we - // use {0, 0}. + // MemRef. The dimensions of the MemRef are irrelevant here. mlir::Type mrt = mlir::MemRefType::get({0, 0}, st); typeMap.emplace(CompilerUtils::mlirTypeToCppTypeName(mrt), mrt); + typeMap.emplace(CompilerUtils::mlirTypeToCppTypeName(mlir::MemRefType::get({0}, st)), + mlir::MemRefType::get({0}, st)); } } @@ -134,4 +135,4 @@ void KernelCatalogParser::parseKernelCatalog(const std::string &filePath, Kernel } catch (std::exception &e) { throw std::runtime_error("error while parsing kernel catalog file `" + filePath + "`: " + e.what()); } -} \ No newline at end of file +} diff --git a/src/runtime/local/datastructures/CSRMatrix.h b/src/runtime/local/datastructures/CSRMatrix.h index 08dc4a3e1..0298622bb 100644 --- a/src/runtime/local/datastructures/CSRMatrix.h +++ b/src/runtime/local/datastructures/CSRMatrix.h @@ -66,9 +66,9 @@ template class CSRMatrix : public Matrix { */ size_t maxNumNonZeros; - std::shared_ptr values; - std::shared_ptr colIdxs; - std::shared_ptr rowOffsets; + std::shared_ptr values; + std::shared_ptr colIdxs; + std::shared_ptr rowOffsets; size_t lastAppendedRowIdx; @@ -126,7 +126,7 @@ template class CSRMatrix : public Matrix { maxNumNonZeros = src->maxNumNonZeros; values = src->values; colIdxs = src->colIdxs; - rowOffsets = std::shared_ptr(src->rowOffsets, src->rowOffsets.get() + rowLowerIncl); + rowOffsets = std::shared_ptr(src->rowOffsets, src->rowOffsets.get() + rowLowerIncl); } virtual ~CSRMatrix() { @@ -174,6 +174,8 @@ template class CSRMatrix : public Matrix { const ValueType *getValues() const { return values.get(); } + std::shared_ptr getValuesSharedPtr() const { return values; } + ValueType *getValues(size_t rowIdx) { // We allow equality here to enable retrieving a pointer to the end. if (rowIdx > numRows) @@ -189,6 +191,8 @@ template class CSRMatrix : public Matrix { const size_t *getColIdxs() const { return colIdxs.get(); } + std::shared_ptr getColIdxsSharedPtr() const { return colIdxs; } + size_t *getColIdxs(size_t rowIdx) { // We allow equality here to enable retrieving a pointer to the end. if (rowIdx > numRows) @@ -204,6 +208,8 @@ template class CSRMatrix : public Matrix { const size_t *getRowOffsets() const { return rowOffsets.get(); } + std::shared_ptr getRowOffsetsSharedPtr() const { return rowOffsets; } + ValueType get(size_t rowIdx, size_t colIdx) const override { if (rowIdx >= numRows) throw std::runtime_error("CSRMatrix (get): rowIdx is out of bounds"); diff --git a/src/runtime/local/kernels/ConvertCSRMatrixToColIdxsMemRef.h b/src/runtime/local/kernels/ConvertCSRMatrixToColIdxsMemRef.h new file mode 100644 index 000000000..64f50b5cf --- /dev/null +++ b/src/runtime/local/kernels/ConvertCSRMatrixToColIdxsMemRef.h @@ -0,0 +1,36 @@ +/* + * Copyright 2024 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "mlir/ExecutionEngine/CRunnerUtils.h" +#include "runtime/local/context/DaphneContext.h" +#include "runtime/local/datastructures/CSRMatrix.h" + +template +inline StridedMemRefType convertCSRMatrixToColIdxsMemRef(const CSRMatrix *input, DCTX(ctx)) { + StridedMemRefType colIdxsMemRef{}; + + colIdxsMemRef.basePtr = input->getColIdxsSharedPtr().get(); + colIdxsMemRef.data = colIdxsMemRef.basePtr; + colIdxsMemRef.offset = 0; + colIdxsMemRef.sizes[0] = input->getNumNonZeros(); + colIdxsMemRef.strides[0] = 1; + + input->increaseRefCounter(); + + return colIdxsMemRef; +} diff --git a/src/runtime/local/kernels/ConvertCSRMatrixToRowOffsetsMemRef.h b/src/runtime/local/kernels/ConvertCSRMatrixToRowOffsetsMemRef.h new file mode 100644 index 000000000..0354e7df9 --- /dev/null +++ b/src/runtime/local/kernels/ConvertCSRMatrixToRowOffsetsMemRef.h @@ -0,0 +1,36 @@ +/* + * Copyright 2024 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "mlir/ExecutionEngine/CRunnerUtils.h" +#include "runtime/local/context/DaphneContext.h" +#include "runtime/local/datastructures/CSRMatrix.h" + +template +inline StridedMemRefType convertCSRMatrixToRowOffsetsMemRef(const CSRMatrix *input, DCTX(ctx)) { + StridedMemRefType rowOffsetsMemRef{}; + + rowOffsetsMemRef.basePtr = input->getRowOffsetsSharedPtr().get(); + rowOffsetsMemRef.data = rowOffsetsMemRef.basePtr; + rowOffsetsMemRef.offset = 0; + rowOffsetsMemRef.sizes[0] = input->getNumRows() + 1; + rowOffsetsMemRef.strides[0] = 1; + + input->increaseRefCounter(); + + return rowOffsetsMemRef; +} diff --git a/src/runtime/local/kernels/ConvertCSRMatrixToValuesMemRef.h b/src/runtime/local/kernels/ConvertCSRMatrixToValuesMemRef.h new file mode 100644 index 000000000..eca515fa0 --- /dev/null +++ b/src/runtime/local/kernels/ConvertCSRMatrixToValuesMemRef.h @@ -0,0 +1,36 @@ +/* + * Copyright 2024 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "mlir/ExecutionEngine/CRunnerUtils.h" +#include "runtime/local/context/DaphneContext.h" +#include "runtime/local/datastructures/CSRMatrix.h" + +template +inline StridedMemRefType convertCSRMatrixToValuesMemRef(const CSRMatrix *input, DCTX(ctx)) { + StridedMemRefType valuesMemRef{}; + + valuesMemRef.basePtr = input->getValuesSharedPtr().get(); + valuesMemRef.data = valuesMemRef.basePtr; + valuesMemRef.offset = 0; + valuesMemRef.sizes[0] = input->getNumNonZeros(); // Is numRowsAllocated needed to account for views? + valuesMemRef.strides[0] = 1; + + input->increaseRefCounter(); + + return valuesMemRef; +} diff --git a/src/runtime/local/kernels/ConvertDenseMatrixToMemRef.h b/src/runtime/local/kernels/ConvertDenseMatrixToMemRef.h index 1cabfcaa8..19191bada 100644 --- a/src/runtime/local/kernels/ConvertDenseMatrixToMemRef.h +++ b/src/runtime/local/kernels/ConvertDenseMatrixToMemRef.h @@ -17,6 +17,7 @@ #pragma once #include "mlir/ExecutionEngine/CRunnerUtils.h" +#include "runtime/local/context/DaphneContext.h" #include "runtime/local/datastructures/DenseMatrix.h" template diff --git a/src/runtime/local/kernels/kernels.json b/src/runtime/local/kernels/kernels.json index 4058f45d0..4a11448b6 100644 --- a/src/runtime/local/kernels/kernels.json +++ b/src/runtime/local/kernels/kernels.json @@ -1539,6 +1539,93 @@ ["double"] ] }, + { + "kernelTemplate": { + "header": "ConvertCSRMatrixToValuesMemRef.h", + "opName": "convertCSRMatrixToValuesMemRef", + "returnType": "StridedMemRefType", + "templateParams": [ + { + "name": "VT", + "isDataType": false + } + ], + "runtimeParams": [ + { + "type": "CSRMatrix *", + "name": "input" + } + ] + }, + "instantiations": [ + ["int64_t"], + ["int32_t"], + ["int8_t"], + ["uint64_t"], + ["uint32_t"], + ["uint8_t"], + ["float"], + ["double"] + ] + }, + { + "kernelTemplate": { + "header": "ConvertCSRMatrixToColIdxsMemRef.h", + "opName": "convertCSRMatrixToColIdxsMemRef", + "returnType": "StridedMemRefType", + "templateParams": [ + { + "name": "VT", + "isDataType": false + } + ], + "runtimeParams": [ + { + "type": "CSRMatrix *", + "name": "input" + } + ] + }, + "instantiations": [ + ["int64_t"], + ["int32_t"], + ["int8_t"], + ["uint64_t"], + ["uint32_t"], + ["uint8_t"], + ["float"], + ["double"] + ] + }, + { + "kernelTemplate": { + "header": "ConvertCSRMatrixToRowOffsetsMemRef.h", + "opName": "convertCSRMatrixToRowOffsetsMemRef", + "returnType": "StridedMemRefType", + "templateParams": [ + { + "name": "VT", + "isDataType": false + } + ], + "runtimeParams": [ + { + "type": "CSRMatrix *", + "name": "input" + } + ] + }, + "instantiations": [ + ["int64_t"], + ["int32_t"], + ["int8_t"], + ["uint64_t"], + ["uint32_t"], + ["uint8_t"], + ["float"], + ["double"] + ] + }, { "kernelTemplate": { "header": "CreateFrame.h", From c062ba14e87fe96b696d226f32ad2a333c37cad3 Mon Sep 17 00:00:00 2001 From: AlexRTer <74372589+AlexRTer@users.noreply.github.com> Date: Tue, 19 Nov 2024 07:10:29 +0100 Subject: [PATCH 04/17] [codegen] broadcasting of rhs in ewBinary codegen This PR changes the existing ewBinaryOp codegen to broadcast rhs if possible. So far, the ewBinary codegen only accepted scalars, matrices of equal shape or combinations of lhs being a matrix and rhs being either a scalar or singleton. Now rhs can also be given as a matching (equal to lhs in one dimension and size 1 in the other) row or column vector. - changed index map for rhs in existing linalg genericOp to allow broadcasting of row/column vectors - added shape checks to ensure matching dimensions when rhs is broadcast closes #920 --- src/compiler/lowering/EwOpsLowering.cpp | 106 ++++++++++++++++++------ 1 file changed, 81 insertions(+), 25 deletions(-) diff --git a/src/compiler/lowering/EwOpsLowering.cpp b/src/compiler/lowering/EwOpsLowering.cpp index 0762f171d..79044471a 100644 --- a/src/compiler/lowering/EwOpsLowering.cpp +++ b/src/compiler/lowering/EwOpsLowering.cpp @@ -31,6 +31,7 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinTypes.h" @@ -142,14 +143,69 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { this->setDebugName("EwDaphneOpLowering"); } + /** + * @brief Returns an affine map for indexing the rhs operand. + * Assumes that neither matrix is a singleton and lhs is not broadcast. + * + * If rhs has no dimensions of size 1, returns an identity map. + * Else, returns a map (i,j)->(0,j) or (i,j)->(i,0) to enable broadcasting of rhs. + */ + AffineMap buildRhsAffineMap(Location loc, ConversionPatternRewriter &rewriter, ssize_t lhsRows, ssize_t lhsCols, + ssize_t rhsRows, ssize_t rhsCols) const { + + AffineMap rhsAffineMap; + + // lhs could also be a row/column vector which should not be handled as broadcasting (even though the resulting + // affine maps coincide). This allows for a clearer error message as well. + if (lhsRows != 1 && rhsRows == 1) { + // rhs is a row vector, broadcast along columns + if (lhsCols != rhsCols) { + throw ErrorHandler::compilerError( + loc, "EwOpsLowering (BinaryOp)", + "could not broadcast rhs along columns. Rhs must " + "be a scalar value, singleton matrix or have an equal amount of column to " + "be broadcast but operands have dimensions (" + + std::to_string(lhsRows) + "," + std::to_string(lhsCols) + ") and (" + std::to_string(rhsRows) + + "," + std::to_string(rhsCols) + ")"); + } + rhsAffineMap = AffineMap::get(2, 0, {rewriter.getAffineConstantExpr(0), rewriter.getAffineDimExpr(1)}, + rewriter.getContext()); + } else if (lhsCols != 1 && rhsCols == 1) { + // rhs is a column vector, broadcast along rows + if (lhsRows != rhsRows) { + throw ErrorHandler::compilerError( + loc, "EwOpsLowering (BinaryOp)", + "could not broadcast rhs along rows. Rhs must " + "be a scalar value, singleton matrix or have an equal amount of rows to " + "be broadcast but operands have dimensions (" + + std::to_string(lhsRows) + "," + std::to_string(lhsCols) + ") and (" + std::to_string(rhsRows) + + "," + std::to_string(rhsCols) + ")"); + } + rhsAffineMap = AffineMap::get(2, 0, {rewriter.getAffineDimExpr(0), rewriter.getAffineConstantExpr(0)}, + rewriter.getContext()); + } else { + // rhs is not broadcasted, return identity mapping + if (lhsRows != rhsRows || lhsCols != rhsCols) { + throw ErrorHandler::compilerError( + loc, "EwOpsLowering (BinaryOp)", + "lhs and rhs must have equal dimensions or allow for broadcasting but operands have dimensions (" + + std::to_string(lhsRows) + "," + std::to_string(lhsCols) + ") and (" + std::to_string(rhsRows) + + "," + std::to_string(rhsCols) + ")"); + } + rhsAffineMap = AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()); + } + + return rhsAffineMap; + } + LogicalResult matchAndRewriteScalarVal(BinaryOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { rewriter.replaceOp(op, binaryFunc(rewriter, op.getLoc(), this->typeConverter, adaptor.getLhs(), adaptor.getRhs())); return mlir::success(); } - LogicalResult matchAndRewriteBroadcastRhs(BinaryOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, - Value &rhs) const { + LogicalResult matchAndRewriteBroadcastScalarRhs(BinaryOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, + Value &rhs) const { Location loc = op->getLoc(); Value lhs = adaptor.getLhs(); @@ -160,7 +216,7 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { Type matrixElementType = lhsMatrixType.getElementType(); MemRefType argMemRefType = MemRefType::get({lhsRows, lhsCols}, matrixElementType); - Value lhsMemref = rewriter.create(loc, argMemRefType, lhs); + auto lhsMemref = rewriter.create(loc, argMemRefType, lhs); Value resMemref = rewriter.create(loc, argMemRefType); @@ -189,13 +245,13 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { auto lhsMatrixType = lhs.getType().template dyn_cast(); auto rhsMatrixType = rhs.getType().template dyn_cast(); - // Match Scalar-Scalar and Matrix-Scalar broadcasting (assuming scalar values are always switched to rhs). - // Broadcasting where either Matrix is a singleton needs to be handled separately below. + // Match Scalar-Scalar and Matrix-Scalar broadcasting (assuming scalar values are always switched to + // rhs). Broadcasting where either Matrix is a singleton or vector needs to be handled separately below. if (!rhsMatrixType) { if (!lhsMatrixType) { return matchAndRewriteScalarVal(op, adaptor, rewriter); } - return matchAndRewriteBroadcastRhs(op, adaptor, rewriter, rhs); + return matchAndRewriteBroadcastScalarRhs(op, adaptor, rewriter, rhs); } Type matrixElementType = lhsMatrixType.getElementType(); @@ -210,9 +266,10 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { op, "ewOps codegen currently only works with matrix dimensions that are known at compile time"); } - // Assume that if only one matrix contains a single value for broadcasting it is rhs. + // For efficiency, broadcasting a singleton is handled separately here (assumes singleton is always rhs). + // Broadcasting of row/column vectors is handled during the construction of the index map for rhs below. if ((lhsRows != 1 || lhsCols != 1) && rhsRows == 1 && rhsCols == 1) { - Value rhsMemref = rewriter.create( + auto rhsMemref = rewriter.create( loc, MemRefType::get({1, 1}, matrixElementType), rhs); Value rhsBroadcastVal = rewriter @@ -220,25 +277,21 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { ValueRange{rewriter.create(loc, 0), rewriter.create(loc, 0)}) .getResult(); - return matchAndRewriteBroadcastRhs(op, adaptor, rewriter, rhsBroadcastVal); + return matchAndRewriteBroadcastScalarRhs(op, adaptor, rewriter, rhsBroadcastVal); } - if (lhsRows != rhsRows || lhsCols != rhsCols) { - throw ErrorHandler::compilerError(loc, "EwOpsLowering (BinaryOp)", - "lhs and rhs must have equal dimensions or either one must " - "be a scalar value but have dimensions (" + - std::to_string(lhsRows) + "," + std::to_string(lhsCols) + ") and (" + - std::to_string(rhsRows) + "," + std::to_string(rhsCols) + ")"); - } + MemRefType lhsMemRefType = MemRefType::get({lhsRows, lhsCols}, matrixElementType); + MemRefType rhsMemRefType = MemRefType::get({rhsRows, rhsCols}, matrixElementType); + auto lhsMemref = rewriter.create(loc, lhsMemRefType, lhs); + auto rhsMemref = rewriter.create(loc, rhsMemRefType, rhs); - MemRefType argMemRefType = MemRefType::get({lhsRows, lhsCols}, matrixElementType); - Value lhsMemref = rewriter.create(loc, argMemRefType, lhs); - Value rhsMemref = rewriter.create(loc, argMemRefType, rhs); - - Value resMemref = rewriter.create(loc, argMemRefType); + // If any broadcasting occurs, it is assumed to be rhs so res inherits its shape from lhs. + Value resMemref = rewriter.create(loc, lhsMemRefType); + // Builds an affine map to index the args and accounts for broadcasting of rhs. + // Creation of rhs indexing map checks whether or not the dimensions match and returns a compiler error if not. SmallVector indexMaps = {AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()), - AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()), + buildRhsAffineMap(loc, rewriter, lhsRows, lhsCols, rhsRows, rhsCols), AffineMap::getMultiDimIdentityMap(2, rewriter.getContext())}; SmallVector iterTypes = {utils::IteratorType::parallel, utils::IteratorType::parallel}; @@ -364,7 +417,8 @@ using CosOpLowering = UnaryOpLowering>; // Rounding -// Prior canonicalization pass removes rounding ops on integers, meaning only f32/f64 types need to be dealt with +// Prior canonicalization pass removes rounding ops on integers, meaning only f32/f64 types need to be dealt +// with using FloorOpLowering = UnaryOpLowering>; using CeilOpLowering = UnaryOpLowering>; using RoundOpLowering = UnaryOpLowering>; @@ -388,8 +442,10 @@ using MinOpLowering = // Logical // using AndOpLowering = -// BinaryOpLowering>; // distinguish AndFOp -// using OrOpLowering = BinaryOpLowering>; // - " - +// BinaryOpLowering>; // distinguish +// AndFOp +// using OrOpLowering = BinaryOpLowering>; +// // - " - // **************************************************************************** // General Pass Setup From 9d33ec57f7550f7bf8ae573f308e97c681c6f317 Mon Sep 17 00:00:00 2001 From: philipportner Date: Tue, 26 Nov 2024 10:38:37 +0100 Subject: [PATCH 05/17] [codegen] EwOps ErrorHandler instead of notify As DAPHNE currently does not properly integrate with the MLIR diagnostics infrastructure, but rather relies on exceptions, this replaces a couple of useages of the notifyMatchFailure function with the DAPHNE ErrorHandler. --- src/compiler/lowering/EwOpsLowering.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/compiler/lowering/EwOpsLowering.cpp b/src/compiler/lowering/EwOpsLowering.cpp index 79044471a..256795529 100644 --- a/src/compiler/lowering/EwOpsLowering.cpp +++ b/src/compiler/lowering/EwOpsLowering.cpp @@ -92,8 +92,9 @@ template struct UnaryOpLowering : publi ssize_t numCols = matrixType.getNumCols(); if (numRows < 0 || numCols < 0) { - return rewriter.notifyMatchFailure( - op, "ewOps codegen currently only works with matrix dimensions that are known at compile time"); + throw ErrorHandler::compilerError( + loc, "EwOpsLowering (BinaryOp)", + "ewOps codegen currently only works with matrix dimensions that are known at compile time"); } Value argMemref = rewriter.create( @@ -262,8 +263,9 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { ssize_t rhsCols = rhsMatrixType.getNumCols(); if (lhsRows < 0 || lhsCols < 0 || rhsRows < 0 || rhsCols < 0) { - return rewriter.notifyMatchFailure( - op, "ewOps codegen currently only works with matrix dimensions that are known at compile time"); + throw ErrorHandler::compilerError( + loc, "EwOpsLowering (BinaryOp)", + "ewOps codegen currently only works with matrix dimensions that are known at compile time"); } // For efficiency, broadcasting a singleton is handled separately here (assumes singleton is always rhs). From 6b06ad2786a289ed97bdf4554f59cdb7583df0d6 Mon Sep 17 00:00:00 2001 From: Patrick Damme Date: Tue, 26 Nov 2024 14:03:04 +0100 Subject: [PATCH 06/17] Allowing DaphneDSL kernel hints for some operator symbols. - So far, DaphneDSL kernel hints were only supported for built-in functions, e.g., "sum::mySumKernel(...)". - However, many DaphneIR operations can only be created through DaphneDSL operator symbols. - This commit adds support for attaching kernel hints to some DaphneDSL operator symbols ("+", "-", "*", "/"), e.g., "[2] +::myAddKernel [3]". - Adapted the DaphneDSL grammar and visitors accordingly. - Created a few script-level test cases. --- src/parser/daphnedsl/DaphneDSLGrammar.g4 | 4 +- src/parser/daphnedsl/DaphneDSLVisitor.cpp | 48 ++++++++++++++++--- test/api/cli/extensibility/HintTest.cpp | 2 +- .../hint_kernel_success_1.daphne | 2 +- .../hint_kernel_success_2.daphne | 2 +- .../hint_kernel_success_3.daphne | 2 +- .../hint_kernel_success_4.daphne | 4 ++ .../extensibility/hint_kernel_success_4.txt | 2 + .../hint_kernel_success_5.daphne | 4 ++ .../extensibility/hint_kernel_success_5.txt | 2 + 10 files changed, 60 insertions(+), 12 deletions(-) create mode 100644 test/api/cli/extensibility/hint_kernel_success_4.daphne create mode 100644 test/api/cli/extensibility/hint_kernel_success_4.txt create mode 100644 test/api/cli/extensibility/hint_kernel_success_5.daphne create mode 100644 test/api/cli/extensibility/hint_kernel_success_5.txt diff --git a/src/parser/daphnedsl/DaphneDSLGrammar.g4 b/src/parser/daphnedsl/DaphneDSLGrammar.g4 index 53b8782f5..82ec2e095 100644 --- a/src/parser/daphnedsl/DaphneDSLGrammar.g4 +++ b/src/parser/daphnedsl/DaphneDSLGrammar.g4 @@ -91,8 +91,8 @@ expr: | lhs=expr op='@' rhs=expr # matmulExpr | lhs=expr op='^' rhs=expr # powExpr | lhs=expr op='%' rhs=expr # modExpr - | lhs=expr op=('*'|'/') rhs=expr # mulExpr - | lhs=expr op=('+'|'-') rhs=expr # addExpr + | lhs=expr op=('*'|'/') ('::' kernel=IDENTIFIER)? rhs=expr # mulExpr + | lhs=expr op=('+'|'-') ('::' kernel=IDENTIFIER)? rhs=expr # addExpr | lhs=expr op=('=='|'!='|'<'|'<='|'>'|'>=') rhs=expr # cmpExpr | lhs=expr op='&&' rhs=expr # conjExpr | lhs=expr op='||' rhs=expr # disjExpr diff --git a/src/parser/daphnedsl/DaphneDSLVisitor.cpp b/src/parser/daphnedsl/DaphneDSLVisitor.cpp index fab458673..21d6e0424 100644 --- a/src/parser/daphnedsl/DaphneDSLVisitor.cpp +++ b/src/parser/daphnedsl/DaphneDSLVisitor.cpp @@ -1197,13 +1197,30 @@ antlrcpp::Any DaphneDSLVisitor::visitMulExpr(DaphneDSLGrammarParser::MulExprCont mlir::Location loc = utils.getLoc(ctx->op); mlir::Value lhs = valueOrErrorOnVisit(ctx->lhs); mlir::Value rhs = valueOrErrorOnVisit(ctx->rhs); + bool hasKernelHint = ctx->kernel != nullptr; + mlir::Value res = nullptr; if (op == "*") - return utils.retValWithInferedType(builder.create(loc, lhs, rhs)); + res = utils.retValWithInferedType(builder.create(loc, lhs, rhs)); if (op == "/") - return utils.retValWithInferedType(builder.create(loc, lhs, rhs)); + res = utils.retValWithInferedType(builder.create(loc, lhs, rhs)); - throw ErrorHandler::compilerError(utils.getLoc(ctx->start), "DSLVisitor", "unexpected op symbol"); + if (hasKernelHint) { + std::string kernel = ctx->kernel->getText(); + + // We deliberately don't check if the specified kernel + // is registered for the created kind of operation, + // since this is checked in RewriteToCallKernelOpPass. + + mlir::Operation *op = res.getDefiningOp(); + // TODO Don't hardcode the attribute name. + op->setAttr("kernel_hint", builder.getStringAttr(kernel)); + } + + if (res) + return res; + else + throw ErrorHandler::compilerError(utils.getLoc(ctx->start), "DSLVisitor", "unexpected op symbol"); } antlrcpp::Any DaphneDSLVisitor::visitAddExpr(DaphneDSLGrammarParser::AddExprContext *ctx) { @@ -1211,7 +1228,9 @@ antlrcpp::Any DaphneDSLVisitor::visitAddExpr(DaphneDSLGrammarParser::AddExprCont mlir::Location loc = utils.getLoc(ctx->op); mlir::Value lhs = valueOrErrorOnVisit(ctx->lhs); mlir::Value rhs = valueOrErrorOnVisit(ctx->rhs); + bool hasKernelHint = ctx->kernel != nullptr; + mlir::Value res = nullptr; if (op == "+") // Note that we use '+' for both addition (EwAddOp) and concatenation // (EwConcatOp). The choice is made based on the types of the operands @@ -1219,11 +1238,28 @@ antlrcpp::Any DaphneDSLVisitor::visitAddExpr(DaphneDSLGrammarParser::AddExprCont // types might not be known at this point in time. Thus, we always // create an EwAddOp here. Note that EwAddOp has a canonicalize method // rewriting it to EwConcatOp if necessary. - return utils.retValWithInferedType(builder.create(loc, lhs, rhs)); + res = utils.retValWithInferedType(builder.create(loc, lhs, rhs)); if (op == "-") - return utils.retValWithInferedType(builder.create(loc, lhs, rhs)); + res = utils.retValWithInferedType(builder.create(loc, lhs, rhs)); - throw ErrorHandler::compilerError(utils.getLoc(ctx->start), "DSLVisitor", "unexpected op symbol"); + if (hasKernelHint) { + std::string kernel = ctx->kernel->getText(); + + // We deliberately don't check if the specified kernel + // is registered for the created kind of operation, + // since this is checked in RewriteToCallKernelOpPass. + + mlir::Operation *op = res.getDefiningOp(); + // TODO Don't hardcode the attribute name. + op->setAttr("kernel_hint", builder.getStringAttr(kernel)); + + // TODO retain the attr in case EwAddOp is rewritten to EwConcatOp. + } + + if (res) + return res; + else + throw ErrorHandler::compilerError(utils.getLoc(ctx->start), "DSLVisitor", "unexpected op symbol"); } antlrcpp::Any DaphneDSLVisitor::visitCmpExpr(DaphneDSLGrammarParser::CmpExprContext *ctx) { diff --git a/test/api/cli/extensibility/HintTest.cpp b/test/api/cli/extensibility/HintTest.cpp index 94de4c6f1..cf120c142 100644 --- a/test/api/cli/extensibility/HintTest.cpp +++ b/test/api/cli/extensibility/HintTest.cpp @@ -58,7 +58,7 @@ MAKE_FAILURE_TEST_CASE("hint_kernel", 3) // Check if DAPHNE terminates normally when expected and produces the expected // output. -MAKE_SUCCESS_TEST_CASE("hint_kernel", 3) +MAKE_SUCCESS_TEST_CASE("hint_kernel", 5) // Check if DAPHNE terminates normally when expected and if the IR really // contains the kernel hint. diff --git a/test/api/cli/extensibility/hint_kernel_success_1.daphne b/test/api/cli/extensibility/hint_kernel_success_1.daphne index bc235a592..f3c4a73a2 100644 --- a/test/api/cli/extensibility/hint_kernel_success_1.daphne +++ b/test/api/cli/extensibility/hint_kernel_success_1.daphne @@ -1,3 +1,3 @@ -// Hint to use an existing pre-compiled kernel for a DaphneIR op with exactly zero results. +// Hint to use an existing pre-compiled kernel for a DaphneIR op with exactly zero results (DaphneDSL built-in function). print::_print__int64_t__bool__bool(42); \ No newline at end of file diff --git a/test/api/cli/extensibility/hint_kernel_success_2.daphne b/test/api/cli/extensibility/hint_kernel_success_2.daphne index 5bff4eb99..3cd1a27c8 100644 --- a/test/api/cli/extensibility/hint_kernel_success_2.daphne +++ b/test/api/cli/extensibility/hint_kernel_success_2.daphne @@ -1,4 +1,4 @@ -// Hint to use an existing pre-compiled kernel for a DaphneIR op with exactly one result. +// Hint to use an existing pre-compiled kernel for a DaphneIR op with exactly one result (DaphneDSL built-in function). res = sum::_sumAll__int64_t__DenseMatrix_int64_t([21, 21]); print(res); \ No newline at end of file diff --git a/test/api/cli/extensibility/hint_kernel_success_3.daphne b/test/api/cli/extensibility/hint_kernel_success_3.daphne index f17fa1353..165b5ed47 100644 --- a/test/api/cli/extensibility/hint_kernel_success_3.daphne +++ b/test/api/cli/extensibility/hint_kernel_success_3.daphne @@ -1,4 +1,4 @@ -// Hint to use an existing pre-compiled kernel for a DaphneIR op with more than one result. +// Hint to use an existing pre-compiled kernel for a DaphneIR op with more than one result (DaphneDSL built-in function). codes, dict = recode::_recode__DenseMatrix_int64_t__DenseMatrix_double__DenseMatrix_double__bool([1.1, 3.3, 1.1, 2.2], false); print(codes); diff --git a/test/api/cli/extensibility/hint_kernel_success_4.daphne b/test/api/cli/extensibility/hint_kernel_success_4.daphne new file mode 100644 index 000000000..40be8637c --- /dev/null +++ b/test/api/cli/extensibility/hint_kernel_success_4.daphne @@ -0,0 +1,4 @@ +// Hint to use an existing pre-compiled kernel for a DaphneIR op with exactly one result (DaphneDSL operator symbol, addExpr (binary +-)). + +res = [2] +::_ewAdd__DenseMatrix_int64_t__DenseMatrix_int64_t__DenseMatrix_int64_t [3]; +print(res); \ No newline at end of file diff --git a/test/api/cli/extensibility/hint_kernel_success_4.txt b/test/api/cli/extensibility/hint_kernel_success_4.txt new file mode 100644 index 000000000..6a5ef86f3 --- /dev/null +++ b/test/api/cli/extensibility/hint_kernel_success_4.txt @@ -0,0 +1,2 @@ +DenseMatrix(1x1, int64_t) +5 diff --git a/test/api/cli/extensibility/hint_kernel_success_5.daphne b/test/api/cli/extensibility/hint_kernel_success_5.daphne new file mode 100644 index 000000000..05444480e --- /dev/null +++ b/test/api/cli/extensibility/hint_kernel_success_5.daphne @@ -0,0 +1,4 @@ +// Hint to use an existing pre-compiled kernel for a DaphneIR op with exactly one result (DaphneDSL operator symbol, mulExpr (binary */)). + +res = [2] *::_ewMul__DenseMatrix_int64_t__DenseMatrix_int64_t__DenseMatrix_int64_t [3]; +print(res); \ No newline at end of file diff --git a/test/api/cli/extensibility/hint_kernel_success_5.txt b/test/api/cli/extensibility/hint_kernel_success_5.txt new file mode 100644 index 000000000..4d02485d0 --- /dev/null +++ b/test/api/cli/extensibility/hint_kernel_success_5.txt @@ -0,0 +1,2 @@ +DenseMatrix(1x1, int64_t) +6 From 741faf4f20b0fbc2515baede3da186146601c644 Mon Sep 17 00:00:00 2001 From: Patrick Damme Date: Tue, 26 Nov 2024 14:37:42 +0100 Subject: [PATCH 07/17] More detailed --explain steps for MLIR codegen. - So far, there was a single explain flag for the DAPHNE compiler's MLIR-based codegen backend. - However, the MLIR-based codegen involves multiple steps and it would be great to view them separately. - This commit adds more detailed explain flags for MLIR-based codegen. - "--explain mlir_codegen_sparsity_exploiting_op_fusion" - "--explain mlir_codegen_daphneir_to_mlir" - "--explain mlir_codegen_mlir_specific" - The existing flag "--explain mlir_codegen" still exists and works as usual. - To simplify things for those who were used to it. - In the future, we can think about merging it with the newly added flags. --- UserConfig.json | 3 ++ src/api/cli/DaphneUserConfig.h | 3 ++ src/api/internal/daphne_internal.cpp | 46 ++++++++++++++------- src/compiler/execution/DaphneIrExecutor.cpp | 7 ++++ src/parser/config/ConfigParser.cpp | 9 ++++ src/parser/config/JsonParams.h | 7 ++++ 6 files changed, 61 insertions(+), 14 deletions(-) diff --git a/UserConfig.json b/UserConfig.json index 316c22938..4754c4fd9 100644 --- a/UserConfig.json +++ b/UserConfig.json @@ -24,6 +24,9 @@ "explain_vectorized": false, "explain_obj_ref_mgnt": false, "explain_mlir_codegen": false, + "explain_mlir_codegen_sparsity_exploiting_op_fusion": false, + "explain_mlir_codegen_daphneir_to_mlir": false, + "explain_mlir_codegen_mlir_specific": false, "taskPartitioningScheme": "STATIC", "numberOfThreads": -1, "minimumTaskSize": 1, diff --git a/src/api/cli/DaphneUserConfig.h b/src/api/cli/DaphneUserConfig.h index 13bfdf4b1..8409d7546 100644 --- a/src/api/cli/DaphneUserConfig.h +++ b/src/api/cli/DaphneUserConfig.h @@ -74,6 +74,9 @@ struct DaphneUserConfig { bool explain_vectorized = false; bool explain_obj_ref_mgnt = false; bool explain_mlir_codegen = false; + bool explain_mlir_codegen_sparsity_exploiting_op_fusion = false; + bool explain_mlir_codegen_daphneir_to_mlir = false; + bool explain_mlir_codegen_mlir_specific = false; bool statistics = false; bool force_cuda = false; diff --git a/src/api/internal/daphne_internal.cpp b/src/api/internal/daphne_internal.cpp index b63be6f0b..e5c8fb0c0 100644 --- a/src/api/internal/daphne_internal.cpp +++ b/src/api/internal/daphne_internal.cpp @@ -287,26 +287,35 @@ int startDAPHNE(int argc, const char **argv, DaphneLibResult *daphneLibRes, int type_adaptation, vectorized, obj_ref_mgnt, - mlir_codegen + mlir_codegen, + mlir_codegen_sparsity_exploiting_op_fusion, + mlir_codegen_daphneir_to_mlir, + mlir_codegen_mlir_specific }; static llvm::cl::list explainArgList( "explain", cat(daphneOptions), llvm::cl::desc("Show DaphneIR after certain compiler passes (separate " "multiple values by comma, the order is irrelevant)"), - llvm::cl::values(clEnumVal(parsing, "Show DaphneIR after parsing"), - clEnumVal(parsing_simplified, "Show DaphneIR after parsing and some simplifications"), - clEnumVal(sql, "Show DaphneIR after SQL parsing"), - clEnumVal(property_inference, "Show DaphneIR after property inference"), - clEnumVal(select_matrix_repr, "Show DaphneIR after selecting " - "physical matrix representations"), - clEnumVal(phy_op_selection, "Show DaphneIR after selecting physical operators"), - clEnumVal(type_adaptation, "Show DaphneIR after adapting types to available kernels"), - clEnumVal(vectorized, "Show DaphneIR after vectorization"), - clEnumVal(obj_ref_mgnt, "Show DaphneIR after managing object references"), - clEnumVal(kernels, "Show DaphneIR after kernel lowering"), - clEnumVal(llvm, "Show DaphneIR after llvm lowering"), - clEnumVal(mlir_codegen, "Show DaphneIR after MLIR codegen")), + llvm::cl::values( + clEnumVal(parsing, "Show DaphneIR after parsing"), + clEnumVal(parsing_simplified, "Show DaphneIR after parsing and some simplifications"), + clEnumVal(sql, "Show DaphneIR after SQL parsing"), + clEnumVal(property_inference, "Show DaphneIR after property inference"), + clEnumVal(select_matrix_repr, "Show DaphneIR after selecting " + "physical matrix representations"), + clEnumVal(phy_op_selection, "Show DaphneIR after selecting physical operators"), + clEnumVal(type_adaptation, "Show DaphneIR after adapting types to available kernels"), + clEnumVal(vectorized, "Show DaphneIR after vectorization"), + clEnumVal(obj_ref_mgnt, "Show DaphneIR after managing object references"), + clEnumVal(kernels, "Show DaphneIR after kernel lowering"), + clEnumVal(llvm, "Show DaphneIR after llvm lowering"), + clEnumVal(mlir_codegen, "Show DaphneIR after MLIR codegen"), + clEnumVal(mlir_codegen_sparsity_exploiting_op_fusion, + "Show DaphneIR after MLIR codegen (sparsity-exploiting operator fusion)"), + clEnumVal(mlir_codegen_daphneir_to_mlir, "Show DaphneIR after MLIR codegen (DaphneIR to MLIR)"), + clEnumVal(mlir_codegen_mlir_specific, "Show DaphneIR after MLIR codegen (MLIR-specific)"), + clEnumVal(llvm, "Show DaphneIR after llvm lowering")), CommaSeparated); static llvm::cl::list scriptArgs1("args", cat(daphneOptions), @@ -479,6 +488,15 @@ int startDAPHNE(int argc, const char **argv, DaphneLibResult *daphneLibRes, int case mlir_codegen: user_config.explain_mlir_codegen = true; break; + case mlir_codegen_sparsity_exploiting_op_fusion: + user_config.explain_mlir_codegen_sparsity_exploiting_op_fusion = true; + break; + case mlir_codegen_daphneir_to_mlir: + user_config.explain_mlir_codegen_daphneir_to_mlir = true; + break; + case mlir_codegen_mlir_specific: + user_config.explain_mlir_codegen_mlir_specific = true; + break; } } diff --git a/src/compiler/execution/DaphneIrExecutor.cpp b/src/compiler/execution/DaphneIrExecutor.cpp index 651c4937a..0e5c0c4fd 100644 --- a/src/compiler/execution/DaphneIrExecutor.cpp +++ b/src/compiler/execution/DaphneIrExecutor.cpp @@ -263,6 +263,8 @@ void DaphneIrExecutor::buildCodegenPipeline(mlir::PassManager &pm) { // SparseExploit fuses multiple operations which only need to be lowered if still needed elsewhere. // Todo: if possible, run only if SparseExploitLowering was successful. pm.addPass(mlir::createCanonicalizerPass()); + if (userConfig_.explain_mlir_codegen_sparsity_exploiting_op_fusion) + pm.addPass(mlir::daphne::createPrintIRPass("IR after MLIR codegen (sparsity-exploiting operator fusion):")); pm.addPass(mlir::daphne::createEwOpLoweringPass()); pm.addPass(mlir::daphne::createAggAllOpLoweringPass()); @@ -283,6 +285,9 @@ void DaphneIrExecutor::buildCodegenPipeline(mlir::PassManager &pm) { pm.addPass(mlir::daphne::createPrintIRPass("IR directly after lowering MatMulOp.")); } + if (userConfig_.explain_mlir_codegen_daphneir_to_mlir) + pm.addPass(mlir::daphne::createPrintIRPass("IR after MLIR codegen (DaphneIR to MLIR):")); + pm.addPass(mlir::createConvertMathToLLVMPass()); pm.addPass(mlir::daphne::createModOpLoweringPass()); pm.addPass(mlir::createCanonicalizerPass()); @@ -301,4 +306,6 @@ void DaphneIrExecutor::buildCodegenPipeline(mlir::PassManager &pm) { if (userConfig_.explain_mlir_codegen) pm.addPass(mlir::daphne::createPrintIRPass("IR after codegen pipeline")); + if (userConfig_.explain_mlir_codegen_mlir_specific) + pm.addPass(mlir::daphne::createPrintIRPass("IR after MLIR codegen (MLIR-specific):")); } diff --git a/src/parser/config/ConfigParser.cpp b/src/parser/config/ConfigParser.cpp index 2a7b6a28c..47d3def14 100644 --- a/src/parser/config/ConfigParser.cpp +++ b/src/parser/config/ConfigParser.cpp @@ -104,6 +104,15 @@ void ConfigParser::readUserConfig(const std::string &filename, DaphneUserConfig config.explain_obj_ref_mgnt = jf.at(DaphneConfigJsonParams::EXPLAIN_OBJ_REF_MGNT).get(); if (keyExists(jf, DaphneConfigJsonParams::EXPLAIN_MLIR_CODEGEN)) config.explain_mlir_codegen = jf.at(DaphneConfigJsonParams::EXPLAIN_MLIR_CODEGEN).get(); + if (keyExists(jf, DaphneConfigJsonParams::EXPLAIN_MLIR_CODEGEN_SPARSITY_EXPLOITING_OP_FUSION)) + config.explain_mlir_codegen_sparsity_exploiting_op_fusion = + jf.at(DaphneConfigJsonParams::EXPLAIN_MLIR_CODEGEN_SPARSITY_EXPLOITING_OP_FUSION).get(); + if (keyExists(jf, DaphneConfigJsonParams::EXPLAIN_MLIR_CODEGEN_DAPHNEIR_TO_MLIR)) + config.explain_mlir_codegen_daphneir_to_mlir = + jf.at(DaphneConfigJsonParams::EXPLAIN_MLIR_CODEGEN_DAPHNEIR_TO_MLIR).get(); + if (keyExists(jf, DaphneConfigJsonParams::EXPLAIN_MLIR_CODEGEN_MLIR_SPECIFIC)) + config.explain_mlir_codegen_mlir_specific = + jf.at(DaphneConfigJsonParams::EXPLAIN_MLIR_CODEGEN_MLIR_SPECIFIC).get(); if (keyExists(jf, DaphneConfigJsonParams::TASK_PARTITIONING_SCHEME)) { config.taskPartitioningScheme = jf.at(DaphneConfigJsonParams::TASK_PARTITIONING_SCHEME).get(); diff --git a/src/parser/config/JsonParams.h b/src/parser/config/JsonParams.h index d2c964cf9..3a6717497 100644 --- a/src/parser/config/JsonParams.h +++ b/src/parser/config/JsonParams.h @@ -54,6 +54,10 @@ struct DaphneConfigJsonParams { inline static const std::string EXPLAIN_VECTORIZED = "explain_vectorized"; inline static const std::string EXPLAIN_OBJ_REF_MGNT = "explain_obj_ref_mgnt"; inline static const std::string EXPLAIN_MLIR_CODEGEN = "explain_mlir_codegen"; + inline static const std::string EXPLAIN_MLIR_CODEGEN_SPARSITY_EXPLOITING_OP_FUSION = + "explain_mlir_codegen_sparsity_exploiting_op_fusion"; + inline static const std::string EXPLAIN_MLIR_CODEGEN_DAPHNEIR_TO_MLIR = "explain_mlir_codegen_daphneir_to_mlir"; + inline static const std::string EXPLAIN_MLIR_CODEGEN_MLIR_SPECIFIC = "explain_mlir_codegen_mlir_specific"; inline static const std::string TASK_PARTITIONING_SCHEME = "taskPartitioningScheme"; inline static const std::string NUMBER_OF_THREADS = "numberOfThreads"; inline static const std::string MINIMUM_TASK_SIZE = "minimumTaskSize"; @@ -95,6 +99,9 @@ struct DaphneConfigJsonParams { EXPLAIN_TYPE_ADAPTATION, EXPLAIN_VECTORIZED, EXPLAIN_MLIR_CODEGEN, + EXPLAIN_MLIR_CODEGEN_SPARSITY_EXPLOITING_OP_FUSION, + EXPLAIN_MLIR_CODEGEN_DAPHNEIR_TO_MLIR, + EXPLAIN_MLIR_CODEGEN_MLIR_SPECIFIC, EXPLAIN_OBJ_REF_MGNT, TASK_PARTITIONING_SCHEME, NUMBER_OF_THREADS, From d6376f8572472f57e8811586e5ab4ecc25ac578d Mon Sep 17 00:00:00 2001 From: AlexRTer <74372589+AlexRTer@users.noreply.github.com> Date: Fri, 29 Nov 2024 19:09:43 +0100 Subject: [PATCH 08/17] Replace `notifyMatchFailure` with `ErrorHandler::compilerError` The former is currently not properly supported by Daphne. While it does work, it is replaced with Daphne's own error handler to display the error messages. --- src/compiler/lowering/AggAllOpLowering.cpp | 6 ++++-- src/compiler/lowering/AggDimOpLowering.cpp | 12 ++++++++---- src/compiler/lowering/SparsityExploitationPass.cpp | 4 ++-- src/compiler/lowering/TransposeOpLowering.cpp | 6 ++++-- 4 files changed, 18 insertions(+), 10 deletions(-) diff --git a/src/compiler/lowering/AggAllOpLowering.cpp b/src/compiler/lowering/AggAllOpLowering.cpp index b324e73fa..213da40df 100644 --- a/src/compiler/lowering/AggAllOpLowering.cpp +++ b/src/compiler/lowering/AggAllOpLowering.cpp @@ -18,6 +18,7 @@ #include #include "compiler/utils/LoweringUtils.h" +#include #include "ir/daphneir/Daphne.h" #include "ir/daphneir/Passes.h" @@ -90,8 +91,9 @@ class AggAllOpLowering : public OpConversionPattern { ssize_t numCols = matrixType.getNumCols(); if (numRows < 0 || numCols < 0) { - return rewriter.notifyMatchFailure( - op, "aggAllOp codegen currently only works with matrix dimensions that are known at compile time"); + throw ErrorHandler::compilerError( + loc, "AggAllOpLowering", + "aggAllOp codegen currently only works with matrix dimensions that are known at compile time"); } Type matrixElementType = matrixType.getElementType(); diff --git a/src/compiler/lowering/AggDimOpLowering.cpp b/src/compiler/lowering/AggDimOpLowering.cpp index 9f75934f9..8452354f1 100644 --- a/src/compiler/lowering/AggDimOpLowering.cpp +++ b/src/compiler/lowering/AggDimOpLowering.cpp @@ -18,6 +18,8 @@ #include #include "compiler/utils/LoweringUtils.h" +#include + #include "ir/daphneir/Daphne.h" #include "ir/daphneir/Passes.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" @@ -100,8 +102,9 @@ class AggDimOpLowering : public OpConversionPattern { ssize_t numCols = matrixType.getNumCols(); if (numRows < 0 || numCols < 0) { - return rewriter.notifyMatchFailure( - op, "aggDimOp codegen currently only works with matrix dimensions that are known at compile time"); + throw ErrorHandler::compilerError( + loc, "AggDimOpLowering", + "aggDimOp codegen currently only works with matrix dimensions that are known at compile time"); } Type matrixElementType = matrixType.getElementType(); @@ -236,8 +239,9 @@ class AggDimIdxOpLowering : public OpConversionPattern { ssize_t numCols = matrixType.getNumCols(); if (numRows < 0 || numCols < 0) { - return rewriter.notifyMatchFailure( - op, "aggDimOp codegen currently only works with matrix dimensions that are known at compile time"); + throw ErrorHandler::compilerError( + loc, "AggDimOpLowering", + "aggDimOp codegen currently only works with matrix dimensions that are known at compile time"); } Type matrixElementType = matrixType.getElementType(); diff --git a/src/compiler/lowering/SparsityExploitationPass.cpp b/src/compiler/lowering/SparsityExploitationPass.cpp index 3cc51f245..234b1dc0b 100644 --- a/src/compiler/lowering/SparsityExploitationPass.cpp +++ b/src/compiler/lowering/SparsityExploitationPass.cpp @@ -108,8 +108,8 @@ class SparsityExploitation final : public mlir::OpConversionPattern #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" @@ -80,8 +81,9 @@ class TransposeOpLowering : public OpConversionPattern { ssize_t numCols = matrixType.getNumCols(); if (numRows < 0 || numCols < 0) { - return rewriter.notifyMatchFailure( - op, "transposeOp codegen currently only works with matrix dimensions that are known at compile time"); + throw ErrorHandler::compilerError( + loc, "TransposeOpLowering", + "transposeOp codegen currently only works with matrix dimensions that are known at compile time"); } Value argMemref = rewriter.create( From f2ba9c3e25e398debfdd3e881ba114eee0de0855 Mon Sep 17 00:00:00 2001 From: AlexRTer <74372589+AlexRTer@users.noreply.github.com> Date: Fri, 29 Nov 2024 18:17:23 +0100 Subject: [PATCH 09/17] Added tests for sparse exploit codegen - added script level and filecheck tests - replaced Operation ptr with cast to specific op to use `get` functions for readability - replaced notifyMatchFailure with ErrorHandler::compilerError --- .../lowering/SparsityExploitationPass.cpp | 32 ++++++++-------- test/CMakeLists.txt | 1 + test/api/cli/codegen/SparsityExploitTest.cpp | 29 ++++++++++++++ test/api/cli/codegen/sparsityExploit.daphne | 5 +++ test/codegen/sparseExploit.mlir | 38 +++++++++++++++++++ 5 files changed, 89 insertions(+), 16 deletions(-) create mode 100644 test/api/cli/codegen/SparsityExploitTest.cpp create mode 100644 test/api/cli/codegen/sparsityExploit.daphne create mode 100644 test/codegen/sparseExploit.mlir diff --git a/src/compiler/lowering/SparsityExploitationPass.cpp b/src/compiler/lowering/SparsityExploitationPass.cpp index 234b1dc0b..f1f057873 100644 --- a/src/compiler/lowering/SparsityExploitationPass.cpp +++ b/src/compiler/lowering/SparsityExploitationPass.cpp @@ -75,22 +75,22 @@ class SparsityExploitation final : public mlir::OpConversionPattern(); // IntersectOp( sparseLhs, ln(denseLhs @ denseRhs) ) - Value sparseLhs = sparseIntersectOp->getOperand(0); - Operation *unaryOp = sparseIntersectOp->getOperand(1).getDefiningOp(); // ln(denseLhs @ denseRhs) + Value sparseLhs = sparseIntersectOp.getLhs(); + auto unaryOp = sparseIntersectOp.getRhs().getDefiningOp(); // ln(denseLhs @ denseRhs) - Operation *denseMatmulOp = unaryOp->getOperand(0).getDefiningOp(); // denseLhs @ denseRhs + auto denseMatmulOp = unaryOp.getArg().getDefiningOp(); // denseLhs @ denseRhs // daphne.matMul Op has 4 arguments, the latter two (`transa`, `transb`) are bools indicating whether // either matrix should be accessed as though it was transposed. - Value denseLhs = denseMatmulOp->getOperand(0); - Value denseRhs = denseMatmulOp->getOperand(1); + Value denseLhs = denseMatmulOp.getLhs(); + Value denseRhs = denseMatmulOp.getRhs(); bool transa = CompilerUtils::constantOrThrow( - denseMatmulOp->getOperand(2), "SparseExploitLowering: expected transa to be known at compile-time"); + denseMatmulOp.getTransa(), "SparseExploitLowering: expected transa to be known at compile-time"); bool transb = CompilerUtils::constantOrThrow( - denseMatmulOp->getOperand(3), "SparseExploitLowering: expected transb to be known at compile-time"); + denseMatmulOp.getTransb(), "SparseExploitLowering: expected transb to be known at compile-time"); auto sparseLhsMatType = sparseLhs.getType().template dyn_cast(); Type resElementType = sparseLhsMatType.getElementType(); @@ -291,15 +291,15 @@ void SparsityExploitationPass::runOnOperation() { // Marks Op as illegal only if it matches the pattern: sum(sparse * ln(dense @ dense)) target.addDynamicallyLegalOp([](Operation *op) { - Operation *definingEwMulOp = op->getOperand(0).getDefiningOp(); + auto definingEwMulOp = op->getOperand(0).getDefiningOp(); if (definingEwMulOp == nullptr) { return true; } - Type lhsType = definingEwMulOp->getOperand(0).getType(); + Type lhsType = definingEwMulOp.getLhs().getType(); auto lhsMatType = lhsType.dyn_cast(); - Value rhs = definingEwMulOp->getOperand(1); + Value rhs = definingEwMulOp.getRhs(); Type rhsType = rhs.getType(); auto rhsMatType = rhsType.dyn_cast(); @@ -311,18 +311,18 @@ void SparsityExploitationPass::runOnOperation() { return true; } - Operation *definingEwLnOp = rhs.getDefiningOp(); + auto definingEwLnOp = rhs.getDefiningOp(); if (definingEwLnOp == nullptr) { return true; } - Operation *definingMatMulOp = definingEwLnOp->getOperand(0).getDefiningOp(); + auto definingMatMulOp = definingEwLnOp.getArg().getDefiningOp(); if (definingMatMulOp == nullptr) { return true; } - Type matmulLhsType = definingMatMulOp->getOperand(0).getType(); - Type matmulRhsType = definingMatMulOp->getOperand(1).getType(); + Type matmulLhsType = definingMatMulOp.getLhs().getType(); + Type matmulRhsType = definingMatMulOp.getRhs().getType(); auto matmulLhsMatType = matmulLhsType.dyn_cast(); auto matmulRhsMatType = matmulRhsType.dyn_cast(); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index a8051d1b7..e4e6b758b 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -63,6 +63,7 @@ set(TEST_SOURCES api/cli/codegen/EwUnaryTest.cpp api/cli/codegen/EwOpLoopFusionTest.cpp api/cli/codegen/MapOpTest.cpp + api/cli/codegen/SparsityExploitTest.cpp api/cli/codegen/TransposeTest.cpp ir/daphneir/InferTypesTest.cpp diff --git a/test/api/cli/codegen/SparsityExploitTest.cpp b/test/api/cli/codegen/SparsityExploitTest.cpp new file mode 100644 index 000000000..1cb495507 --- /dev/null +++ b/test/api/cli/codegen/SparsityExploitTest.cpp @@ -0,0 +1,29 @@ +/* + * Copyright 2024 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include + +const std::string dirPath = "test/api/cli/codegen/"; + +TEST_CASE("sparsity exploit (sparse-dense cross-entropy)", TAG_CODEGEN) { + std::string result = "1.61078\n"; + compareDaphneToStr(result, dirPath + "sparsityExploit.daphne", "--select-matrix-repr"); + compareDaphneToStr(result, dirPath + "sparsityExploit.daphne", "--select-matrix-repr", "--mlir-codegen"); +} diff --git a/test/api/cli/codegen/sparsityExploit.daphne b/test/api/cli/codegen/sparsityExploit.daphne new file mode 100644 index 000000000..eb57018be --- /dev/null +++ b/test/api/cli/codegen/sparsityExploit.daphne @@ -0,0 +1,5 @@ +sparseLhs = rand(15, 10, 0.0, 1.0, 0.1, 1); // size: 15x10 range: (0,1) sparsity: 0.1 +DenseU = rand(15, 5, 0.0, 1.0, 1.0, 2); // size: 15x5 range: (0,1) sparsity: 1 +DenseV = rand(10, 5, 0.0, 1.0, 1.0, 3); // size: 10x5 range: (0,1) sparsity: 1 + +print(sum(sparseLhs * ln(DenseU @ t(DenseV)))); diff --git a/test/codegen/sparseExploit.mlir b/test/codegen/sparseExploit.mlir new file mode 100644 index 000000000..3b5789de0 --- /dev/null +++ b/test/codegen/sparseExploit.mlir @@ -0,0 +1,38 @@ +// RUN: daphne-opt -pass-pipeline="builtin.module(lower-sparse-exploit, canonicalize)" %s | FileCheck %s + +// COM: Canonicalizer is run to guarantee matMul, ewLn, and ewMul are also removed as they become redundant but are not removed by the sparse pass itself. +// COM: They are tested here regardless as it is crucial to the pass to replace the entire pattern. + +module { + func.func @double() { + %0 = "daphne.constant"() {value = true} : () -> i1 + %1 = "daphne.constant"() {value = false} : () -> i1 + %2 = "daphne.constant"() {value = 2 : index} : () -> index + %3 = "daphne.constant"() {value = 10 : index} : () -> index + %4 = "daphne.constant"() {value = 3 : si64} : () -> si64 + %5 = "daphne.constant"() {value = 2 : si64} : () -> si64 + %6 = "daphne.constant"() {value = 1 : si64} : () -> si64 + %7 = "daphne.constant"() {value = 0.000000e+00 : f64} : () -> f64 + %8 = "daphne.constant"() {value = 1.000000e+00 : f64} : () -> f64 + %9 = "daphne.constant"() {value = 2.000000e-01 : f64} : () -> f64 + %10 = "daphne.randMatrix"(%3, %3, %7, %8, %9, %6) : (index, index, f64, f64, f64, si64) -> !daphne.Matrix<10x10xf64:sp[2.000000e-01]:rep[sparse]> + %11 = "daphne.randMatrix"(%3, %2, %7, %8, %8, %5) : (index, index, f64, f64, f64, si64) -> !daphne.Matrix<10x2xf64:sp[1.000000e+00]> + %12 = "daphne.randMatrix"(%3, %2, %7, %8, %8, %4) : (index, index, f64, f64, f64, si64) -> !daphne.Matrix<10x2xf64:sp[1.000000e+00]> + // CHECK-NOT: daphne.matMul + // CHECK-NOT: daphne.ewLn + // CHECK-NOT: daphne.ewMul + // CHECK-NOT: daphne.sumAll + // CHECK: affine.parallel + // CHECK: scf.for + // CHECK: scf.for + // CHECK: math.fma + // CHECK: math.log + // CHECK: math.fma + %13 = "daphne.matMul"(%11, %12, %1, %0) : (!daphne.Matrix<10x2xf64:sp[1.000000e+00]>, !daphne.Matrix<10x2xf64:sp[1.000000e+00]>, i1, i1) -> !daphne.Matrix<10x10xf64:sp[1.000000e+00]> + %14 = "daphne.ewLn"(%13) : (!daphne.Matrix<10x10xf64:sp[1.000000e+00]>) -> !daphne.Matrix<10x10xf64> + %15 = "daphne.ewMul"(%10, %14) : (!daphne.Matrix<10x10xf64:sp[2.000000e-01]:rep[sparse]>, !daphne.Matrix<10x10xf64>) -> !daphne.Matrix<10x10xf64:sp[2.000000e-01]:rep[sparse]> + %16 = "daphne.sumAll"(%15) : (!daphne.Matrix<10x10xf64:sp[2.000000e-01]:rep[sparse]>) -> f64 + "daphne.print"(%16, %0, %1) : (f64, i1, i1) -> () + "daphne.return"() : () -> () + } +} \ No newline at end of file From 701286e4739180995071da0970851bc1cf261000 Mon Sep 17 00:00:00 2001 From: philipportner Date: Tue, 3 Dec 2024 09:27:12 +0100 Subject: [PATCH 10/17] duplicate `clEnumVal` causes crash in --debug When building with debug symbols a recent commit introduced a duplicate `clEnumVal` which caused an exception. This patch removes the duplication. ``` daphne: /home/philipportner/daphne/thirdparty/installed/include/llvm/Support/CommandLine.h:864: void llvm::cl::parser::addLiteralOption(llvm::StringRef, const DT&, llvm::StringRef) [with DT = int; DataType = startDAPHNE(int, const char**, DaphneLibResult*, int*, DaphneUserConfig&)::ExplainArgs]: Assertion `findOption(Name) == Values.size() && "Option already exists!"' failed. ``` --- src/api/internal/daphne_internal.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/api/internal/daphne_internal.cpp b/src/api/internal/daphne_internal.cpp index e5c8fb0c0..fc728ffcb 100644 --- a/src/api/internal/daphne_internal.cpp +++ b/src/api/internal/daphne_internal.cpp @@ -309,7 +309,6 @@ int startDAPHNE(int argc, const char **argv, DaphneLibResult *daphneLibRes, int clEnumVal(vectorized, "Show DaphneIR after vectorization"), clEnumVal(obj_ref_mgnt, "Show DaphneIR after managing object references"), clEnumVal(kernels, "Show DaphneIR after kernel lowering"), - clEnumVal(llvm, "Show DaphneIR after llvm lowering"), clEnumVal(mlir_codegen, "Show DaphneIR after MLIR codegen"), clEnumVal(mlir_codegen_sparsity_exploiting_op_fusion, "Show DaphneIR after MLIR codegen (sparsity-exploiting operator fusion)"), From 7136e0d93ccc882d68bbded41bebab7d3d61baf5 Mon Sep 17 00:00:00 2001 From: Samin <69201742+saminbassiri@users.noreply.github.com> Date: Tue, 3 Dec 2024 20:10:59 +0100 Subject: [PATCH 11/17] [DAPHNE-#901] Inner join and semi join with result cardinality hint (#918) This commit introduces a new optional parameter, numRowRes, to the innerJoin() and semiJoin() DaphneDSL built-in functions, enabling precise control over result size allocation. It fixes #901. Key Changes: 1. Kernel changes: - innerJoin: - If `numRowRes = -1`, the result size defaults to `numRowRhs * numRowLhs` (cartesian product). - Otherwise, the result size is defined by `numRowRes`. - semiJoin: - If `numRowRes = -1`, the result size defaults to `numRowLhs`. - Otherwise, the result size is defined by `numRowRes`. 2. DaphneDSL updates: - `numRowRes` is now an optional argument for `innerJoin` and `semiJoin`. - Defaults to `-1` if not provided. 3. DaphneIR adjustments: - `numRowRes` is now a mandatory argument for `InnerJoinOp` and `SemiJoinOp`. 4. Implementation updates: - Modified `DaphneDSLBuiltins.cpp` to set default values for `numRowRes`. - Updated `SQLVisitor.cpp` to ensure compatibility by passing `-1` as `numRowRes`. - Adjusted `kernels.json` to reflect the new parameter for relevant operations. 5. Testing: - Added script-level test cases to validate correct behavior across various scenarios. --- doc/DaphneDSL/Builtins.md | 12 +++++++++-- src/ir/daphneir/DaphneOps.td | 4 ++-- src/parser/daphnedsl/DaphneDSLBuiltins.cpp | 18 +++++++++++++---- src/parser/sql/SQLVisitor.cpp | 5 ++++- src/runtime/local/kernels/InnerJoin.h | 5 ++++- src/runtime/local/kernels/SemiJoin.h | 21 ++++++++++++++------ src/runtime/local/kernels/kernels.json | 10 +++++++++- test/api/cli/operations/OperationsTest.cpp | 2 ++ test/api/cli/operations/innerJoin_1.daphne | 10 ++++++++++ test/api/cli/operations/innerJoin_1.txt | 6 ++++++ test/api/cli/operations/semiJoin_1.daphne | 10 ++++++++++ test/api/cli/operations/semiJoin_1.txt | 6 ++++++ test/runtime/local/kernels/InnerJoinTest.cpp | 4 ++-- test/runtime/local/kernels/SemiJoinTest.cpp | 2 +- 14 files changed, 95 insertions(+), 20 deletions(-) create mode 100644 test/api/cli/operations/innerJoin_1.daphne create mode 100644 test/api/cli/operations/innerJoin_1.txt create mode 100644 test/api/cli/operations/semiJoin_1.daphne create mode 100644 test/api/cli/operations/semiJoin_1.txt diff --git a/doc/DaphneDSL/Builtins.md b/doc/DaphneDSL/Builtins.md index 3b0318a15..acbbd332e 100644 --- a/doc/DaphneDSL/Builtins.md +++ b/doc/DaphneDSL/Builtins.md @@ -492,14 +492,22 @@ We will support set operations such as **`intersect`**, **`merge`**, and **`exce Calculates the cartesian product of the two input frames. -- **`innerJoin`**`(lhs:frame, rhs:frame, lhsOn:str, rhsOn:str)` +- **`innerJoin`**`(lhs:frame, rhs:frame, lhsOn:str, rhsOn:str[, numRowRes:si64])` Performs an inner join of the two input frames on `lhs`.`lhsOn` == `rhs`.`rhsOn`. -- **`semiJoin`**`(lhs:frame, rhs:frame, lhsOn:str, rhsOn:str)` + The parameter `numRowRes` is an optional hint for an upper bound of the number or result rows. + If specified, it determines the number of rows that will be allocated for the result, whereby `-1` stands for an automatically chosen size. + Otherwise, it defaults to `-1`. + +- **`semiJoin`**`(lhs:frame, rhs:frame, lhsOn:str, rhsOn:str[, numRowRes:si64])` Performs a semi join of the two input frames on `lhs`.`lhsOn` == `rhs`.`rhsOn`. Returns only the columns belonging to `lhs`. + + The parameter `numRowRes` is an optional hint for an upper bound of the number or result rows. + If specified, it determines the number of rows that will be allocated for the result, whereby `-1` stands for an automatically chosen size. + Otherwise, it defaults to `-1`. - **`groupJoin`**`(lhs:frame, rhs:frame, lhsOn:str, rhsOn:str, rhsAgg:str)` diff --git a/src/ir/daphneir/DaphneOps.td b/src/ir/daphneir/DaphneOps.td index 0f717cd9b..b2d91310b 100644 --- a/src/ir/daphneir/DaphneOps.td +++ b/src/ir/daphneir/DaphneOps.td @@ -1097,7 +1097,7 @@ def Daphne_InnerJoinOp : Daphne_Op<"innerJoin", [ DataTypeFrm, ValueTypesConcat, DeclareOpInterfaceMethods, ]> { - let arguments = (ins FrameOrU:$lhs, FrameOrU:$rhs, StrScalar:$lhsOn, StrScalar:$rhsOn); + let arguments = (ins FrameOrU:$lhs, FrameOrU:$rhs, StrScalar:$lhsOn, StrScalar:$rhsOn, Size:$numRowRes); let results = (outs FrameOrU:$res); } @@ -1141,7 +1141,7 @@ def Daphne_SemiJoinOp : Daphne_Op<"semiJoin", [ DeclareOpInterfaceMethods, NumColsFromArg ]> { - let arguments = (ins FrameOrU:$lhs, FrameOrU:$rhs, StrScalar:$lhsOn, StrScalar:$rhsOn); + let arguments = (ins FrameOrU:$lhs, FrameOrU:$rhs, StrScalar:$lhsOn, StrScalar:$rhsOn, Size:$numRowRes); let results = (outs FrameOrU:$res, MatrixOf<[Size]>:$lhsTids); } diff --git a/src/parser/daphnedsl/DaphneDSLBuiltins.cpp b/src/parser/daphnedsl/DaphneDSLBuiltins.cpp index bc136b9d0..7bb1fd1b1 100644 --- a/src/parser/daphnedsl/DaphneDSLBuiltins.cpp +++ b/src/parser/daphnedsl/DaphneDSLBuiltins.cpp @@ -993,13 +993,18 @@ antlrcpp::Any DaphneDSLBuiltins::build(mlir::Location loc, const std::string &fu builder.create(loc, FrameType::get(builder.getContext(), colTypes), args[0], args[1])); } if (func == "innerJoin") { - checkNumArgsExact(loc, func, numArgs, 4); + checkNumArgsBetween(loc, func, numArgs, 4, 5); std::vector colTypes; + mlir::Value numRowRes; for (int i = 0; i < 2; i++) for (mlir::Type t : args[i].getType().dyn_cast().getColumnTypes()) colTypes.push_back(t); + if (numArgs == 5) + numRowRes = utils.castSI64If(args[4]); + else + numRowRes = builder.create(loc, int64_t(-1)); return static_cast(builder.create(loc, FrameType::get(builder.getContext(), colTypes), - args[0], args[1], args[2], args[3])); + args[0], args[1], args[2], args[3], numRowRes)); } if (func == "fullOuterJoin") return createJoinOp(loc, func, args); @@ -1011,14 +1016,19 @@ antlrcpp::Any DaphneDSLBuiltins::build(mlir::Location loc, const std::string &fu // TODO Reconcile this with the other join ops, but we need it to work // quickly now. // return createJoinOp(loc, func, args); - checkNumArgsExact(loc, func, numArgs, 4); + checkNumArgsBetween(loc, func, numArgs, 4, 5); mlir::Value lhs = args[0]; mlir::Value rhs = args[1]; mlir::Value lhsOn = args[2]; mlir::Value rhsOn = args[3]; + mlir::Value numRowRes; + if (numArgs == 5) + numRowRes = utils.castSI64If(args[4]); + else + numRowRes = builder.create(loc, int64_t(-1)); return builder .create(loc, FrameType::get(builder.getContext(), {utils.unknownType}), utils.matrixOfSizeType, - lhs, rhs, lhsOn, rhsOn) + lhs, rhs, lhsOn, rhsOn, numRowRes) .getResults(); } if (func == "groupJoin") { diff --git a/src/parser/sql/SQLVisitor.cpp b/src/parser/sql/SQLVisitor.cpp index d3c2a3d9f..e098ced3e 100644 --- a/src/parser/sql/SQLVisitor.cpp +++ b/src/parser/sql/SQLVisitor.cpp @@ -551,8 +551,11 @@ antlrcpp::Any SQLVisitor::visitInnerJoin(SQLGrammarParser::InnerJoinContext *ctx mlir::Value rhsName = valueOrErrorOnVisit(ctx->rhs); mlir::Value lhsName = valueOrErrorOnVisit(ctx->lhs); + mlir::Value numRowRes = + static_cast(builder.create(queryLoc, static_cast(-1))); + return static_cast( - builder.create(loc, t, currentFrame, tojoin, rhsName, lhsName)); + builder.create(loc, t, currentFrame, tojoin, rhsName, lhsName, numRowRes)); } std::vector rhsNames; diff --git a/src/runtime/local/kernels/InnerJoin.h b/src/runtime/local/kernels/InnerJoin.h index a6185056b..55451a4ba 100644 --- a/src/runtime/local/kernels/InnerJoin.h +++ b/src/runtime/local/kernels/InnerJoin.h @@ -80,8 +80,11 @@ inline void innerJoin( const Frame *lhs, const Frame *rhs, // input column names const char *lhsOn, const char *rhsOn, + // result size + int64_t numRowRes, // context DCTX(ctx)) { + // Find out the value types of the columns to process. ValueTypeCode vtcLhsOn = lhs->getColumnType(lhsOn); ValueTypeCode vtcRhsOn = rhs->getColumnType(rhsOn); @@ -89,7 +92,7 @@ inline void innerJoin( // Perhaps check if res already allocated. const size_t numRowRhs = rhs->getNumRows(); const size_t numRowLhs = lhs->getNumRows(); - const size_t totalRows = numRowRhs * numRowLhs; + const size_t totalRows = numRowRes == -1 ? numRowRhs * numRowLhs : numRowRes; const size_t numColRhs = rhs->getNumCols(); const size_t numColLhs = lhs->getNumCols(); const size_t totalCols = numColRhs + numColLhs; diff --git a/src/runtime/local/kernels/SemiJoin.h b/src/runtime/local/kernels/SemiJoin.h index eb815e0b7..4e308d087 100644 --- a/src/runtime/local/kernels/SemiJoin.h +++ b/src/runtime/local/kernels/SemiJoin.h @@ -46,6 +46,8 @@ void semiJoinCol( Frame *&res, DenseMatrix *&resLhsTid, // arguments const DenseMatrix *argLhs, const DenseMatrix *argRhs, + // result size + int64_t numRowRes, // context DCTX(ctx)) { if (argLhs->getNumCols() != 1) @@ -72,11 +74,14 @@ void semiJoinCol( // Create the output data objects. if (res == nullptr) { ValueTypeCode schema[] = {ValueTypeUtils::codeFor}; - res = DataObjectFactory::create(numArgLhs, 1, schema, nullptr, false); + const size_t resSize = numRowRes == -1 ? numArgLhs : numRowRes; + res = DataObjectFactory::create(resSize, 1, schema, nullptr, false); } auto resLhs = res->getColumn(0); - if (resLhsTid == nullptr) - resLhsTid = DataObjectFactory::create>(numArgLhs, 1, false); + if (resLhsTid == nullptr) { + const size_t resLhsTidSize = numRowRes == -1 ? numArgLhs : numRowRes; + resLhsTid = DataObjectFactory::create>(resLhsTidSize, 1, false); + } size_t pos = 0; for (size_t i = 0; i < numArgLhs; i++) { @@ -107,11 +112,13 @@ void semiJoinColIf( const Frame *lhs, const Frame *rhs, // input column names const char *lhsOn, const char *rhsOn, + // result size + int64_t numRowRes, // context DCTX(ctx)) { if (vtcLhs == ValueTypeUtils::codeFor && vtcRhs == ValueTypeUtils::codeFor) { semiJoinCol(res, resLhsTid, lhs->getColumn(lhsOn), rhs->getColumn(rhsOn), - ctx); + numRowRes, ctx); } } @@ -127,6 +134,8 @@ void semiJoin( const Frame *lhs, const Frame *rhs, // input column names const char *lhsOn, const char *rhsOn, + // result size + int64_t numRowRes, // context DCTX(ctx)) { // Find out the value types of the columns to process. @@ -136,8 +145,8 @@ void semiJoin( // Call the semiJoin-kernel on columns for the actual combination of // value types. // Repeat this for all type combinations... - semiJoinColIf(vtcLhsOn, vtcRhsOn, res, lhsTid, lhs, rhs, lhsOn, rhsOn, ctx); - semiJoinColIf(vtcLhsOn, vtcRhsOn, res, lhsTid, lhs, rhs, lhsOn, rhsOn, ctx); + semiJoinColIf(vtcLhsOn, vtcRhsOn, res, lhsTid, lhs, rhs, lhsOn, rhsOn, numRowRes, ctx); + semiJoinColIf(vtcLhsOn, vtcRhsOn, res, lhsTid, lhs, rhs, lhsOn, rhsOn, numRowRes, ctx); // Set the column labels of the result frame. std::string labels[] = {lhsOn}; diff --git a/src/runtime/local/kernels/kernels.json b/src/runtime/local/kernels/kernels.json index 4a11448b6..783872553 100644 --- a/src/runtime/local/kernels/kernels.json +++ b/src/runtime/local/kernels/kernels.json @@ -3061,7 +3061,11 @@ { "type": "const char *", "name": "rhsOn" - } + }, + { + "type": "int64_t", + "name": "numRowRes" + } ] }, "instantiations": [[]] @@ -4354,6 +4358,10 @@ { "type": "const char *", "name": "rhsOn" + }, + { + "type": "int64_t", + "name": "numRowRes" } ] }, diff --git a/test/api/cli/operations/OperationsTest.cpp b/test/api/cli/operations/OperationsTest.cpp index 47857ceb0..4dd0a7bb9 100644 --- a/test/api/cli/operations/OperationsTest.cpp +++ b/test/api/cli/operations/OperationsTest.cpp @@ -42,6 +42,7 @@ MAKE_TEST_CASE("fill", 1) MAKE_TEST_CASE("gemv", 1) MAKE_TEST_CASE("idxMax", 1) MAKE_TEST_CASE("idxMin", 1) +MAKE_TEST_CASE("innerJoin", 1) MAKE_TEST_CASE("isNan", 1) MAKE_TEST_CASE("lower", 1) MAKE_TEST_CASE("mean", 1) @@ -59,6 +60,7 @@ MAKE_TEST_CASE("rbind", 1) MAKE_TEST_CASE("recode", 4) MAKE_TEST_CASE("replace", 1) MAKE_TEST_CASE("reverse", 1) +MAKE_TEST_CASE("semiJoin", 1) MAKE_TEST_CASE("seq", 2) MAKE_TEST_CASE("solve", 1) MAKE_TEST_CASE("sqrt", 1) diff --git a/test/api/cli/operations/innerJoin_1.daphne b/test/api/cli/operations/innerJoin_1.daphne new file mode 100644 index 000000000..b2e97b3b0 --- /dev/null +++ b/test/api/cli/operations/innerJoin_1.daphne @@ -0,0 +1,10 @@ +# Test inner join without and with optional arg for result size. + +f1 = createFrame([ 1, 2 ], [ 3, 4 ], "a", "b"); +f2 = createFrame([ 3, 4, 5 ], [ 6, 7, 8 ], "c", "d"); + +f3 = innerJoin(f1, f2, "b", "c"); +f4 = innerJoin(f1, f2, "b", "c", 2); + +print(f3); +print(f4); \ No newline at end of file diff --git a/test/api/cli/operations/innerJoin_1.txt b/test/api/cli/operations/innerJoin_1.txt new file mode 100644 index 000000000..193f56e12 --- /dev/null +++ b/test/api/cli/operations/innerJoin_1.txt @@ -0,0 +1,6 @@ +Frame(2x4, [a:int64_t, b:int64_t, c:int64_t, d:int64_t]) +1 3 3 6 +2 4 4 7 +Frame(2x4, [a:int64_t, b:int64_t, c:int64_t, d:int64_t]) +1 3 3 6 +2 4 4 7 diff --git a/test/api/cli/operations/semiJoin_1.daphne b/test/api/cli/operations/semiJoin_1.daphne new file mode 100644 index 000000000..8a3b3b321 --- /dev/null +++ b/test/api/cli/operations/semiJoin_1.daphne @@ -0,0 +1,10 @@ +# Test semi-join without and with optional arg for result size. + +f1 = createFrame([ 1, 2 ], [ 3, 4 ], "a", "b"); +f2 = createFrame([ 3, 4, 5 ], [ 6, 7, 8 ], "c", "d"); + +keys1, tids1 = semiJoin(f1, f2, "b", "c"); +keys2, tids2 = semiJoin(f1, f2, "b", "c", 2); + +print(f1[tids1, ]); +print(f1[tids2, ]); \ No newline at end of file diff --git a/test/api/cli/operations/semiJoin_1.txt b/test/api/cli/operations/semiJoin_1.txt new file mode 100644 index 000000000..b135e0a8a --- /dev/null +++ b/test/api/cli/operations/semiJoin_1.txt @@ -0,0 +1,6 @@ +Frame(2x2, [a:int64_t, b:int64_t]) +1 3 +2 4 +Frame(2x2, [a:int64_t, b:int64_t]) +1 3 +2 4 diff --git a/test/runtime/local/kernels/InnerJoinTest.cpp b/test/runtime/local/kernels/InnerJoinTest.cpp index 371c20cfe..6e4c264c1 100644 --- a/test/runtime/local/kernels/InnerJoinTest.cpp +++ b/test/runtime/local/kernels/InnerJoinTest.cpp @@ -31,7 +31,7 @@ #include -TEST_CASE("innerJoin", TAG_KERNELS) { +TEST_CASE("InnerJoin", TAG_KERNELS) { auto lhsC0 = genGivenVals>(4, {1, 2, 3, 4}); auto lhsC1 = genGivenVals>(4, {11.0, 22.0, 33.0, 44.00}); std::vector lhsCols = {lhsC0, lhsC1}; @@ -46,7 +46,7 @@ TEST_CASE("innerJoin", TAG_KERNELS) { auto rhs = DataObjectFactory::create(rhsCols, rhsLabels); Frame *res = nullptr; - innerJoin(res, lhs, rhs, "a", "c", nullptr); + innerJoin(res, lhs, rhs, "a", "c", -1, nullptr); // Check the meta data. CHECK(res->getNumRows() == 2); diff --git a/test/runtime/local/kernels/SemiJoinTest.cpp b/test/runtime/local/kernels/SemiJoinTest.cpp index 5b0d76017..ea6a47070 100644 --- a/test/runtime/local/kernels/SemiJoinTest.cpp +++ b/test/runtime/local/kernels/SemiJoinTest.cpp @@ -58,7 +58,7 @@ TEST_CASE("SemiJoin", TAG_KERNELS) { // res Frame *res = nullptr; DenseMatrix *lhsTid = nullptr; - semiJoin(res, lhsTid, lhs, rhs, "a", "c", nullptr); + semiJoin(res, lhsTid, lhs, rhs, "a", "c", -1, nullptr); CHECK(*res == *expRes); CHECK(*lhsTid == *expTid); From 44ff77bb96cb7d847508220147cdaaf982ed653b Mon Sep 17 00:00:00 2001 From: Samin <69201742+saminbassiri@users.noreply.github.com> Date: Tue, 3 Dec 2024 21:13:18 +0100 Subject: [PATCH 12/17] [DAPHNE-#903] Add 'groupSum()' built-in function to DaphneDSL. (#921) This commit introduces a new `groupSum()` built-in function to DaphneDSL, enabling the creation of a `GroupOp` in DaphneIR and closes #903. Changes Implemented 1. `groupSum()` Built-in Function in DaphneDSL: - Interface: `group(arg:frame, groupCols:str, ..., sumCol:str)` - Accepts: - A frame as input. - At least one column to group on. - A single column to compute the sum. - Aggregation Support: - Only supports `SUM` as the aggregation function. 3. **Test Cases**: - Added script-level tests to validate the functionality of the `group()` function in DaphneDSL. --- doc/DaphneDSL/Builtins.md | 8 +++++ src/parser/daphnedsl/DaphneDSLBuiltins.cpp | 39 ++++++++++++++++++++++ test/api/cli/operations/OperationsTest.cpp | 1 + test/api/cli/operations/groupSum_1.daphne | 8 +++++ test/api/cli/operations/groupSum_1.txt | 5 +++ 5 files changed, 61 insertions(+) create mode 100644 test/api/cli/operations/groupSum_1.daphne create mode 100644 test/api/cli/operations/groupSum_1.txt diff --git a/doc/DaphneDSL/Builtins.md b/doc/DaphneDSL/Builtins.md index acbbd332e..535003f0d 100644 --- a/doc/DaphneDSL/Builtins.md +++ b/doc/DaphneDSL/Builtins.md @@ -515,6 +515,14 @@ We will support set operations such as **`intersect`**, **`merge`**, and **`exce We will support more variants of joins, including (left/right) outer joins, theta joins, anti-joins, etc. +### Grouping and aggregation + +- **`groupSum`**`(arg:frame, grpColNames:str[, grpColNames, ...], sumColName:str)` + + Groups the rows in the given frame `arg` by the specified columns `grpColNames` (at least one column) and calculates the per-group sum of the column denoted by `sumColName`. + + *This built-in function is currently limited in terms of functionality (aggregation only on a single column, sum as the only aggregation function). It will be extended in the future. Meanwhile, consider using DAPHNE's `sql()` built-in function for more comprehensive grouping and aggregation support.* + ### Frame label manipulation - **`setColLabels`**`(arg:frame, labels:str, ...)` diff --git a/src/parser/daphnedsl/DaphneDSLBuiltins.cpp b/src/parser/daphnedsl/DaphneDSLBuiltins.cpp index 7bb1fd1b1..10208d788 100644 --- a/src/parser/daphnedsl/DaphneDSLBuiltins.cpp +++ b/src/parser/daphnedsl/DaphneDSLBuiltins.cpp @@ -1044,6 +1044,45 @@ antlrcpp::Any DaphneDSLBuiltins::build(mlir::Location loc, const std::string &fu .getResults(); } + // -------------------------------------------------------------------- + // Grouping and aggregation + // -------------------------------------------------------------------- + + if (func == "groupSum") { + // Arbitrary number of columns to group on. + // A single column to calculate the sum on. + checkNumArgsMin(loc, func, numArgs, 3); + mlir::Value currentFrame = args[0]; + mlir::Value aggCol = args[numArgs - 1]; + std::vector groupName; + std::vector columnName; + std::vector colTypes; + + // set aggregaton function to SUM + auto aggFunc = static_cast( + mlir::daphne::GroupEnumAttr::get(builder.getContext(), mlir::daphne::GroupEnum::SUM)); + std::vector functionName; + functionName.push_back(aggFunc); + + // get group columns + for (size_t i = 1; i < numArgs - 1; i++) { + groupName.push_back(args[i]); + } + + // get agg column + columnName.push_back(aggCol); + + // result column types + mlir::Type vt = utils.unknownType; + for (size_t i = 0; i < groupName.size() + columnName.size(); i++) { + colTypes.push_back(vt); + } + + return static_cast(builder.create(loc, FrameType::get(builder.getContext(), colTypes), + currentFrame, groupName, columnName, + builder.getArrayAttr(functionName))); + } + // ******************************************************************** // Frame label manipulation // ******************************************************************** diff --git a/test/api/cli/operations/OperationsTest.cpp b/test/api/cli/operations/OperationsTest.cpp index 4dd0a7bb9..4b8f44aa6 100644 --- a/test/api/cli/operations/OperationsTest.cpp +++ b/test/api/cli/operations/OperationsTest.cpp @@ -40,6 +40,7 @@ MAKE_TEST_CASE("createFrame", 1) MAKE_TEST_CASE("ctable", 1) MAKE_TEST_CASE("fill", 1) MAKE_TEST_CASE("gemv", 1) +MAKE_TEST_CASE("groupSum", 1) MAKE_TEST_CASE("idxMax", 1) MAKE_TEST_CASE("idxMin", 1) MAKE_TEST_CASE("innerJoin", 1) diff --git a/test/api/cli/operations/groupSum_1.daphne b/test/api/cli/operations/groupSum_1.daphne new file mode 100644 index 000000000..82b08ca51 --- /dev/null +++ b/test/api/cli/operations/groupSum_1.daphne @@ -0,0 +1,8 @@ +// Grouping and sum aggregation (two grouping columns). + +fr = createFrame([ 11, 11, 22, 33, 33 ], [ 1, 1, 2, 3, 4 ], + [ 100, 200, 300, 400, 500 ], "a", "b", "c"); + +res = groupSum(fr, "a", "b", "c"); + +print(res); \ No newline at end of file diff --git a/test/api/cli/operations/groupSum_1.txt b/test/api/cli/operations/groupSum_1.txt new file mode 100644 index 000000000..4e24eb7a4 --- /dev/null +++ b/test/api/cli/operations/groupSum_1.txt @@ -0,0 +1,5 @@ +Frame(4x3, [a:int64_t, b:int64_t, SUM(c):int64_t]) +11 1 300 +22 2 300 +33 3 400 +33 4 500 From 47d6fd035960df80b27f95f0389bbf201dba4512 Mon Sep 17 00:00:00 2001 From: Patrick Damme Date: Fri, 6 Dec 2024 22:48:24 +0100 Subject: [PATCH 13/17] Clean-up of script-level tests for file readers/writers. - Cleaned up the directory "test/api/cli/io/". - Problems with previous state: - The test cases mixed some things that should be tested independently, e.g., reading a frame and reading from a file whose path is given as the parameter of a DaphneDSL UDF. - The directory was cluttered with test data/metadata files (reference files and git-ignored generated files). - Inconsistent naming of test files. - Consequently, adding new test cases into the existing structure was not straightforward. - Changes by this commit: - Rewrote test cases that read matrices/frames with and without string data. - Removed the old .csv and .csv.meta files and replaced them by new (sometimes simpler) files (in subdirectory "ref/") that still check the same interesting cases (and more). - Removed the old "testRead*.{daphne,txt}" files. - The DaphneDSL files were replaced by the new ones in subdirectory "read/". - The reference txt files are omitted, since comparisons are not based on DaphneDSL print() anymore (since its output can be ambiguous for string data). - Replaced "ReadTest.cpp" and "WriteTest.cpp" by a single "ReadWriteTest.cpp" to have all test cases at a glance. - Removed the dynamic path test, replaced it by several new ones. - So far, this test case was commented out. - Now, variants that just concatenate strings are tested, because they already work. - However, variants concatenating strings and numbers don't work yet and remain commented (see #931). - Removed unused files "readSparse.daphne" and "readSparse.txt": They used to be part of a test case for reading sparse data into a CSRMatrix, but that test case was commented out and the file "readSparse.coo" referenced in it didn't even exist. Thus, we don't test reading sparse data into CSRMatrix now, but we didn't do that before this clean-up, either. - Created new DaphneDSL scripts that test writing matrices/frames (subdirectory "write/"). - After this clean-up, the tests still cover the same interesting cases (though sometimes with different examples). - Fixed a bug in CSV reader for string data that was revealed by rewriting the read test cases. - Escaped double-quote characters in quoted CSV values were not unescaped, i.e., they remained double double-quotes "" in memory, although they should become single double-quotes ". - Consistently, a bug in the unit test cases was fixed. - Added some required cast-kernel specializations. - A specialization of CastObj from Frame to DenseMatrix. - A specialization of CastSca from std::string to std::string (required for avoiding template ambiguities). - Thanks to @saminbassiri for this cast-kernel code. --- src/runtime/local/io/utils.h | 2 +- src/runtime/local/kernels/CastObj.h | 5 +- src/runtime/local/kernels/CastSca.h | 8 ++ src/runtime/local/kernels/kernels.json | 1 + test/CMakeLists.txt | 4 +- test/api/cli/io/.gitignore | 4 - test/api/cli/io/ReadCsv1.csv | 2 - test/api/cli/io/ReadCsv1.csv.meta | 6 - test/api/cli/io/ReadCsv2.csv | 10 -- test/api/cli/io/ReadCsv3.csv | 4 - test/api/cli/io/ReadCsv3.csv.meta | 18 --- test/api/cli/io/ReadTest.cpp | 64 ---------- test/api/cli/io/ReadWriteTest.cpp | 109 ++++++++++++++++++ test/api/cli/io/WriteTest.cpp | 44 ------- test/api/cli/io/check_frame.daphne | 53 +++++++++ test/api/cli/io/check_matrix.daphne | 36 ++++++ test/api/cli/io/do_check_frame.daphne | 21 ++++ test/api/cli/io/do_check_matrix.daphne | 21 ++++ test/api/cli/io/matrix_full_ref.csv | 4 - test/api/cli/io/matrix_view_ref.csv | 4 - test/api/cli/io/out/.gitignore | 2 + .../io/read/read_frame_dynamic-path-1.daphne | 7 ++ .../io/read/read_frame_dynamic-path-2.daphne | 7 ++ .../io/read/read_frame_dynamic-path-3.daphne | 7 ++ .../io/read/read_frame_mixed-no-str.daphne | 11 ++ .../cli/io/read/read_frame_mixed-str.daphne | 11 ++ .../cli/io/read/read_frame_read-in-udf.daphne | 13 +++ .../io/read/read_matrix_dynamic-path-1.daphne | 7 ++ .../io/read/read_matrix_dynamic-path-2.daphne | 7 ++ .../io/read/read_matrix_dynamic-path-3.daphne | 7 ++ test/api/cli/io/read/read_matrix_f64.daphne | 8 ++ .../io/read/read_matrix_read-in-udf.daphne | 11 ++ test/api/cli/io/read/read_matrix_si64.daphne | 8 ++ test/api/cli/io/read/read_matrix_str.daphne | 8 ++ test/api/cli/io/readMatrix.daphne | 4 - test/api/cli/io/readSparse.daphne | 2 - test/api/cli/io/readSparse.txt | 11 -- test/api/cli/io/ref/frame_123_ref.csv | 1 + test/api/cli/io/ref/frame_123_ref.csv.meta | 9 ++ .../api/cli/io/ref/frame_mixed-no-str_ref.csv | 5 + .../io/ref/frame_mixed-no-str_ref.csv.meta | 9 ++ test/api/cli/io/ref/frame_mixed-str_ref.csv | 6 + .../cli/io/ref/frame_mixed-str_ref.csv.meta | 9 ++ test/api/cli/io/ref/matrix_123_ref.csv | 1 + test/api/cli/io/ref/matrix_123_ref.csv.meta | 5 + test/api/cli/io/ref/matrix_f64_ref.csv | 2 + test/api/cli/io/ref/matrix_f64_ref.csv.meta | 5 + test/api/cli/io/ref/matrix_si64_ref.csv | 2 + test/api/cli/io/ref/matrix_si64_ref.csv.meta | 5 + test/api/cli/io/ref/matrix_str_ref.csv | 3 + .../matrix_str_ref.csv.meta} | 2 +- test/api/cli/io/ref/matrix_view_ref.csv | 2 + test/api/cli/io/ref/matrix_view_ref.csv.meta | 5 + test/api/cli/io/testReadFrame.daphne | 6 - test/api/cli/io/testReadFrame.txt | 3 - test/api/cli/io/testReadMatrix.daphne | 6 - test/api/cli/io/testReadMatrix.txt | 3 - .../cli/io/testReadMatrix_DynamicPath.daphne | 5 - .../api/cli/io/testReadStringIntoFrame.daphne | 3 - test/api/cli/io/testReadStringIntoFrame.txt | 5 - test/api/cli/io/testReadStringMatrix.daphne | 3 - test/api/cli/io/testReadStringMatrix.txt | 11 -- .../io/write/write_frame_mixed-no-str.daphne | 8 ++ .../cli/io/write/write_frame_mixed-str.daphne | 8 ++ test/api/cli/io/write/write_matrix_f64.daphne | 5 + .../api/cli/io/write/write_matrix_si64.daphne | 5 + test/api/cli/io/write/write_matrix_str.daphne | 5 + .../api/cli/io/write/write_matrix_view.daphne | 6 + test/api/cli/io/writeMatrix_full.daphne | 4 - test/api/cli/io/writeMatrix_view.daphne | 5 - test/runtime/local/io/ReadCsvTest.cpp | 8 +- 71 files changed, 480 insertions(+), 241 deletions(-) delete mode 100644 test/api/cli/io/.gitignore delete mode 100644 test/api/cli/io/ReadCsv1.csv delete mode 100644 test/api/cli/io/ReadCsv1.csv.meta delete mode 100644 test/api/cli/io/ReadCsv2.csv delete mode 100644 test/api/cli/io/ReadCsv3.csv delete mode 100644 test/api/cli/io/ReadCsv3.csv.meta delete mode 100644 test/api/cli/io/ReadTest.cpp create mode 100644 test/api/cli/io/ReadWriteTest.cpp delete mode 100644 test/api/cli/io/WriteTest.cpp create mode 100644 test/api/cli/io/check_frame.daphne create mode 100644 test/api/cli/io/check_matrix.daphne create mode 100644 test/api/cli/io/do_check_frame.daphne create mode 100644 test/api/cli/io/do_check_matrix.daphne delete mode 100644 test/api/cli/io/matrix_full_ref.csv delete mode 100644 test/api/cli/io/matrix_view_ref.csv create mode 100644 test/api/cli/io/out/.gitignore create mode 100644 test/api/cli/io/read/read_frame_dynamic-path-1.daphne create mode 100644 test/api/cli/io/read/read_frame_dynamic-path-2.daphne create mode 100644 test/api/cli/io/read/read_frame_dynamic-path-3.daphne create mode 100644 test/api/cli/io/read/read_frame_mixed-no-str.daphne create mode 100644 test/api/cli/io/read/read_frame_mixed-str.daphne create mode 100644 test/api/cli/io/read/read_frame_read-in-udf.daphne create mode 100644 test/api/cli/io/read/read_matrix_dynamic-path-1.daphne create mode 100644 test/api/cli/io/read/read_matrix_dynamic-path-2.daphne create mode 100644 test/api/cli/io/read/read_matrix_dynamic-path-3.daphne create mode 100644 test/api/cli/io/read/read_matrix_f64.daphne create mode 100644 test/api/cli/io/read/read_matrix_read-in-udf.daphne create mode 100644 test/api/cli/io/read/read_matrix_si64.daphne create mode 100644 test/api/cli/io/read/read_matrix_str.daphne delete mode 100644 test/api/cli/io/readMatrix.daphne delete mode 100644 test/api/cli/io/readSparse.daphne delete mode 100644 test/api/cli/io/readSparse.txt create mode 100644 test/api/cli/io/ref/frame_123_ref.csv create mode 100644 test/api/cli/io/ref/frame_123_ref.csv.meta create mode 100644 test/api/cli/io/ref/frame_mixed-no-str_ref.csv create mode 100644 test/api/cli/io/ref/frame_mixed-no-str_ref.csv.meta create mode 100644 test/api/cli/io/ref/frame_mixed-str_ref.csv create mode 100644 test/api/cli/io/ref/frame_mixed-str_ref.csv.meta create mode 100644 test/api/cli/io/ref/matrix_123_ref.csv create mode 100644 test/api/cli/io/ref/matrix_123_ref.csv.meta create mode 100644 test/api/cli/io/ref/matrix_f64_ref.csv create mode 100644 test/api/cli/io/ref/matrix_f64_ref.csv.meta create mode 100644 test/api/cli/io/ref/matrix_si64_ref.csv create mode 100644 test/api/cli/io/ref/matrix_si64_ref.csv.meta create mode 100644 test/api/cli/io/ref/matrix_str_ref.csv rename test/api/cli/io/{ReadCsv2.csv.meta => ref/matrix_str_ref.csv.meta} (70%) create mode 100644 test/api/cli/io/ref/matrix_view_ref.csv create mode 100644 test/api/cli/io/ref/matrix_view_ref.csv.meta delete mode 100644 test/api/cli/io/testReadFrame.daphne delete mode 100644 test/api/cli/io/testReadFrame.txt delete mode 100644 test/api/cli/io/testReadMatrix.daphne delete mode 100644 test/api/cli/io/testReadMatrix.txt delete mode 100644 test/api/cli/io/testReadMatrix_DynamicPath.daphne delete mode 100644 test/api/cli/io/testReadStringIntoFrame.daphne delete mode 100644 test/api/cli/io/testReadStringIntoFrame.txt delete mode 100644 test/api/cli/io/testReadStringMatrix.daphne delete mode 100644 test/api/cli/io/testReadStringMatrix.txt create mode 100644 test/api/cli/io/write/write_frame_mixed-no-str.daphne create mode 100644 test/api/cli/io/write/write_frame_mixed-str.daphne create mode 100644 test/api/cli/io/write/write_matrix_f64.daphne create mode 100644 test/api/cli/io/write/write_matrix_si64.daphne create mode 100644 test/api/cli/io/write/write_matrix_str.daphne create mode 100644 test/api/cli/io/write/write_matrix_view.daphne delete mode 100644 test/api/cli/io/writeMatrix_full.daphne delete mode 100644 test/api/cli/io/writeMatrix_view.daphne diff --git a/src/runtime/local/io/utils.h b/src/runtime/local/io/utils.h index 9f6fb821b..97ef3b002 100644 --- a/src/runtime/local/io/utils.h +++ b/src/runtime/local/io/utils.h @@ -108,7 +108,7 @@ inline size_t setCString(struct File *file, size_t start_pos, std::string *res, if (!is_not_end) break; if (is_multiLine && str[pos] == '"' && str[pos + 1] == '"') { - res->append("\"\""); + res->append("\""); pos += 2; } else if (is_multiLine && str[pos] == '\\' && str[pos + 1] == '"') { res->append("\\\""); diff --git a/src/runtime/local/kernels/CastObj.h b/src/runtime/local/kernels/CastObj.h index ecb53bcbb..3984276d4 100644 --- a/src/runtime/local/kernels/CastObj.h +++ b/src/runtime/local/kernels/CastObj.h @@ -68,7 +68,7 @@ template class CastObj, Frame> { const size_t numRows = argFrm->getNumRows(); const DenseMatrix *argCol = argFrm->getColumn(c); for (size_t r = 0; r < numRows; r++) - res->set(r, c, static_cast(argCol->get(r, 0))); + res->set(r, c, castSca(argCol->get(r, 0), nullptr)); DataObjectFactory::destroy(argCol); } @@ -126,6 +126,9 @@ template class CastObj, Frame> { case ValueTypeCode::UI8: castCol(res, arg, c); break; + case ValueTypeCode::STR: + castCol(res, arg, c); + break; default: throw std::runtime_error("CastObj::apply: unknown value type code"); } diff --git a/src/runtime/local/kernels/CastSca.h b/src/runtime/local/kernels/CastSca.h index 7e0b84b85..547fc6c0a 100644 --- a/src/runtime/local/kernels/CastSca.h +++ b/src/runtime/local/kernels/CastSca.h @@ -123,4 +123,12 @@ template struct CastSca { } }; +// ---------------------------------------------------------------------------- +// string <- string +// ---------------------------------------------------------------------------- + +template <> struct CastSca { + static std::string apply(const std::string arg, DaphneContext *ctx) { return arg; } +}; + #endif // SRC_RUNTIME_LOCAL_KERNELS_CASTSCA_H diff --git a/src/runtime/local/kernels/kernels.json b/src/runtime/local/kernels/kernels.json index 783872553..250a70845 100644 --- a/src/runtime/local/kernels/kernels.json +++ b/src/runtime/local/kernels/kernels.json @@ -419,6 +419,7 @@ [["DenseMatrix", "double"], "Frame"], [["DenseMatrix", "int64_t"], "Frame"], [["DenseMatrix", "uint64_t"], "Frame"], + [["DenseMatrix", "std::string"], "Frame"], ["Frame", ["DenseMatrix", "double"]], ["Frame", ["DenseMatrix", "int64_t"]], ["Frame", ["DenseMatrix", "uint64_t"]], diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index e4e6b758b..7b707e746 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -32,12 +32,10 @@ set(TEST_SOURCES api/cli/extensibility/HintTest.cpp api/cli/functions/FunctionsTest.cpp api/cli/functions/RecursiveFunctionsTest.cpp - api/cli/io/ReadTest.cpp - api/cli/io/WriteTest.cpp api/cli/import/ImportTest.cpp api/cli/indexing/IndexingTest.cpp api/cli/inference/InferenceTest.cpp - api/cli/io/ReadTest.cpp + api/cli/io/ReadWriteTest.cpp api/cli/lists/ListsTest.cpp api/cli/literals/LiteralsTest.cpp api/cli/operations/ConstantFoldingTest.cpp diff --git a/test/api/cli/io/.gitignore b/test/api/cli/io/.gitignore deleted file mode 100644 index 32aa93035..000000000 --- a/test/api/cli/io/.gitignore +++ /dev/null @@ -1,4 +0,0 @@ -matrix_full.csv -matrix_full.csv.meta -matrix_view.csv -matrix_view.csv.meta \ No newline at end of file diff --git a/test/api/cli/io/ReadCsv1.csv b/test/api/cli/io/ReadCsv1.csv deleted file mode 100644 index 79d814f7b..000000000 --- a/test/api/cli/io/ReadCsv1.csv +++ /dev/null @@ -1,2 +0,0 @@ --0.1,-0.2,0.1,0.2 -3.14,5.41,6.22216,5 diff --git a/test/api/cli/io/ReadCsv1.csv.meta b/test/api/cli/io/ReadCsv1.csv.meta deleted file mode 100644 index 82073032e..000000000 --- a/test/api/cli/io/ReadCsv1.csv.meta +++ /dev/null @@ -1,6 +0,0 @@ -{ - "numRows": 2, - "numCols": 4, - "valueType": "f64", - "numNonZeros": 0 -} \ No newline at end of file diff --git a/test/api/cli/io/ReadCsv2.csv b/test/api/cli/io/ReadCsv2.csv deleted file mode 100644 index e09ff2e54..000000000 --- a/test/api/cli/io/ReadCsv2.csv +++ /dev/null @@ -1,10 +0,0 @@ -"banana, grape",36,Fruit Basket -"xyz""uvw",34,"No Category\"" -"parrot, rabbit",31,Pets -"line1 -line2",26,"with newline" -chair,28,Furniture Set -"green, yellow\n",51, -"\n\"xyz""uvw\"",29,"Mixed string" -"\"blue, \"\"",42,"" -"unknown, item",23,Unknown Item diff --git a/test/api/cli/io/ReadCsv3.csv b/test/api/cli/io/ReadCsv3.csv deleted file mode 100644 index 81fa66d93..000000000 --- a/test/api/cli/io/ReadCsv3.csv +++ /dev/null @@ -1,4 +0,0 @@ -1,-1,"" -2,-2, -3,-3,"multi-line," -4,-4,simple string diff --git a/test/api/cli/io/ReadCsv3.csv.meta b/test/api/cli/io/ReadCsv3.csv.meta deleted file mode 100644 index fc14ce73d..000000000 --- a/test/api/cli/io/ReadCsv3.csv.meta +++ /dev/null @@ -1,18 +0,0 @@ -{ - "numRows": 4, - "numCols": 3, - "schema": [ - { - "label": "a", - "valueType": "si8" - }, - { - "label": "b", - "valueType": "ui8" - }, - { - "label": "c", - "valueType": "str" - } - ] -} diff --git a/test/api/cli/io/ReadTest.cpp b/test/api/cli/io/ReadTest.cpp deleted file mode 100644 index 885ce81bd..000000000 --- a/test/api/cli/io/ReadTest.cpp +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright 2021 The DAPHNE Consortium - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include - -#include - -#include -#include - -const std::string dirPath = "test/api/cli/io/"; - -// TODO Add script-level test cases for reading files in various formats. -// This test case used to read a COO file, but being a quick fix, it was not -// integrated cleanly. There should either be a seprate reader for COO, or we -// should do that via the respective Matrix Market format. -#if 0 -TEST_CASE("readSparse", TAG_IO) { - auto arg = "filename=\"" + dirPath + "readSparse.coo\""; - compareDaphneToRef(dirPath + "readSparse.txt", - dirPath + "readSparse.daphne", - "--select-matrix-representations", - "--args", - arg.c_str()); -} -#endif - -TEST_CASE("readFrameFromCSV", TAG_IO) { - compareDaphneToRef(dirPath + "testReadFrame.txt", dirPath + "testReadFrame.daphne"); -} - -TEST_CASE("readStringValuesIntoFrameFromCSV", TAG_IO) { - compareDaphneToRef(dirPath + "testReadStringIntoFrame.txt", dirPath + "testReadStringIntoFrame.daphne"); -} - -TEST_CASE("readMatrixFromCSV", TAG_IO) { - compareDaphneToRef(dirPath + "testReadMatrix.txt", dirPath + "testReadMatrix.daphne"); -} - -TEST_CASE("readStringMatrixFromCSV", TAG_IO) { - compareDaphneToRef(dirPath + "testReadStringMatrix.txt", dirPath + "testReadStringMatrix.daphne"); -} - -// does not yet work! -// TEST_CASE("readReadMatrixFromCSV_DynamicPath", TAG_IO) -// { -// compareDaphneToRef(dirPath + "testReadMatrix.txt", dirPath + -// "testReadMatrix_DynamicPath.daphne"); -// } \ No newline at end of file diff --git a/test/api/cli/io/ReadWriteTest.cpp b/test/api/cli/io/ReadWriteTest.cpp new file mode 100644 index 000000000..9a8f045b9 --- /dev/null +++ b/test/api/cli/io/ReadWriteTest.cpp @@ -0,0 +1,109 @@ +/* + * Copyright 2024 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include + +#include + +#include +#include + +const std::string dirPath = "test/api/cli/io/"; + +// ******************************************************************************** +// Read test cases +// ******************************************************************************** + +// These test cases check if the CSV reader reads the expected data from a CSV file. +// The data is chosen to trigger various interesting cases. The tests perform the following steps: +// - read a matrix/frame from a specified CSV file +// - compare it to a matrix/frame hard-coded in a DaphneDSL script (nan-safe for floating-point value types) +// - check if the number of incorrect elements is zero +#define MAKE_READ_TEST_CASE(dt, name, nanSafe) \ + TEST_CASE("read_" dt "_" name, TAG_IO) { \ + const std::string scriptPath = dirPath + "read/read_" + dt + "_" + name + ".daphne"; \ + const std::string inPath = dirPath + "ref/" + dt + "_" + name + "_ref.csv"; \ + compareDaphneToStr("0\n", scriptPath.c_str(), "--args", \ + ("inPath=\"" + inPath + "\",nanSafe=" + nanSafe).c_str()); \ + } + +MAKE_READ_TEST_CASE("matrix", "si64", "false") +MAKE_READ_TEST_CASE("matrix", "f64", "true") +MAKE_READ_TEST_CASE("matrix", "str", "false") +MAKE_READ_TEST_CASE("frame", "mixed-no-str", "false") +MAKE_READ_TEST_CASE("frame", "mixed-str", "false") + +// These test cases check if the CSV reader can be used in various relevant ways in DaphneDSL. +// The data is rather simple. The tests perform the following steps: +// - read a matrix/frame from a specified CSV file +// - compare it to a matrix/frame hard-coded in a DaphneDSL script +// - check if the number of incorrect elements is zero +#define MAKE_READ_TEST_CASE_2(scriptName) \ + TEST_CASE("read_" scriptName ".daphne", TAG_IO) { \ + const std::string scriptPath = dirPath + "read/read_" + scriptName + ".daphne"; \ + compareDaphneToStr("0\n", scriptPath.c_str()); \ + } + +// TODO The commented out test cases don't work yet (see #931). +MAKE_READ_TEST_CASE_2("matrix_read-in-udf") +MAKE_READ_TEST_CASE_2("matrix_dynamic-path-1") +// MAKE_READ_TEST_CASE_2("matrix_dynamic-path-2") +// MAKE_READ_TEST_CASE_2("matrix_dynamic-path-3") +MAKE_READ_TEST_CASE_2("frame_read-in-udf") +MAKE_READ_TEST_CASE_2("frame_dynamic-path-1") +// MAKE_READ_TEST_CASE_2("frame_dynamic-path-2") +// MAKE_READ_TEST_CASE_2("frame_dynamic-path-3") + +// ******************************************************************************** +// Write test cases +// ******************************************************************************** + +// These test cases check if the CSV writer produces the expected CSV files. +// The data is chosen to trigger various interesting cases. The tests perform the following steps: +// - start with matrix/frame hard-coded in a DaphneDSL script +// - write it to a CSV file +// - don't compare the file contents to a reference file, as there can be multiple equivalent CSV representations, +// because +// - any field may be quoted, but only certain fields need to be quoted +// - there could be trailing zeros in floating-point numbers +// - there are different valid notations for floating-point numbers (e.g., 0.001 vs 1e-3) +// - there are different valid notations for integers (e.g., 32 vs 0x20) +// - ... +// - instead, read the written file and a reference CSV file and compare the two read matrices/frames in DaphneDSL +// - note: these tests don't check if the matrix/frame read from the reference CSV file looks as expected (that's done +// by the read tests) +#define MAKE_WRITE_TEST_CASE(dt, name, nanSafe) \ + TEST_CASE("write_" dt "_" name, TAG_IO) { \ + const std::string scriptPathWrt = dirPath + "write/write_" + dt + "_" + name + ".daphne"; \ + const std::string scriptPathCmp = dirPath + "do_check_" + dt + ".daphne"; \ + const std::string outPath = dirPath + "out/" + dt + "_" + name + ".csv"; \ + const std::string refPath = dirPath + "ref/" + dt + "_" + name + "_ref.csv"; \ + std::filesystem::remove(outPath); /* remove old output file if it still exists */ \ + checkDaphneStatusCode(StatusCode::SUCCESS, scriptPathWrt.c_str(), "--args", \ + ("outPath=\"" + outPath + "\"").c_str()); \ + compareDaphneToStr("0\n", scriptPathCmp.c_str(), "--args", \ + ("chkPath=\"" + outPath + "\",refPath=\"" + refPath + "\",nanSafe=" + nanSafe).c_str()); \ + } + +// TODO The commented out test cases don't work yet, as the CSV writer doesn't support strings yet. +MAKE_WRITE_TEST_CASE("matrix", "si64", "false") +MAKE_WRITE_TEST_CASE("matrix", "f64", "true") +// MAKE_WRITE_TEST_CASE("matrix", "str", "false") +MAKE_WRITE_TEST_CASE("matrix", "view", "false") +MAKE_WRITE_TEST_CASE("frame", "mixed-no-str", "false") +// MAKE_WRITE_TEST_CASE("frame", "mixed-str", "false") \ No newline at end of file diff --git a/test/api/cli/io/WriteTest.cpp b/test/api/cli/io/WriteTest.cpp deleted file mode 100644 index b58b8ec09..000000000 --- a/test/api/cli/io/WriteTest.cpp +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Copyright 2024 The DAPHNE Consortium - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include - -#include - -#include -#include - -const std::string dirPath = "test/api/cli/io/"; - -TEST_CASE("writeMatrixCSV_Full", TAG_IO) { - std::string csvPath = dirPath + "matrix_full.csv"; - std::filesystem::remove(csvPath); // remove old file if it still exists - checkDaphneStatusCode(StatusCode::SUCCESS, dirPath + "writeMatrix_full.daphne", "--args", - std::string("outPath=\"" + csvPath + "\"").c_str()); - compareDaphneToRef(dirPath + "matrix_full_ref.csv", dirPath + "readMatrix.daphne", "--args", - std::string("inPath=\"" + csvPath + "\"").c_str()); -} - -TEST_CASE("writeMatrixCSV_View", TAG_IO) { - std::string csvPath = dirPath + "matrix_view.csv"; - std::filesystem::remove(csvPath); // remove old file if it still exists - checkDaphneStatusCode(StatusCode::SUCCESS, dirPath + "writeMatrix_view.daphne", "--args", - std::string("outPath=\"" + csvPath + "\"").c_str()); - compareDaphneToRef(dirPath + "matrix_view_ref.csv", dirPath + "readMatrix.daphne", "--args", - std::string("inPath=\"" + csvPath + "\"").c_str()); -} \ No newline at end of file diff --git a/test/api/cli/io/check_frame.daphne b/test/api/cli/io/check_frame.daphne new file mode 100644 index 000000000..8084373ee --- /dev/null +++ b/test/api/cli/io/check_frame.daphne @@ -0,0 +1,53 @@ +# Copyright 2024 The DAPHNE Consortium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Counts in how many corresponding elements the two given frames differ. +# If nanSafe is true, then two corresponding nan elements are considered equal. +# If there are differences, the two input frames are printed for debug information. +def checkFrame(chk, ref, nanSafe:bool) { + if(nanSafe) + stop("nan-safe comparison of frames is not supported yet"); + + if(ncol(chk) != ncol(ref)) + stop("the two input frames have different #cols: found " + ncol(chk) + " but expected " + ncol(ref)); + + numDiff = 0; + + # TODO Doesn't work (missing elementwise binary ops on two frames, see #932). + # numDiff = numDiff + sum(as.matrix(chk != ref)); + + # TODO Doesn't work, because the type of the c-th column isn't known at compile-time. + # for(c in 0:ncol(chk)-1) { + # colChk = as.matrix(chk[, c]); + # colRef = as.matrix(ref[, c]); + # numDiff = numDiff + sum(as.matrix(colChk != colRef)); + # } + + # TODO The following code only works for a hard-coded number of columns. + if(ncol(chk) != 3) + stop("this script only works for frames with exactly three columns"); + numDiff = numDiff + sum(as.si64(as.matrix(chk[, 0]) != as.matrix(ref[, 0]))); + numDiff = numDiff + sum(as.si64(as.matrix(chk[, 1]) != as.matrix(ref[, 1]))); + numDiff = numDiff + sum(as.si64(as.matrix(chk[, 2]) != as.matrix(ref[, 2]))); + + print(numDiff); + + # Debug output. + if(numDiff > 0) { + print("chk"); + print(chk); + print("ref"); + print(ref); + } +} \ No newline at end of file diff --git a/test/api/cli/io/check_matrix.daphne b/test/api/cli/io/check_matrix.daphne new file mode 100644 index 000000000..6ed473dca --- /dev/null +++ b/test/api/cli/io/check_matrix.daphne @@ -0,0 +1,36 @@ +# Copyright 2024 The DAPHNE Consortium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Counts in how many corresponding elements the two given matrices differ. +# If nanSafe is true, then two corresponding nan elements are considered equal. +# If there are differences, the two input matrices are printed for debug information. +def checkMatrix(chk, ref, nanSafe:bool) { + numDiff = 0; + if(nanSafe) + # The corresponding values are not equal and at least one of them is not nan (they're not both nan). + numDiff = sum(as.matrix(chk != ref) && (as.matrix(isNan(chk)) == 0 || as.matrix(isNan(ref)) == 0)); + else + # The corresponding values are not equal. + numDiff = sum(as.matrix(chk != ref)); + + print(numDiff); + + # Debug output. + if(numDiff > 0) { + print("chk"); + print(chk); + print("ref"); + print(ref); + } +} \ No newline at end of file diff --git a/test/api/cli/io/do_check_frame.daphne b/test/api/cli/io/do_check_frame.daphne new file mode 100644 index 000000000..75bbc3d71 --- /dev/null +++ b/test/api/cli/io/do_check_frame.daphne @@ -0,0 +1,21 @@ +# Copyright 2024 The DAPHNE Consortium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Reads the files $chkPath and $refPath and compares the two read frames. + +import "check_frame.daphne"; + +chk = readFrame($chkPath); +ref = readFrame($refPath); +check_frame.checkFrame(chk, ref, $nanSafe); \ No newline at end of file diff --git a/test/api/cli/io/do_check_matrix.daphne b/test/api/cli/io/do_check_matrix.daphne new file mode 100644 index 000000000..578017b9f --- /dev/null +++ b/test/api/cli/io/do_check_matrix.daphne @@ -0,0 +1,21 @@ +# Copyright 2024 The DAPHNE Consortium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Reads the files $chkPath and $refPath and compares the two read matrices. + +import "check_matrix.daphne"; + +chk = readMatrix($chkPath); +ref = readMatrix($refPath); +check_matrix.checkMatrix(chk, ref, $nanSafe); \ No newline at end of file diff --git a/test/api/cli/io/matrix_full_ref.csv b/test/api/cli/io/matrix_full_ref.csv deleted file mode 100644 index b828190bc..000000000 --- a/test/api/cli/io/matrix_full_ref.csv +++ /dev/null @@ -1,4 +0,0 @@ -DenseMatrix(3x2, int64_t) -1 2 -3 4 -5 6 diff --git a/test/api/cli/io/matrix_view_ref.csv b/test/api/cli/io/matrix_view_ref.csv deleted file mode 100644 index 767f21dd4..000000000 --- a/test/api/cli/io/matrix_view_ref.csv +++ /dev/null @@ -1,4 +0,0 @@ -DenseMatrix(3x1, int64_t) -2 -4 -6 diff --git a/test/api/cli/io/out/.gitignore b/test/api/cli/io/out/.gitignore new file mode 100644 index 000000000..9e06bd6c2 --- /dev/null +++ b/test/api/cli/io/out/.gitignore @@ -0,0 +1,2 @@ +*.csv +*.csv.meta \ No newline at end of file diff --git a/test/api/cli/io/read/read_frame_dynamic-path-1.daphne b/test/api/cli/io/read/read_frame_dynamic-path-1.daphne new file mode 100644 index 000000000..ade414fde --- /dev/null +++ b/test/api/cli/io/read/read_frame_dynamic-path-1.daphne @@ -0,0 +1,7 @@ +# Read a frame from a file when the file path is the result of an expression (only string concat). + +import "../check_frame.daphne"; + +chk = readFrame("test/api/cli/io/ref/frame_" + "123" + "_ref.csv"); +ref = {"a": [1], "b": [2], "c": [3]}; +check_frame.checkFrame(chk, ref, false); \ No newline at end of file diff --git a/test/api/cli/io/read/read_frame_dynamic-path-2.daphne b/test/api/cli/io/read/read_frame_dynamic-path-2.daphne new file mode 100644 index 000000000..03666a93c --- /dev/null +++ b/test/api/cli/io/read/read_frame_dynamic-path-2.daphne @@ -0,0 +1,7 @@ +# Read a frame from a file when the file path is the result of an expression (string/number concat). + +import "../check_frame.daphne"; + +chk = readFrame("test/api/cli/io/ref/frame_" + 123 + "_ref.csv"); +ref = {"a": [1], "b": [2], "c": [3]}; +check_frame.checkFrame(chk, ref, false); \ No newline at end of file diff --git a/test/api/cli/io/read/read_frame_dynamic-path-3.daphne b/test/api/cli/io/read/read_frame_dynamic-path-3.daphne new file mode 100644 index 000000000..94a8623cd --- /dev/null +++ b/test/api/cli/io/read/read_frame_dynamic-path-3.daphne @@ -0,0 +1,7 @@ +# Read a frame from a file when the file path is the result of an expression (string/casted number concat). + +import "../check_frame.daphne"; + +chk = readFrame("test/api/cli/io/ref/frame_" + as.str(123) + "_ref.csv"); +ref = {"a": [1], "b": [2], "c": [3]}; +check_frame.checkFrame(chk, ref, false); \ No newline at end of file diff --git a/test/api/cli/io/read/read_frame_mixed-no-str.daphne b/test/api/cli/io/read/read_frame_mixed-no-str.daphne new file mode 100644 index 000000000..c04df7052 --- /dev/null +++ b/test/api/cli/io/read/read_frame_mixed-no-str.daphne @@ -0,0 +1,11 @@ +# Read a frame with columns of various value types (not including string) from a file. + +import "../check_frame.daphne"; + +# TODO Test nan values. + +chk = readFrame($inPath); +ref = {"c_si64": [0, 1, -22, 3, -44], + "c_f64": [0.0, 1.1, -22.2, inf, -inf], + "c_si64_": [1000, 2000, 3000, 4000, 5000]}; +check_frame.checkFrame(chk, ref, $nanSafe); \ No newline at end of file diff --git a/test/api/cli/io/read/read_frame_mixed-str.daphne b/test/api/cli/io/read/read_frame_mixed-str.daphne new file mode 100644 index 000000000..d75f2668a --- /dev/null +++ b/test/api/cli/io/read/read_frame_mixed-str.daphne @@ -0,0 +1,11 @@ +# Read a frame with columns of various value types (including string) from a file. + +import "../check_frame.daphne"; + +# TODO test nan + +chk = readFrame($inPath); +ref = {"c_si64": [0, 1, -22, 3, -44], + "c_f64": [0.0, 1.1, -22.2, inf, -inf], + "c_str": ["abc", "", "d\"e", "fg\nhi", "mn,op"]}; +check_frame.checkFrame(chk, ref, $nanSafe); \ No newline at end of file diff --git a/test/api/cli/io/read/read_frame_read-in-udf.daphne b/test/api/cli/io/read/read_frame_read-in-udf.daphne new file mode 100644 index 000000000..da3e34fcd --- /dev/null +++ b/test/api/cli/io/read/read_frame_read-in-udf.daphne @@ -0,0 +1,13 @@ +# Read a frame from a file when the file path is a parameter to a UDF. + +import "../check_frame.daphne"; + +# TODO Test nan values. + +def myReadFrame(path:str) { + return readFrame(path); +} + +chk = myReadFrame("test/api/cli/io/ref/frame_123_ref.csv"); +ref = {"a": [1], "b": [2], "c": [3]}; +check_frame.checkFrame(chk, ref, false); \ No newline at end of file diff --git a/test/api/cli/io/read/read_matrix_dynamic-path-1.daphne b/test/api/cli/io/read/read_matrix_dynamic-path-1.daphne new file mode 100644 index 000000000..69fb87368 --- /dev/null +++ b/test/api/cli/io/read/read_matrix_dynamic-path-1.daphne @@ -0,0 +1,7 @@ +# Read a matrix from a file when the file path is the result of an expression (only string concat). + +import "../check_matrix.daphne"; + +chk = readMatrix("test/api/cli/io/ref/matrix_" + "123" + "_ref.csv"); +ref = [1]; +check_matrix.checkMatrix(chk, ref, false); \ No newline at end of file diff --git a/test/api/cli/io/read/read_matrix_dynamic-path-2.daphne b/test/api/cli/io/read/read_matrix_dynamic-path-2.daphne new file mode 100644 index 000000000..551a36713 --- /dev/null +++ b/test/api/cli/io/read/read_matrix_dynamic-path-2.daphne @@ -0,0 +1,7 @@ +# Read a matrix from a file when the file path is the result of an expression (string/number concat). + +import "../check_matrix.daphne"; + +chk = readMatrix("test/api/cli/io/ref/matrix_" + 123 + "_ref.csv"); +ref = [1]; +check_matrix.checkMatrix(chk, ref, false); \ No newline at end of file diff --git a/test/api/cli/io/read/read_matrix_dynamic-path-3.daphne b/test/api/cli/io/read/read_matrix_dynamic-path-3.daphne new file mode 100644 index 000000000..99736cd5a --- /dev/null +++ b/test/api/cli/io/read/read_matrix_dynamic-path-3.daphne @@ -0,0 +1,7 @@ +# Read a matrix from a file when the file path is the result of an expression (string/casted number concat). + +import "../check_matrix.daphne"; + +chk = readMatrix("test/api/cli/io/ref/matrix_" + as.str(123) + "_ref.csv"); +ref = [1]; +check_matrix.checkMatrix(chk, ref, false); \ No newline at end of file diff --git a/test/api/cli/io/read/read_matrix_f64.daphne b/test/api/cli/io/read/read_matrix_f64.daphne new file mode 100644 index 000000000..7c551f164 --- /dev/null +++ b/test/api/cli/io/read/read_matrix_f64.daphne @@ -0,0 +1,8 @@ +# Read a matrix of value type f64 from a file. + +import "../check_matrix.daphne"; + +chk = readMatrix($inPath); +ref = [1.1, -22.2, 3.3, -44.4, 5.5, + -66.6, 0.0, nan, inf, -inf](2, 5); +check_matrix.checkMatrix(chk, ref, $nanSafe); \ No newline at end of file diff --git a/test/api/cli/io/read/read_matrix_read-in-udf.daphne b/test/api/cli/io/read/read_matrix_read-in-udf.daphne new file mode 100644 index 000000000..50dc821fb --- /dev/null +++ b/test/api/cli/io/read/read_matrix_read-in-udf.daphne @@ -0,0 +1,11 @@ +# Read a matrix from a file when the file path is a parameter to a UDF. + +import "../check_matrix.daphne"; + +def myReadMatrix(path:str) { + return readMatrix(path); +} + +chk = myReadMatrix("test/api/cli/io/ref/matrix_123_ref.csv"); +ref = [1]; +check_matrix.checkMatrix(chk, ref, false); \ No newline at end of file diff --git a/test/api/cli/io/read/read_matrix_si64.daphne b/test/api/cli/io/read/read_matrix_si64.daphne new file mode 100644 index 000000000..aa412bf73 --- /dev/null +++ b/test/api/cli/io/read/read_matrix_si64.daphne @@ -0,0 +1,8 @@ +# Read a matrix of value type si64 from a file. + +import "../check_matrix.daphne"; + +chk = readMatrix($inPath); +ref = [1, -22, 3, -44, + 5, -66, 0, 0](2, 4); +check_matrix.checkMatrix(chk, ref, $nanSafe); \ No newline at end of file diff --git a/test/api/cli/io/read/read_matrix_str.daphne b/test/api/cli/io/read/read_matrix_str.daphne new file mode 100644 index 000000000..27a02a8a4 --- /dev/null +++ b/test/api/cli/io/read/read_matrix_str.daphne @@ -0,0 +1,8 @@ +# Read a matrix of value type str from a file. + +import "../check_matrix.daphne"; + +chk = readMatrix($inPath); +ref = ["abc", "", "d\"e", + "fg\nhi", "jkl", "mn,op"](2, 3); +check_matrix.checkMatrix(chk, ref, $nanSafe); \ No newline at end of file diff --git a/test/api/cli/io/readMatrix.daphne b/test/api/cli/io/readMatrix.daphne deleted file mode 100644 index c47563c51..000000000 --- a/test/api/cli/io/readMatrix.daphne +++ /dev/null @@ -1,4 +0,0 @@ -// Read a matrix from the file $inPath and print it. - -X = readMatrix($inPath); -print(X); \ No newline at end of file diff --git a/test/api/cli/io/readSparse.daphne b/test/api/cli/io/readSparse.daphne deleted file mode 100644 index 578ece638..000000000 --- a/test/api/cli/io/readSparse.daphne +++ /dev/null @@ -1,2 +0,0 @@ -M = readMatrix($filename); -print(M); \ No newline at end of file diff --git a/test/api/cli/io/readSparse.txt b/test/api/cli/io/readSparse.txt deleted file mode 100644 index a6ed7d416..000000000 --- a/test/api/cli/io/readSparse.txt +++ /dev/null @@ -1,11 +0,0 @@ -CSRMatrix(10x5, double) -1 0 0 0 0 -0 0 0 0 0 -0 0 0 0 0 -0 0 0 0 0 -0 0 0 0 0 -0 1 1 0 0 -0 0 0 0 0 -0 0 0 0 0 -0 0 0 0 0 -0 0 0 0 1 diff --git a/test/api/cli/io/ref/frame_123_ref.csv b/test/api/cli/io/ref/frame_123_ref.csv new file mode 100644 index 000000000..b0246d596 --- /dev/null +++ b/test/api/cli/io/ref/frame_123_ref.csv @@ -0,0 +1 @@ +1,2,3 diff --git a/test/api/cli/io/ref/frame_123_ref.csv.meta b/test/api/cli/io/ref/frame_123_ref.csv.meta new file mode 100644 index 000000000..6a5eb1cc3 --- /dev/null +++ b/test/api/cli/io/ref/frame_123_ref.csv.meta @@ -0,0 +1,9 @@ +{ + "numRows": 1, + "numCols": 3, + "schema": [ + {"label": "a", "valueType": "si64"}, + {"label": "b", "valueType": "si64"}, + {"label": "c", "valueType": "si64"} + ] +} \ No newline at end of file diff --git a/test/api/cli/io/ref/frame_mixed-no-str_ref.csv b/test/api/cli/io/ref/frame_mixed-no-str_ref.csv new file mode 100644 index 000000000..aa53316e6 --- /dev/null +++ b/test/api/cli/io/ref/frame_mixed-no-str_ref.csv @@ -0,0 +1,5 @@ +0,0.0,1000 +1,1.1,2000 +-22,-22.2,3000 +3,inf,4000 +-44,-inf,5000 diff --git a/test/api/cli/io/ref/frame_mixed-no-str_ref.csv.meta b/test/api/cli/io/ref/frame_mixed-no-str_ref.csv.meta new file mode 100644 index 000000000..ca2b6b377 --- /dev/null +++ b/test/api/cli/io/ref/frame_mixed-no-str_ref.csv.meta @@ -0,0 +1,9 @@ +{ + "numRows": 5, + "numCols": 3, + "schema": [ + {"label": "c_si64", "valueType": "si64"}, + {"label": "c_f64", "valueType": "f64"}, + {"label": "c_si64_", "valueType": "si64"} + ] +} \ No newline at end of file diff --git a/test/api/cli/io/ref/frame_mixed-str_ref.csv b/test/api/cli/io/ref/frame_mixed-str_ref.csv new file mode 100644 index 000000000..f55d63405 --- /dev/null +++ b/test/api/cli/io/ref/frame_mixed-str_ref.csv @@ -0,0 +1,6 @@ +0,0.0,abc +1,1.1, +-22,-22.2,"d""e" +3,inf,"fg +hi" +-44,-inf,"mn,op" diff --git a/test/api/cli/io/ref/frame_mixed-str_ref.csv.meta b/test/api/cli/io/ref/frame_mixed-str_ref.csv.meta new file mode 100644 index 000000000..986e25f0c --- /dev/null +++ b/test/api/cli/io/ref/frame_mixed-str_ref.csv.meta @@ -0,0 +1,9 @@ +{ + "numRows": 5, + "numCols": 3, + "schema": [ + {"label": "c_si64", "valueType": "si64"}, + {"label": "c_f64", "valueType": "f64"}, + {"label": "c_str", "valueType": "str"} + ] +} \ No newline at end of file diff --git a/test/api/cli/io/ref/matrix_123_ref.csv b/test/api/cli/io/ref/matrix_123_ref.csv new file mode 100644 index 000000000..d00491fd7 --- /dev/null +++ b/test/api/cli/io/ref/matrix_123_ref.csv @@ -0,0 +1 @@ +1 diff --git a/test/api/cli/io/ref/matrix_123_ref.csv.meta b/test/api/cli/io/ref/matrix_123_ref.csv.meta new file mode 100644 index 000000000..97ad56a68 --- /dev/null +++ b/test/api/cli/io/ref/matrix_123_ref.csv.meta @@ -0,0 +1,5 @@ +{ + "numRows": 1, + "numCols": 1, + "valueType": "si64" +} \ No newline at end of file diff --git a/test/api/cli/io/ref/matrix_f64_ref.csv b/test/api/cli/io/ref/matrix_f64_ref.csv new file mode 100644 index 000000000..f0205d55d --- /dev/null +++ b/test/api/cli/io/ref/matrix_f64_ref.csv @@ -0,0 +1,2 @@ +1.1,-22.2,3.3,-44.4,5.5 +-66.6,0.0,nan,inf,-inf diff --git a/test/api/cli/io/ref/matrix_f64_ref.csv.meta b/test/api/cli/io/ref/matrix_f64_ref.csv.meta new file mode 100644 index 000000000..d5b38231d --- /dev/null +++ b/test/api/cli/io/ref/matrix_f64_ref.csv.meta @@ -0,0 +1,5 @@ +{ + "numRows": 2, + "numCols": 5, + "valueType": "f64" +} \ No newline at end of file diff --git a/test/api/cli/io/ref/matrix_si64_ref.csv b/test/api/cli/io/ref/matrix_si64_ref.csv new file mode 100644 index 000000000..332b71f5c --- /dev/null +++ b/test/api/cli/io/ref/matrix_si64_ref.csv @@ -0,0 +1,2 @@ +1,-22,3,-44 +5,-66,0,0 diff --git a/test/api/cli/io/ref/matrix_si64_ref.csv.meta b/test/api/cli/io/ref/matrix_si64_ref.csv.meta new file mode 100644 index 000000000..0b44c2898 --- /dev/null +++ b/test/api/cli/io/ref/matrix_si64_ref.csv.meta @@ -0,0 +1,5 @@ +{ + "numRows": 2, + "numCols": 4, + "valueType": "si64" +} \ No newline at end of file diff --git a/test/api/cli/io/ref/matrix_str_ref.csv b/test/api/cli/io/ref/matrix_str_ref.csv new file mode 100644 index 000000000..d18309603 --- /dev/null +++ b/test/api/cli/io/ref/matrix_str_ref.csv @@ -0,0 +1,3 @@ +abc,,"d""e" +"fg +hi","jkl","mn,op" diff --git a/test/api/cli/io/ReadCsv2.csv.meta b/test/api/cli/io/ref/matrix_str_ref.csv.meta similarity index 70% rename from test/api/cli/io/ReadCsv2.csv.meta rename to test/api/cli/io/ref/matrix_str_ref.csv.meta index 0fe44b980..4e96a97fc 100644 --- a/test/api/cli/io/ReadCsv2.csv.meta +++ b/test/api/cli/io/ref/matrix_str_ref.csv.meta @@ -1,5 +1,5 @@ { - "numRows": 9, + "numRows": 2, "numCols": 3, "valueType": "str" } \ No newline at end of file diff --git a/test/api/cli/io/ref/matrix_view_ref.csv b/test/api/cli/io/ref/matrix_view_ref.csv new file mode 100644 index 000000000..332b71f5c --- /dev/null +++ b/test/api/cli/io/ref/matrix_view_ref.csv @@ -0,0 +1,2 @@ +1,-22,3,-44 +5,-66,0,0 diff --git a/test/api/cli/io/ref/matrix_view_ref.csv.meta b/test/api/cli/io/ref/matrix_view_ref.csv.meta new file mode 100644 index 000000000..0b44c2898 --- /dev/null +++ b/test/api/cli/io/ref/matrix_view_ref.csv.meta @@ -0,0 +1,5 @@ +{ + "numRows": 2, + "numCols": 4, + "valueType": "si64" +} \ No newline at end of file diff --git a/test/api/cli/io/testReadFrame.daphne b/test/api/cli/io/testReadFrame.daphne deleted file mode 100644 index 59bde6fe6..000000000 --- a/test/api/cli/io/testReadFrame.daphne +++ /dev/null @@ -1,6 +0,0 @@ -# Test reading from a file when the file path is not trivially constant (i.e., a parameter to a UDF) -def readFrameFromCSV(path: str) { - print(readFrame(path)); -} - -readFrameFromCSV("test/api/cli/io/ReadCsv1.csv"); \ No newline at end of file diff --git a/test/api/cli/io/testReadFrame.txt b/test/api/cli/io/testReadFrame.txt deleted file mode 100644 index e5af505cc..000000000 --- a/test/api/cli/io/testReadFrame.txt +++ /dev/null @@ -1,3 +0,0 @@ -Frame(2x4, [col_0:double, col_1:double, col_2:double, col_3:double]) --0.1 -0.2 0.1 0.2 -3.14 5.41 6.22216 5 diff --git a/test/api/cli/io/testReadMatrix.daphne b/test/api/cli/io/testReadMatrix.daphne deleted file mode 100644 index cefdb281c..000000000 --- a/test/api/cli/io/testReadMatrix.daphne +++ /dev/null @@ -1,6 +0,0 @@ -# Test reading from a file when the file path is not trivially constant (i.e., a parameter to a UDF) -def readMatrixFromCSV(path: str) { - print(readMatrix(path)); -} - -readMatrixFromCSV("test/api/cli/io/ReadCsv1.csv"); \ No newline at end of file diff --git a/test/api/cli/io/testReadMatrix.txt b/test/api/cli/io/testReadMatrix.txt deleted file mode 100644 index 26ed2a258..000000000 --- a/test/api/cli/io/testReadMatrix.txt +++ /dev/null @@ -1,3 +0,0 @@ -DenseMatrix(2x4, double) --0.1 -0.2 0.1 0.2 -3.14 5.41 6.22216 5 diff --git a/test/api/cli/io/testReadMatrix_DynamicPath.daphne b/test/api/cli/io/testReadMatrix_DynamicPath.daphne deleted file mode 100644 index 6c31fba84..000000000 --- a/test/api/cli/io/testReadMatrix_DynamicPath.daphne +++ /dev/null @@ -1,5 +0,0 @@ -# Test dynamic computation of string path -> does not yet work! -i = 1; -filename = "test/api/cli/io/ReadCsv" + i + ".csv"; -m = readMatrix(filename); -print(m); \ No newline at end of file diff --git a/test/api/cli/io/testReadStringIntoFrame.daphne b/test/api/cli/io/testReadStringIntoFrame.daphne deleted file mode 100644 index d2b457193..000000000 --- a/test/api/cli/io/testReadStringIntoFrame.daphne +++ /dev/null @@ -1,3 +0,0 @@ -// Test reading frame with string columns. - -print(readFrame("test/api/cli/io/ReadCsv3.csv")); \ No newline at end of file diff --git a/test/api/cli/io/testReadStringIntoFrame.txt b/test/api/cli/io/testReadStringIntoFrame.txt deleted file mode 100644 index 1cca28d7b..000000000 --- a/test/api/cli/io/testReadStringIntoFrame.txt +++ /dev/null @@ -1,5 +0,0 @@ -Frame(4x3, [a:int8_t, b:uint8_t, c:std::string]) -1 255 -2 254 -3 253 multi-line, -4 252 simple string diff --git a/test/api/cli/io/testReadStringMatrix.daphne b/test/api/cli/io/testReadStringMatrix.daphne deleted file mode 100644 index f3499212f..000000000 --- a/test/api/cli/io/testReadStringMatrix.daphne +++ /dev/null @@ -1,3 +0,0 @@ -// Test reading matrix of strings - -print(readMatrix("test/api/cli/io/ReadCsv2.csv")); \ No newline at end of file diff --git a/test/api/cli/io/testReadStringMatrix.txt b/test/api/cli/io/testReadStringMatrix.txt deleted file mode 100644 index fc4085f8c..000000000 --- a/test/api/cli/io/testReadStringMatrix.txt +++ /dev/null @@ -1,11 +0,0 @@ -DenseMatrix(9x3, std::string) -banana, grape 36 Fruit Basket -xyz""uvw 34 No Category\" -parrot, rabbit 31 Pets -line1 -line2 26 with newline -chair 28 Furniture Set -green, yellow\n 51 -\n\"xyz""uvw\" 29 Mixed string -\"blue, \"\" 42 -unknown, item 23 Unknown Item diff --git a/test/api/cli/io/write/write_frame_mixed-no-str.daphne b/test/api/cli/io/write/write_frame_mixed-no-str.daphne new file mode 100644 index 000000000..e7f1261e7 --- /dev/null +++ b/test/api/cli/io/write/write_frame_mixed-no-str.daphne @@ -0,0 +1,8 @@ +# Write a frame with columns of various value types (not including string) to a file. + +# TODO Test nan values. + +X = {"c_si64": [0, 1, -22, 3, -44], + "c_f64": [0.0, 1.1, -22.2, inf, -inf], + "c_si64_": [1000, 2000, 3000, 4000, 5000]}; +write(X, $outPath); \ No newline at end of file diff --git a/test/api/cli/io/write/write_frame_mixed-str.daphne b/test/api/cli/io/write/write_frame_mixed-str.daphne new file mode 100644 index 000000000..eed5a7f51 --- /dev/null +++ b/test/api/cli/io/write/write_frame_mixed-str.daphne @@ -0,0 +1,8 @@ +# Write a frame with columns of various value types (including string) to a file. + +# TODO Test nan values. + +X = {"c_si64": [0, 1, -22, 3, -44], + "c_f64": [0.0, 1.1, -22.2, inf, -inf], + "c_str": ["abc", "", "d\"e", "fg\nhi", "mn,op"]}; +write(X, $outPath); \ No newline at end of file diff --git a/test/api/cli/io/write/write_matrix_f64.daphne b/test/api/cli/io/write/write_matrix_f64.daphne new file mode 100644 index 000000000..69982599f --- /dev/null +++ b/test/api/cli/io/write/write_matrix_f64.daphne @@ -0,0 +1,5 @@ +# Write a matrix of value type f64 to a file. + +X = [1.1, -22.2, 3.3, -44.4, 5.5, + -66.6, 0.0, nan, inf, -inf](2, 5); +write(X, $outPath); \ No newline at end of file diff --git a/test/api/cli/io/write/write_matrix_si64.daphne b/test/api/cli/io/write/write_matrix_si64.daphne new file mode 100644 index 000000000..aaebe897b --- /dev/null +++ b/test/api/cli/io/write/write_matrix_si64.daphne @@ -0,0 +1,5 @@ +# Write a matrix of value type si64 to a file. + +X = [1, -22, 3, -44, + 5, -66, 0, 0](2, 4); +write(X, $outPath); \ No newline at end of file diff --git a/test/api/cli/io/write/write_matrix_str.daphne b/test/api/cli/io/write/write_matrix_str.daphne new file mode 100644 index 000000000..289052275 --- /dev/null +++ b/test/api/cli/io/write/write_matrix_str.daphne @@ -0,0 +1,5 @@ +# Write a matrix of value type str to a file. + +X = ["abc", "", "d\"e", + "fg\nhi", "jkl", "mn,op"](2, 3); +write(X, $outPath); \ No newline at end of file diff --git a/test/api/cli/io/write/write_matrix_view.daphne b/test/api/cli/io/write/write_matrix_view.daphne new file mode 100644 index 000000000..d5b1f55fb --- /dev/null +++ b/test/api/cli/io/write/write_matrix_view.daphne @@ -0,0 +1,6 @@ +# Write a matrix which is a view into another one to a file. + +X = [0, 1, -22, 3, -44, 0, + 0, 5, -66, 0, 0, 0](2, 6); +X = X[, 1:ncol(X)-1]; # omit the first and last column +write(X, $outPath); \ No newline at end of file diff --git a/test/api/cli/io/writeMatrix_full.daphne b/test/api/cli/io/writeMatrix_full.daphne deleted file mode 100644 index a66f44b22..000000000 --- a/test/api/cli/io/writeMatrix_full.daphne +++ /dev/null @@ -1,4 +0,0 @@ -// Write an entire small matrix to the file $outPath. - -X = reshape(seq(1, 6, 1), 3, 2); -write(X, $outPath); \ No newline at end of file diff --git a/test/api/cli/io/writeMatrix_view.daphne b/test/api/cli/io/writeMatrix_view.daphne deleted file mode 100644 index 90538095d..000000000 --- a/test/api/cli/io/writeMatrix_view.daphne +++ /dev/null @@ -1,5 +0,0 @@ -// Write a view into a small matrix to the file $outPath. - -X = reshape(seq(1, 6, 1), 3, 2); -X = X[, 1]; # only the 2nd column -write(X, $outPath); \ No newline at end of file diff --git a/test/runtime/local/io/ReadCsvTest.cpp b/test/runtime/local/io/ReadCsvTest.cpp index eed5e30e0..fd68ece99 100644 --- a/test/runtime/local/io/ReadCsvTest.cpp +++ b/test/runtime/local/io/ReadCsvTest.cpp @@ -226,7 +226,7 @@ TEST_CASE("ReadCsv, frame of numbers and strings", TAG_IO) { CHECK(m->getColumn(2)->get(1, 0) == "sample,"); CHECK(m->getColumn(2)->get(2, 0) == "line1\nline2"); CHECK(m->getColumn(2)->get(3, 0) == ""); - CHECK(m->getColumn(2)->get(4, 0) == "\"\"\\n\\\"abc\"\"def\\\""); + CHECK(m->getColumn(2)->get(4, 0) == "\"\\n\\\"abc\"def\\\""); CHECK(m->getColumn(2)->get(5, 0) == ""); CHECK(m->getColumn(3)->get(0, 0) == 444); @@ -337,10 +337,10 @@ TEMPLATE_PRODUCT_TEST_CASE("ReadCsv", TAG_IO, (DenseMatrix), (ALL_STRING_VALUE_T CHECK(m->get(0, 0) == "apple, orange"); CHECK(m->get(1, 0) == "dog, cat"); CHECK(m->get(2, 0) == "table"); - CHECK(m->get(3, 0) == "\"\""); - CHECK(m->get(4, 0) == "abc\"\"def"); + CHECK(m->get(3, 0) == "\""); + CHECK(m->get(4, 0) == "abc\"def"); CHECK(m->get(5, 0) == "red, blue\\n"); - CHECK(m->get(6, 0) == "\\n\\\"abc\"\"def\\\""); + CHECK(m->get(6, 0) == "\\n\\\"abc\"def\\\""); CHECK(m->get(7, 0) == "line1\nline2"); CHECK(m->get(8, 0) == "\\\"red, \\\"\\\""); From 3afd47a30cb4743328aedb799f47c3826fa13a1e Mon Sep 17 00:00:00 2001 From: Patrick Damme Date: Thu, 12 Dec 2024 19:05:40 +0100 Subject: [PATCH 14/17] CSV writer supports string data. - The CSV writer supports matrices of std::string value type and frames with columns of std::string value type. - Strings written to CSV files are properly quoted if necessary, i.e., if the string contains the value separator, newlines, or quotation marks. Quotes inside quoted strings are escaped by doubling them. - Added an instantiation of the write-kernel for string-valued matrices (frames with string columns are covered by the general write-kernel for frames). - Added script-level test cases. --- src/runtime/local/io/WriteCsv.h | 50 ++++++++++++++++++++++++-- src/runtime/local/kernels/kernels.json | 1 + test/api/cli/io/ReadWriteTest.cpp | 6 ++-- 3 files changed, 52 insertions(+), 5 deletions(-) diff --git a/src/runtime/local/io/WriteCsv.h b/src/runtime/local/io/WriteCsv.h index 34f4efcc0..e5b55b804 100644 --- a/src/runtime/local/io/WriteCsv.h +++ b/src/runtime/local/io/WriteCsv.h @@ -48,6 +48,42 @@ template struct WriteCsv { template void writeCsv(const DTArg *arg, File *file) { WriteCsv::apply(arg, file); } +// **************************************************************************** +// Helper functions +// **************************************************************************** + +/** + * @brief Returns a CSV value representation of the given string. + * + * If necessary, the given string is wrapped in quotation marks (`"`) and quoation marks occurring in the string are + * escaped by doubling them, i.e., `"` is replaced by `""`. Quoting the string is always allowed in CSV, but only + * *necessary* and done by this function when the string contains the value separator, newlines, or quotes. If the + * string does not need to be quoted, it is returned as it is. + * + * @param s The string to convert to a CSV value representation. + * @return The CSV value representation of the given string. + */ +std::string quoteStrCsvIf(const std::string &s) { + if (s.find_first_of(",\n\r\"") != std::string::npos) { + // String needs to be quoted. + // Inside the quoted string, quotes ('"') must be escaped by duplicating them. + std::stringstream strm; + strm << '"'; + for (size_t i = 0; i < s.length(); i++) { + char c = s[i]; + if (c == '"') + strm << '"' << '"'; + else + strm << c; + } + strm << '"'; + return strm.str(); + } else { + // String does not need to be quoted. + return s; + } +} + // **************************************************************************** // (Partial) template specializations for different data/value types // **************************************************************************** @@ -67,9 +103,13 @@ template struct WriteCsv> { for (size_t i = 0; i < arg->getNumRows(); ++i) { for (size_t j = 0; j < argNumCols; ++j) { - fprintf(file->identifier, - std::is_floating_point::value ? "%f" : (std::is_same::value ? "%ld" : "%d"), - valuesArg[i * rowSkip + j]); + if constexpr (std::is_same::value) + fprintf(file->identifier, "%s", quoteStrCsvIf(valuesArg[i * rowSkip + j]).c_str()); + else + fprintf(file->identifier, + std::is_floating_point::value ? "%f" + : (std::is_same::value ? "%ld" : "%d"), + valuesArg[i * rowSkip + j]); if (j < (arg->getNumCols() - 1)) fprintf(file->identifier, ","); else @@ -125,6 +165,10 @@ template <> struct WriteCsv { case ValueTypeCode::F64: fprintf(file->identifier, "%f", reinterpret_cast(array)[i]); break; + case ValueTypeCode::STR: + fprintf(file->identifier, "%s", + quoteStrCsvIf(reinterpret_cast(array)[i]).c_str()); + break; default: throw std::runtime_error("unknown value type code"); } diff --git a/src/runtime/local/kernels/kernels.json b/src/runtime/local/kernels/kernels.json index 250a70845..c46b29439 100644 --- a/src/runtime/local/kernels/kernels.json +++ b/src/runtime/local/kernels/kernels.json @@ -4029,6 +4029,7 @@ [["DenseMatrix", "double"]], [["DenseMatrix", "int64_t"]], [["DenseMatrix", "uint8_t"]], + [["DenseMatrix", "std::string"]], ["Frame"] ] }, diff --git a/test/api/cli/io/ReadWriteTest.cpp b/test/api/cli/io/ReadWriteTest.cpp index 9a8f045b9..e782ac80d 100644 --- a/test/api/cli/io/ReadWriteTest.cpp +++ b/test/api/cli/io/ReadWriteTest.cpp @@ -96,6 +96,8 @@ MAKE_READ_TEST_CASE_2("frame_dynamic-path-1") std::filesystem::remove(outPath); /* remove old output file if it still exists */ \ checkDaphneStatusCode(StatusCode::SUCCESS, scriptPathWrt.c_str(), "--args", \ ("outPath=\"" + outPath + "\"").c_str()); \ + /* TODO REQUIRE() the status code to be success (don't only CHECK() it), because in case of failure, */ \ + /* the next check doesn't make sense and might produce a misleading output. */ \ compareDaphneToStr("0\n", scriptPathCmp.c_str(), "--args", \ ("chkPath=\"" + outPath + "\",refPath=\"" + refPath + "\",nanSafe=" + nanSafe).c_str()); \ } @@ -103,7 +105,7 @@ MAKE_READ_TEST_CASE_2("frame_dynamic-path-1") // TODO The commented out test cases don't work yet, as the CSV writer doesn't support strings yet. MAKE_WRITE_TEST_CASE("matrix", "si64", "false") MAKE_WRITE_TEST_CASE("matrix", "f64", "true") -// MAKE_WRITE_TEST_CASE("matrix", "str", "false") +MAKE_WRITE_TEST_CASE("matrix", "str", "false") MAKE_WRITE_TEST_CASE("matrix", "view", "false") MAKE_WRITE_TEST_CASE("frame", "mixed-no-str", "false") -// MAKE_WRITE_TEST_CASE("frame", "mixed-str", "false") \ No newline at end of file +MAKE_WRITE_TEST_CASE("frame", "mixed-str", "false") \ No newline at end of file From a3748141adc98d0544416492b04c48feb3b31c24 Mon Sep 17 00:00:00 2001 From: Samin <69201742+saminbassiri@users.noreply.github.com> Date: Thu, 12 Dec 2024 19:59:00 +0100 Subject: [PATCH 15/17] String value type support in more kernels. - EwBinaryOpCode: Updated specification which elementwise binary ops are supported on which string value types (std::string, FixedStr16, const char *). - Needed to generalize some macros from one common value type for both operands to individual value types, e.g., to support "matrix of std::string op string scalar". - Fixed a small type bug in EwBinaryObjSca-kernel. - ExtractRow- and FilterRow-kernels on frames work with frames with std::string columns. - Added instantiations of EwBinaryObjSca-kernel for comparisons of string matrices with string scalar (integer result). - Order- and Group-kernels can handle frames with std::string columns. - These changes were originally included in PRs #918, #921, and #926 by @saminbassiri. - @pdamme is committing them in the name of @saminbassiri (for correct attribution) in a separate commit, since they don't really fit the topics of those PRs. --- src/runtime/local/kernels/BinaryOpCode.h | 39 +++++--- src/runtime/local/kernels/EwBinaryObjSca.h | 2 +- src/runtime/local/kernels/ExtractRow.h | 27 ++++-- src/runtime/local/kernels/FilterRow.h | 21 +++-- src/runtime/local/kernels/Group.h | 101 ++++++++++++++++----- src/runtime/local/kernels/Order.h | 21 +++-- src/runtime/local/kernels/kernels.json | 5 + 7 files changed, 153 insertions(+), 63 deletions(-) diff --git a/src/runtime/local/kernels/BinaryOpCode.h b/src/runtime/local/kernels/BinaryOpCode.h index 30023586c..9fc55fd53 100644 --- a/src/runtime/local/kernels/BinaryOpCode.h +++ b/src/runtime/local/kernels/BinaryOpCode.h @@ -132,16 +132,20 @@ static constexpr bool supportsBinaryOp = false; SUPPORT(BITWISE_AND, VT) // Generates code specifying that all binary operations of a certain category should be -// supported on the given argument value type `VTArg` (for the left and right-hand-side -// arguments, for simplicity) and the given result value type `VTRes`. -#define SUPPORT_COMPARISONS_RA(VTRes, VTArg) \ +// supported on the given argument value types `VTLhs` and `VTRhs` (for the left and right-hand-side +// arguments, respectively) and the given result value type `VTRes`. +#define SUPPORT_COMPARISONS_RLR(VTRes, VTLhs, VTRhs) \ /* string Comparisons operations. */ \ - SUPPORT_RLR(LT, VTRes, VTArg, VTArg) \ - SUPPORT_RLR(GT, VTRes, VTArg, VTArg) -#define SUPPORT_EQUALITY_RA(VTRes, VTArg) \ + SUPPORT_RLR(LT, VTRes, VTLhs, VTRhs) \ + SUPPORT_RLR(GT, VTRes, VTLhs, VTRhs) +#define SUPPORT_COMPARISONS_EQUAL_RLR(VTRes, VTLhs, VTRhs) \ /* string Comparisons operations. */ \ - SUPPORT_RLR(EQ, VTRes, VTArg, VTArg) \ - SUPPORT_RLR(NEQ, VTRes, VTArg, VTArg) + SUPPORT_RLR(LE, VTRes, VTLhs, VTRhs) \ + SUPPORT_RLR(GE, VTRes, VTLhs, VTRhs) +#define SUPPORT_EQUALITY_RLR(VTRes, VTLhs, VTRhs) \ + /* string Comparisons operations. */ \ + SUPPORT_RLR(EQ, VTRes, VTLhs, VTRhs) \ + SUPPORT_RLR(NEQ, VTRes, VTLhs, VTRhs) #define SUPPORT_STRING_RLR(VTRes, VTLhs, VTRhs) \ /* string concatenation operations. */ \ /* Since the result may not fit in FixedStr16,*/ \ @@ -175,11 +179,15 @@ SUPPORT_NUMERIC_INT(uint64_t) SUPPORT_NUMERIC_INT(uint32_t) SUPPORT_NUMERIC_INT(uint8_t) // Strings binary operations. -SUPPORT_EQUALITY_RA(int64_t, std::string) -SUPPORT_EQUALITY_RA(int64_t, FixedStr16) -SUPPORT_EQUALITY_RA(int64_t, const char *) -SUPPORT_COMPARISONS_RA(int64_t, std::string) -SUPPORT_COMPARISONS_RA(int64_t, FixedStr16) +SUPPORT_EQUALITY_RLR(int64_t, std::string, std::string) +SUPPORT_EQUALITY_RLR(int64_t, FixedStr16, FixedStr16) +SUPPORT_EQUALITY_RLR(int64_t, const char *, const char *) +SUPPORT_EQUALITY_RLR(int64_t, std::string, const char *) +SUPPORT_COMPARISONS_RLR(int64_t, std::string, std::string) +SUPPORT_COMPARISONS_RLR(int64_t, FixedStr16, FixedStr16) +SUPPORT_COMPARISONS_RLR(int64_t, std::string, const char *) +SUPPORT_COMPARISONS_EQUAL_RLR(int64_t, std::string, std::string) +SUPPORT_COMPARISONS_EQUAL_RLR(int64_t, std::string, const char *) SUPPORT_STRING_RLR(std::string, std::string, std::string) SUPPORT_STRING_RLR(std::string, FixedStr16, FixedStr16) SUPPORT_STRING_RLR(const char *, const char *, const char *) @@ -195,6 +203,7 @@ SUPPORT_STRING_RLR(std::string, std::string, const char *) #undef SUPPORT_BITWISE #undef SUPPORT_NUMERIC_FP #undef SUPPORT_NUMERIC_INT -#undef SUPPORT_EQUALITY_RA -#undef SUPPORT_COMPARISONS_RA +#undef SUPPORT_EQUALITY_RLR +#undef SUPPORT_COMPARISONS_RLR +#undef SUPPORT_COMPARISONS_EQUAL_RLR #undef SUPPORT_STRING_RLR diff --git a/src/runtime/local/kernels/EwBinaryObjSca.h b/src/runtime/local/kernels/EwBinaryObjSca.h index 61b5fa8e8..62d704ded 100644 --- a/src/runtime/local/kernels/EwBinaryObjSca.h +++ b/src/runtime/local/kernels/EwBinaryObjSca.h @@ -63,7 +63,7 @@ struct EwBinaryObjSca, DenseMatrix, VTRhs> { if (res == nullptr) res = DataObjectFactory::create>(numRows, numCols, false); - const VTRes *valuesLhs = lhs->getValues(); + const VTLhs *valuesLhs = lhs->getValues(); VTRes *valuesRes = res->getValues(); EwBinaryScaFuncPtr func = getEwBinaryScaFuncPtr(opCode); diff --git a/src/runtime/local/kernels/ExtractRow.h b/src/runtime/local/kernels/ExtractRow.h index 442bccdd0..a4367d198 100644 --- a/src/runtime/local/kernels/ExtractRow.h +++ b/src/runtime/local/kernels/ExtractRow.h @@ -118,16 +118,23 @@ template struct ExtractRow { throw std::out_of_range(errMsg.str()); } for (size_t c = 0; c < numCols; c++) { - // We always copy in units of 8 bytes (uint64_t). If the - // actual element size is lower, the superfluous bytes will - // be overwritten by the next match. With this approach, we - // do not need to call memcpy for each element, nor - // interpret the types for a L/S of fitting size. - // TODO Don't multiply by elementSize, but left-shift by - // ld(elementSize). - *reinterpret_cast(resCols[c]) = - *reinterpret_cast(argCols[c] + pos * elementSizes[c]); - resCols[c] += elementSizes[c]; + if (schema[c] == ValueTypeCode::STR) { + // Handle std::string column + *reinterpret_cast(resCols[c]) = + *reinterpret_cast(argCols[c] + pos * elementSizes[c]); + resCols[c] += elementSizes[c]; + } else { + // We always copy in units of 8 bytes (uint64_t). If the + // actual element size is lower, the superfluous bytes will + // be overwritten by the next match. With this approach, we + // do not need to call memcpy for each element, nor + // interpret the types for a L/S of fitting size. + // TODO Don't multiply by elementSize, but left-shift by + // ld(elementSize). + *reinterpret_cast(resCols[c]) = + *reinterpret_cast(argCols[c] + pos * elementSizes[c]); + resCols[c] += elementSizes[c]; + } } } res->shrinkNumRows(numRowsSel); diff --git a/src/runtime/local/kernels/FilterRow.h b/src/runtime/local/kernels/FilterRow.h index 0c53ec2ab..7b6a8f257 100644 --- a/src/runtime/local/kernels/FilterRow.h +++ b/src/runtime/local/kernels/FilterRow.h @@ -131,13 +131,20 @@ template struct FilterRow { for (size_t r = 0; r < numRows; r++) { if (valuesSel[r]) { for (size_t c = 0; c < numCols; c++) { - // We always copy in units of 8 bytes (uint64_t). If the - // actual element size is lower, the superfluous bytes will - // be overwritten by the next match. With this approach, we - // do not need to call memcpy for each element, nor - // interpret the types for a L/S of fitting size. - *reinterpret_cast(resCols[c]) = *reinterpret_cast(argCols[c]); - resCols[c] += elementSizes[c]; + if (schema[c] == ValueTypeCode::STR) { + // Handle std::string column + *reinterpret_cast(resCols[c]) = + *reinterpret_cast(argCols[c]); // Deep copy the string + resCols[c] += elementSizes[c]; + } else { + // We always copy in units of 8 bytes (uint64_t). If the + // actual element size is lower, the superfluous bytes will + // be overwritten by the next match. With this approach, we + // do not need to call memcpy for each element, nor + // interpret the types for a L/S of fitting size. + *reinterpret_cast(resCols[c]) = *reinterpret_cast(argCols[c]); + resCols[c] += elementSizes[c]; + } } } for (size_t c = 0; c < numCols; c++) diff --git a/src/runtime/local/kernels/Group.h b/src/runtime/local/kernels/Group.h index 0d8df6410..5e192d93b 100644 --- a/src/runtime/local/kernels/Group.h +++ b/src/runtime/local/kernels/Group.h @@ -58,6 +58,24 @@ void group(DT *&res, const DT *arg, const char **keyCols, size_t numKeyCols, con // Frame <- Frame // ---------------------------------------------------------------------------- +// TODO If possible, reuse the stringifyGroupEnum() from the DAPHNE compiler. +inline std::string myStringifyGroupEnum(mlir::daphne::GroupEnum val) { + using mlir::daphne::GroupEnum; + switch (val) { + case GroupEnum::COUNT: + return "COUNT"; + case GroupEnum::SUM: + return "SUM"; + case GroupEnum::MIN: + return "MIN"; + case GroupEnum::MAX: + return "MAX"; + case GroupEnum::AVG: + return "AVG"; + } + throw std::runtime_error("invalid GroupEnum value"); +} + // returns the result of the aggregation function aggFunc over the (contiguous) // memory between the begin and end pointer template @@ -65,26 +83,60 @@ VTRes aggregate(const mlir::daphne::GroupEnum &aggFunc, const VTArg *begin, cons using mlir::daphne::GroupEnum; switch (aggFunc) { case GroupEnum::COUNT: - return end - begin; + if constexpr (std::is_same::value) + throw std::invalid_argument(std::string("aggregate: ") + myStringifyGroupEnum(aggFunc) + + std::string(" aggregation is not supported for these value types.")); + else + return end - begin; break; // TODO: Do we need to check for Null elements here? case GroupEnum::SUM: - return std::accumulate(begin, end, (VTRes)0); + if constexpr ((std::is_same::value) || (std::is_same::value)) + throw std::invalid_argument(std::string("aggregate: ") + myStringifyGroupEnum(aggFunc) + + std::string(" aggregation is not supported for these value types.")); + else + return std::accumulate(begin, end, (VTRes)0); break; case GroupEnum::MIN: - return *std::min_element(begin, end); + if constexpr ((std::is_same::value) || (std::is_same::value)) + throw std::invalid_argument(std::string("aggregate: ") + myStringifyGroupEnum(aggFunc) + + std::string(" aggregation is not supported for these value types.")); + else + return *std::min_element(begin, end); break; case GroupEnum::MAX: - return *std::max_element(begin, end); + if constexpr ((std::is_same::value) || (std::is_same::value)) + throw std::invalid_argument(std::string("aggregate: ") + myStringifyGroupEnum(aggFunc) + + std::string(" aggregation is not supported for these value types.")); + else + return *std::max_element(begin, end); break; case GroupEnum::AVG: - return std::accumulate(begin, end, (double)0) / (double)(end - begin); + if constexpr ((std::is_same::value) || (std::is_same::value)) + throw std::invalid_argument(std::string("aggregate: ") + myStringifyGroupEnum(aggFunc) + + std::string(" aggregation is not supported for these value types.")); + else + return std::accumulate(begin, end, (double)0) / (double)(end - begin); break; default: - return *begin; + if constexpr (std::is_same::value || std::is_same::value) + throw std::invalid_argument("aggregate: Unsupported aggregation operation for string types."); + else + return *begin; break; } } +template <> +std::string aggregate(const mlir::daphne::GroupEnum &aggFunc, const std::string *begin, const std::string *end) { + using mlir::daphne::GroupEnum; + if (aggFunc == GroupEnum::MIN) + return *std::min_element(begin, end); + if (aggFunc == GroupEnum::MAX) + return *std::max_element(begin, end); + else + return *begin; +} + // struct which calls the aggregate() function (specified via aggFunc) on each // duplicate group in the groups vector and on all implied single groups for a // sepcified column (colIdx) of the argument frame (arg) and stores the result @@ -117,22 +169,14 @@ template struct ColumnGroupAgg { } }; -inline std::string myStringifyGroupEnum(mlir::daphne::GroupEnum val) { - using mlir::daphne::GroupEnum; - switch (val) { - case GroupEnum::COUNT: - return "COUNT"; - case GroupEnum::SUM: - return "SUM"; - case GroupEnum::MIN: - return "MIN"; - case GroupEnum::MAX: - return "MAX"; - case GroupEnum::AVG: - return "AVG"; +// Since DeduceValueTypeAndExecute can not handle string values, +// we add special ColumnGroupAgg function for arg with std::string values. +template struct ColumnGroupAggStringVTArg { + static void apply(Frame *res, const Frame *arg, size_t colIdx, std::vector> *groups, + mlir::daphne::GroupEnum aggFunc, DCTX(ctx)) { + ColumnGroupAgg::apply(res, arg, colIdx, groups, aggFunc, ctx); } - return ""; -} +}; template <> struct Group { static void apply(Frame *&res, const Frame *arg, const char **keyCols, size_t numKeyCols, const char **aggCols, @@ -270,9 +314,18 @@ template <> struct Group { // copying key columns and column-wise group aggregation for (size_t i = 0; i < numColsRes; i++) { - DeduceValueTypeAndExecute::apply( - res->getSchema()[i], ordered->getSchema()[i], res, ordered, i, groups, - (i < numKeyCols) ? (GroupEnum)0 : aggFuncs[i - numKeyCols], ctx); + if (ordered->getSchema()[i] == ValueTypeCode::STR) { + if (res->getSchema()[i] == ValueTypeCode::STR) + ColumnGroupAgg::apply( + res, ordered, i, groups, (i < numKeyCols) ? (GroupEnum)0 : aggFuncs[i - numKeyCols], ctx); + else + DeduceValueTypeAndExecute::apply( + res->getSchema()[i], res, ordered, i, groups, + (i < numKeyCols) ? (GroupEnum)0 : aggFuncs[i - numKeyCols], ctx); + } else + DeduceValueTypeAndExecute::apply( + res->getSchema()[i], ordered->getSchema()[i], res, ordered, i, groups, + (i < numKeyCols) ? (GroupEnum)0 : aggFuncs[i - numKeyCols], ctx); } delete groups; DataObjectFactory::destroy(ordered); diff --git a/src/runtime/local/kernels/Order.h b/src/runtime/local/kernels/Order.h index afa6d1a10..5b78e6a6a 100644 --- a/src/runtime/local/kernels/Order.h +++ b/src/runtime/local/kernels/Order.h @@ -197,8 +197,11 @@ struct OrderFrame { if (numColIdxs > 1) { for (size_t i = 0; i < numColIdxs - 1; i++) { - DeduceValueTypeAndExecute::apply(arg->getSchema()[colIdxs[i]], arg, idx, groups, - ascending[i], colIdxs[i], ctx); + if (arg->getSchema()[colIdxs[i]] == ValueTypeCode::STR) + MultiColumnIDSort::apply(arg, idx, groups, ascending[i], colIdxs[i], ctx); + else + DeduceValueTypeAndExecute::apply(arg->getSchema()[colIdxs[i]], arg, idx, groups, + ascending[i], colIdxs[i], ctx); } } @@ -206,11 +209,17 @@ struct OrderFrame { // use size_t colIdx = colIdxs[numColIdxs - 1]; if (groupsRes == nullptr) { - DeduceValueTypeAndExecute::apply(arg->getSchema()[colIdx], arg, idx, groups, - ascending[numColIdxs - 1], colIdx, ctx); + if (arg->getSchema()[colIdx] == ValueTypeCode::STR) + ColumnIDSort::apply(arg, idx, groups, ascending[numColIdxs - 1], colIdx, ctx); + else + DeduceValueTypeAndExecute::apply(arg->getSchema()[colIdx], arg, idx, groups, + ascending[numColIdxs - 1], colIdx, ctx); } else { - DeduceValueTypeAndExecute::apply(arg->getSchema()[colIdx], arg, idx, groups, - ascending[numColIdxs - 1], colIdx, ctx); + if (arg->getSchema()[colIdx] == ValueTypeCode::STR) + MultiColumnIDSort::apply(arg, idx, groups, ascending[numColIdxs - 1], colIdx, ctx); + else + DeduceValueTypeAndExecute::apply(arg->getSchema()[colIdx], arg, idx, groups, + ascending[numColIdxs - 1], colIdx, ctx); groupsRes->insert(groupsRes->end(), groups.begin(), groups.end()); } } diff --git a/src/runtime/local/kernels/kernels.json b/src/runtime/local/kernels/kernels.json index c46b29439..19597655e 100644 --- a/src/runtime/local/kernels/kernels.json +++ b/src/runtime/local/kernels/kernels.json @@ -2197,6 +2197,11 @@ ["DenseMatrix", "std::string"], ["DenseMatrix", "std::string"], "const char *" + ], + [ + ["DenseMatrix", "int64_t"], + ["DenseMatrix", "std::string"], + "const char *" ] ], "opCodes": [ From bd1f04bed2d01d202a2b91ef25a6dd435f396de1 Mon Sep 17 00:00:00 2001 From: Samin <69201742+saminbassiri@users.noreply.github.com> Date: Thu, 12 Dec 2024 22:41:59 +0100 Subject: [PATCH 16/17] Hash-based implementation for innerJoin-kernel (#926) - The existing innerJoin-kernel was very inefficient as it was a nested-loop-join and incurred a significant function call overhead per value. Furthermore, it suffered from several memory leaks related to the use of Frame::getColumn(). - This commit replaces the innerJoin-kernel implementation by a hash-based one. - Adapted a unit test case of the innerJoin-kernel to trigger a case where the build side of the join is not unique. --- src/runtime/local/kernels/InnerJoin.h | 166 +++++++++++-------- test/runtime/local/kernels/InnerJoinTest.cpp | 18 +- 2 files changed, 110 insertions(+), 74 deletions(-) diff --git a/src/runtime/local/kernels/InnerJoin.h b/src/runtime/local/kernels/InnerJoin.h index 55451a4ba..e0c8be1a3 100644 --- a/src/runtime/local/kernels/InnerJoin.h +++ b/src/runtime/local/kernels/InnerJoin.h @@ -1,3 +1,19 @@ +/* + * Copyright 2024 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + #ifndef SRC_RUNTIME_LOCAL_KERNELS_INNERJOIN_H #define SRC_RUNTIME_LOCAL_KERNELS_INNERJOIN_H @@ -10,7 +26,8 @@ #include #include -#include +#include +#include #include #include @@ -30,43 +47,85 @@ template void innerJoinSet(ValueTypeCode vtcType, Frame *&res, const Frame *arg, const int64_t toRow, const int64_t toCol, const int64_t fromRow, const int64_t fromCol, DCTX(ctx)) { if (vtcType == ValueTypeUtils::codeFor) { - innerJoinSetValue(res->getColumn(toCol), arg->getColumn(fromCol), toRow, fromRow, ctx); + const DenseMatrix *colArg = arg->getColumn(fromCol); + DenseMatrix *colRes = res->getColumn(toCol); + innerJoinSetValue(colRes, colArg, toRow, fromRow, ctx); + DataObjectFactory::destroy(colArg, colRes); } } -template -bool innerJoinEqual( - // results - Frame *&res, - // arguments - const DenseMatrix *argLhs, const DenseMatrix *argRhs, const int64_t targetLhs, - const int64_t targetRhs, - // context - DCTX(ctx)) { - const VTLhs l = argLhs->get(targetLhs, 0); - const VTRhs r = argRhs->get(targetRhs, 0); - return l == r; +// Create a hash table for rhs +template +std::unordered_map> BuildHashRhs(const Frame *rhs, const char *rhsOn, + const size_t numRowRhs) { + std::unordered_map> res; + const DenseMatrix *col = rhs->getColumn(rhsOn); + for (size_t row_idx_r = 0; row_idx_r < numRowRhs; row_idx_r++) { + VTRhs key = col->get(row_idx_r, 0); + res[key].push_back(row_idx_r); + } + DataObjectFactory::destroy(col); + return res; } -template -bool innerJoinProbeIf( - // value type known only at run-time - ValueTypeCode vtcLhs, ValueTypeCode vtcRhs, - // results - Frame *&res, +template +int64_t ProbeHashLhs( + // results and results schema + Frame *&res, ValueTypeCode *schema, // input frames const Frame *lhs, const Frame *rhs, // input column names - const char *lhsOn, const char *rhsOn, - // input rows - const int64_t targetL, const int64_t targetR, + const char *lhsOn, + // num columns + const size_t numColRhs, const size_t numColLhs, // context - DCTX(ctx)) { - if (vtcLhs == ValueTypeUtils::codeFor && vtcRhs == ValueTypeUtils::codeFor) { - return innerJoinEqual(res, lhs->getColumn(lhsOn), rhs->getColumn(rhsOn), targetL, - targetR, ctx); + DCTX(ctx), + // hashed map of Rhs + std::unordered_map> hashRhsIndex, + // Lhs rowa + const size_t numRowLhs) { + int64_t row_idx_res = 0; + int64_t col_idx_res = 0; + auto lhsFKCol = lhs->getColumn(lhsOn); + for (size_t row_idx_l = 0; row_idx_l < numRowLhs; row_idx_l++) { + auto key = lhsFKCol->get(row_idx_l, 0); + auto it = hashRhsIndex.find(key); + + if (it != hashRhsIndex.end()) { + for (size_t row_idx_r : it->second) { + col_idx_res = 0; + + // Populate result row from lhs columns + for (size_t idx_c = 0; idx_c < numColLhs; idx_c++) { + innerJoinSet(schema[col_idx_res], res, lhs, row_idx_res, col_idx_res, row_idx_l, idx_c, + ctx); + innerJoinSet(schema[col_idx_res], res, lhs, row_idx_res, col_idx_res, row_idx_l, idx_c, + ctx); + innerJoinSet(schema[col_idx_res], res, lhs, row_idx_res, col_idx_res, row_idx_l, idx_c, + ctx); + + col_idx_res++; + } + + // Populate result row from rhs columns + for (size_t idx_c = 0; idx_c < numColRhs; idx_c++) { + innerJoinSet(schema[col_idx_res], res, rhs, row_idx_res, col_idx_res, row_idx_r, idx_c, + ctx); + innerJoinSet(schema[col_idx_res], res, rhs, row_idx_res, col_idx_res, row_idx_r, idx_c, + ctx); + + innerJoinSet(schema[col_idx_res], res, rhs, row_idx_res, col_idx_res, row_idx_r, idx_c, + ctx); + + col_idx_res++; + } + + row_idx_res++; + } + } } - return false; + DataObjectFactory::destroy(lhsFKCol); + return row_idx_res; } // **************************************************************************** @@ -84,10 +143,8 @@ inline void innerJoin( int64_t numRowRes, // context DCTX(ctx)) { - // Find out the value types of the columns to process. ValueTypeCode vtcLhsOn = lhs->getColumnType(lhsOn); - ValueTypeCode vtcRhsOn = rhs->getColumnType(rhsOn); // Perhaps check if res already allocated. const size_t numRowRhs = rhs->getNumRows(); @@ -100,55 +157,34 @@ inline void innerJoin( const std::string *oldlabels_r = rhs->getLabels(); int64_t col_idx_res = 0; - int64_t row_idx_res = 0; + int64_t row_idx_res; + // Set up schema and labels ValueTypeCode schema[totalCols]; std::string newlabels[totalCols]; - // Setting Schema and Labels for (size_t col_idx_l = 0; col_idx_l < numColLhs; col_idx_l++) { schema[col_idx_res] = lhs->getColumnType(col_idx_l); - newlabels[col_idx_res] = oldlabels_l[col_idx_l]; - col_idx_res++; + newlabels[col_idx_res++] = oldlabels_l[col_idx_l]; } for (size_t col_idx_r = 0; col_idx_r < numColRhs; col_idx_r++) { schema[col_idx_res] = rhs->getColumnType(col_idx_r); - newlabels[col_idx_res] = oldlabels_r[col_idx_r]; - col_idx_res++; + newlabels[col_idx_res++] = oldlabels_r[col_idx_r]; } - // Creating Result Frame + // Initialize result frame with an estimate res = DataObjectFactory::create(totalRows, totalCols, schema, newlabels, false); - for (size_t row_idx_l = 0; row_idx_l < numRowLhs; row_idx_l++) { - for (size_t row_idx_r = 0; row_idx_r < numRowRhs; row_idx_r++) { - col_idx_res = 0; - // PROBE ROWS - bool hit = false; - hit = hit || innerJoinProbeIf(vtcLhsOn, vtcRhsOn, res, lhs, rhs, lhsOn, rhsOn, row_idx_l, - row_idx_r, ctx); - hit = hit || innerJoinProbeIf(vtcLhsOn, vtcRhsOn, res, lhs, rhs, lhsOn, rhsOn, row_idx_l, - row_idx_r, ctx); - if (hit) { - for (size_t idx_c = 0; idx_c < numColLhs; idx_c++) { - innerJoinSet(schema[col_idx_res], res, lhs, row_idx_res, col_idx_res, row_idx_l, idx_c, - ctx); - innerJoinSet(schema[col_idx_res], res, lhs, row_idx_res, col_idx_res, row_idx_l, idx_c, - ctx); - col_idx_res++; - } - for (size_t idx_c = 0; idx_c < numColRhs; idx_c++) { - innerJoinSet(schema[col_idx_res], res, rhs, row_idx_res, col_idx_res, row_idx_r, idx_c, - ctx); - - innerJoinSet(schema[col_idx_res], res, rhs, row_idx_res, col_idx_res, row_idx_r, idx_c, - ctx); - col_idx_res++; - } - row_idx_res++; - } - } + // Build hash table and prob left table + if (vtcLhsOn == ValueTypeCode::STR) { + row_idx_res = ProbeHashLhs(res, schema, lhs, rhs, lhsOn, numColRhs, numColLhs, ctx, + BuildHashRhs(rhs, rhsOn, numRowRhs), numRowLhs); + } else { + row_idx_res = ProbeHashLhs(res, schema, lhs, rhs, lhsOn, numColRhs, numColLhs, ctx, + BuildHashRhs(rhs, rhsOn, numRowRhs), numRowLhs); } + // Shrink result frame to actual size res->shrinkNumRows(row_idx_res); } + #endif // SRC_RUNTIME_LOCAL_KERNELS_INNERJOIN_H diff --git a/test/runtime/local/kernels/InnerJoinTest.cpp b/test/runtime/local/kernels/InnerJoinTest.cpp index 6e4c264c1..c2ee660f0 100644 --- a/test/runtime/local/kernels/InnerJoinTest.cpp +++ b/test/runtime/local/kernels/InnerJoinTest.cpp @@ -38,9 +38,9 @@ TEST_CASE("InnerJoin", TAG_KERNELS) { std::string lhsLabels[] = {"a", "b"}; auto lhs = DataObjectFactory::create(lhsCols, lhsLabels); - auto rhsC0 = genGivenVals>(3, {1, 4, 5}); - auto rhsC1 = genGivenVals>(3, {-1, -4, -5}); - auto rhsC2 = genGivenVals>(3, {0.1, 0.2, 0.3}); + auto rhsC0 = genGivenVals>(4, {1, 4, 5, 4}); + auto rhsC1 = genGivenVals>(4, {-1, -4, -5, -6}); + auto rhsC2 = genGivenVals>(4, {0.1, 0.2, 0.3, 0.4}); std::vector rhsCols = {rhsC0, rhsC1, rhsC2}; std::string rhsLabels[] = {"c", "d", "e"}; auto rhs = DataObjectFactory::create(rhsCols, rhsLabels); @@ -49,7 +49,7 @@ TEST_CASE("InnerJoin", TAG_KERNELS) { innerJoin(res, lhs, rhs, "a", "c", -1, nullptr); // Check the meta data. - CHECK(res->getNumRows() == 2); + CHECK(res->getNumRows() == 3); CHECK(res->getNumCols() == 5); CHECK(res->getColumnType(0) == ValueTypeCode::SI64); @@ -64,11 +64,11 @@ TEST_CASE("InnerJoin", TAG_KERNELS) { CHECK(res->getLabels()[3] == "d"); CHECK(res->getLabels()[4] == "e"); - auto resC0Exp = genGivenVals>(2, {1, 4}); - auto resC1Exp = genGivenVals>(2, {11.0, 44.0}); - auto resC2Exp = genGivenVals>(2, {1, 4}); - auto resC3Exp = genGivenVals>(2, {-1, -4}); - auto resC4Exp = genGivenVals>(2, {0.1, 0.2}); + auto resC0Exp = genGivenVals>(3, {1, 4, 4}); + auto resC1Exp = genGivenVals>(3, {11.0, 44.0, 44.0}); + auto resC2Exp = genGivenVals>(3, {1, 4, 4}); + auto resC3Exp = genGivenVals>(3, {-1, -4, -6}); + auto resC4Exp = genGivenVals>(3, {0.1, 0.2, 0.4}); CHECK(*(res->getColumn(0)) == *resC0Exp); CHECK(*(res->getColumn(1)) == *resC1Exp); From ab157d907f8278b862ac1f7cf047acb877f63a76 Mon Sep 17 00:00:00 2001 From: Patrick Damme Date: Fri, 13 Dec 2024 00:42:52 +0100 Subject: [PATCH 17/17] Performance improvement of the FilterCol-kernel on DenseMatrix. - This commit changes the algorithm used to do the main work of the kernel. - The previous algorithm evaluated all entries of the selection bit vector for every row of the input matrix. - The new algorithm creates a vector of positions from the input bit vector first and evaluates only those positions for every row of the input matrix. - With the value types we currently use, the new approach should always be faster than the old one; but the trade-offs may change with other value types for the bit vector and positions (depending on the number of 1 bits). - With the old algorithm, FilterCol was the most expensive operation in decision trees/random forests for a particular use-case script; the new algorithm speeds up that script by 4x on my machine. --- src/runtime/local/kernels/FilterCol.h | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/src/runtime/local/kernels/FilterCol.h b/src/runtime/local/kernels/FilterCol.h index 3e9671456..1d6b68227 100644 --- a/src/runtime/local/kernels/FilterCol.h +++ b/src/runtime/local/kernels/FilterCol.h @@ -73,6 +73,16 @@ template struct FilterCol, DenseMa VT *valuesRes = res->getValues(); const size_t rowSkipArg = arg->getRowSkip(); const size_t rowSkipRes = res->getRowSkip(); + + // Two alternative approaches for doing the main work. + // Note that even though sel is essentially a bit vector, we represent it as a 64-bit integer at the moment, for + // simplicity. As both the elements in sel used in approach 1 and the positions in approach 2 are currently + // represented as 64-bit integers, approach 2 should always be faster than approach 1, because in approach 2 we + // iterate over at most as many 64-bit values as in approach 1 while we can omit the check. Once we change to a + // 1-bit representation for the values in sel, we should rethink the trade-off between the two approaches. +#if 0 // approach 1 + // For every row in arg, iterate over all elements in sel (one per column in arg), check if the respective + // column should be part of the output and if so, copy the value to the output. for (size_t r = 0; r < numRows; r++) { for (size_t ca = 0, cr = 0; ca < numColsArg; ca++) if (sel->get(ca, 0)) @@ -80,6 +90,22 @@ template struct FilterCol, DenseMa valuesArg += rowSkipArg; valuesRes += rowSkipRes; } +#else // approach 2 + // Once in the beginning, create a vector of the positions of the columns we want to copy to the output. + // Negligible effort, unless the number of rows in arg is very small. + std::vector positions; + for (size_t c = 0; c < numColsArg; c++) + if (sel->get(c, 0)) + positions.push_back(c); + // For every row in arg, iterate over the array of those positions and copy the value in the respective column + // to the output. + for (size_t r = 0; r < numRows; r++) { + for (size_t i = 0; i < positions.size(); i++) + valuesRes[i] = valuesArg[positions[i]]; + valuesArg += rowSkipArg; + valuesRes += rowSkipRes; + } +#endif } };