diff --git a/UserConfig.json b/UserConfig.json index 316c22938..4754c4fd9 100644 --- a/UserConfig.json +++ b/UserConfig.json @@ -24,6 +24,9 @@ "explain_vectorized": false, "explain_obj_ref_mgnt": false, "explain_mlir_codegen": false, + "explain_mlir_codegen_sparsity_exploiting_op_fusion": false, + "explain_mlir_codegen_daphneir_to_mlir": false, + "explain_mlir_codegen_mlir_specific": false, "taskPartitioningScheme": "STATIC", "numberOfThreads": -1, "minimumTaskSize": 1, diff --git a/doc/DaphneDSL/Builtins.md b/doc/DaphneDSL/Builtins.md index e78eaeb64..535003f0d 100644 --- a/doc/DaphneDSL/Builtins.md +++ b/doc/DaphneDSL/Builtins.md @@ -490,16 +490,24 @@ We will support set operations such as **`intersect`**, **`merge`**, and **`exce - **`cartesian`**`(lhs:frame, rhs:frame)` - Calculates the cartesian (cross) product of the two input frames. + Calculates the cartesian product of the two input frames. -- **`innerJoin`**`(lhs:frame, rhs:frame, lhsOn:str, rhsOn:str)` +- **`innerJoin`**`(lhs:frame, rhs:frame, lhsOn:str, rhsOn:str[, numRowRes:si64])` Performs an inner join of the two input frames on `lhs`.`lhsOn` == `rhs`.`rhsOn`. -- **`semiJoin`**`(lhs:frame, rhs:frame, lhsOn:str, rhsOn:str)` + The parameter `numRowRes` is an optional hint for an upper bound of the number or result rows. + If specified, it determines the number of rows that will be allocated for the result, whereby `-1` stands for an automatically chosen size. + Otherwise, it defaults to `-1`. + +- **`semiJoin`**`(lhs:frame, rhs:frame, lhsOn:str, rhsOn:str[, numRowRes:si64])` Performs a semi join of the two input frames on `lhs`.`lhsOn` == `rhs`.`rhsOn`. Returns only the columns belonging to `lhs`. + + The parameter `numRowRes` is an optional hint for an upper bound of the number or result rows. + If specified, it determines the number of rows that will be allocated for the result, whereby `-1` stands for an automatically chosen size. + Otherwise, it defaults to `-1`. - **`groupJoin`**`(lhs:frame, rhs:frame, lhsOn:str, rhsOn:str, rhsAgg:str)` @@ -507,6 +515,14 @@ We will support set operations such as **`intersect`**, **`merge`**, and **`exce We will support more variants of joins, including (left/right) outer joins, theta joins, anti-joins, etc. +### Grouping and aggregation + +- **`groupSum`**`(arg:frame, grpColNames:str[, grpColNames, ...], sumColName:str)` + + Groups the rows in the given frame `arg` by the specified columns `grpColNames` (at least one column) and calculates the per-group sum of the column denoted by `sumColName`. + + *This built-in function is currently limited in terms of functionality (aggregation only on a single column, sum as the only aggregation function). It will be extended in the future. Meanwhile, consider using DAPHNE's `sql()` built-in function for more comprehensive grouping and aggregation support.* + ### Frame label manipulation - **`setColLabels`**`(arg:frame, labels:str, ...)` diff --git a/doc/DaphneDSL/Imports.md b/doc/DaphneDSL/Imports.md index 1c9470c5d..0802efede 100644 --- a/doc/DaphneDSL/Imports.md +++ b/doc/DaphneDSL/Imports.md @@ -40,6 +40,8 @@ print(utils.x); } ``` +NOTE: to use a user config, the json file path needs to be passed as CLI arg to the DAPHNE binary `daphne --config=` + NOTE: `default_dirs` can hold many paths and it will look for the **one** specified file in each, whereas any other library names have a list consisting of **one** directory, from which **all** files will be imported (can be easily extended to multiple directories). Example: diff --git a/doc/development/BuildingDaphne.md b/doc/development/BuildingDaphne.md index 2a1d6bd87..baddcfa0b 100644 --- a/doc/development/BuildingDaphne.md +++ b/doc/development/BuildingDaphne.md @@ -132,6 +132,10 @@ All possible options for the build script: --- +## Building on WSL + +When using Windows Subsystems for Linux (WSL), the default memory limit for WSL is 50% of the total memory of the underlying Windows host. This can lead to build fails due to SIGKILL for DAPHNE builds. [Advanced settings configuration in WSL](https://learn.microsoft.com/en-us/windows/wsl/wsl-config) describes how the memory limit can be configured. + ## Extension ### Overview over the build script diff --git a/doc/tutorial/sqlTutorial.md b/doc/tutorial/sqlTutorial.md index d6f2002dd..2b42e5b7a 100644 --- a/doc/tutorial/sqlTutorial.md +++ b/doc/tutorial/sqlTutorial.md @@ -56,7 +56,7 @@ Other features we do and don't support right now can be found below. ### Supported Features -* Cross Product +* SQL Cross Product (Cartesian Product) * Complex Where Clauses * Inner Join with single and multiple join conditions separated by an "AND" Operator * Group By Clauses diff --git a/src/api/cli/DaphneUserConfig.h b/src/api/cli/DaphneUserConfig.h index 13bfdf4b1..8409d7546 100644 --- a/src/api/cli/DaphneUserConfig.h +++ b/src/api/cli/DaphneUserConfig.h @@ -74,6 +74,9 @@ struct DaphneUserConfig { bool explain_vectorized = false; bool explain_obj_ref_mgnt = false; bool explain_mlir_codegen = false; + bool explain_mlir_codegen_sparsity_exploiting_op_fusion = false; + bool explain_mlir_codegen_daphneir_to_mlir = false; + bool explain_mlir_codegen_mlir_specific = false; bool statistics = false; bool force_cuda = false; diff --git a/src/api/internal/daphne_internal.cpp b/src/api/internal/daphne_internal.cpp index b63be6f0b..fc728ffcb 100644 --- a/src/api/internal/daphne_internal.cpp +++ b/src/api/internal/daphne_internal.cpp @@ -287,26 +287,34 @@ int startDAPHNE(int argc, const char **argv, DaphneLibResult *daphneLibRes, int type_adaptation, vectorized, obj_ref_mgnt, - mlir_codegen + mlir_codegen, + mlir_codegen_sparsity_exploiting_op_fusion, + mlir_codegen_daphneir_to_mlir, + mlir_codegen_mlir_specific }; static llvm::cl::list explainArgList( "explain", cat(daphneOptions), llvm::cl::desc("Show DaphneIR after certain compiler passes (separate " "multiple values by comma, the order is irrelevant)"), - llvm::cl::values(clEnumVal(parsing, "Show DaphneIR after parsing"), - clEnumVal(parsing_simplified, "Show DaphneIR after parsing and some simplifications"), - clEnumVal(sql, "Show DaphneIR after SQL parsing"), - clEnumVal(property_inference, "Show DaphneIR after property inference"), - clEnumVal(select_matrix_repr, "Show DaphneIR after selecting " - "physical matrix representations"), - clEnumVal(phy_op_selection, "Show DaphneIR after selecting physical operators"), - clEnumVal(type_adaptation, "Show DaphneIR after adapting types to available kernels"), - clEnumVal(vectorized, "Show DaphneIR after vectorization"), - clEnumVal(obj_ref_mgnt, "Show DaphneIR after managing object references"), - clEnumVal(kernels, "Show DaphneIR after kernel lowering"), - clEnumVal(llvm, "Show DaphneIR after llvm lowering"), - clEnumVal(mlir_codegen, "Show DaphneIR after MLIR codegen")), + llvm::cl::values( + clEnumVal(parsing, "Show DaphneIR after parsing"), + clEnumVal(parsing_simplified, "Show DaphneIR after parsing and some simplifications"), + clEnumVal(sql, "Show DaphneIR after SQL parsing"), + clEnumVal(property_inference, "Show DaphneIR after property inference"), + clEnumVal(select_matrix_repr, "Show DaphneIR after selecting " + "physical matrix representations"), + clEnumVal(phy_op_selection, "Show DaphneIR after selecting physical operators"), + clEnumVal(type_adaptation, "Show DaphneIR after adapting types to available kernels"), + clEnumVal(vectorized, "Show DaphneIR after vectorization"), + clEnumVal(obj_ref_mgnt, "Show DaphneIR after managing object references"), + clEnumVal(kernels, "Show DaphneIR after kernel lowering"), + clEnumVal(mlir_codegen, "Show DaphneIR after MLIR codegen"), + clEnumVal(mlir_codegen_sparsity_exploiting_op_fusion, + "Show DaphneIR after MLIR codegen (sparsity-exploiting operator fusion)"), + clEnumVal(mlir_codegen_daphneir_to_mlir, "Show DaphneIR after MLIR codegen (DaphneIR to MLIR)"), + clEnumVal(mlir_codegen_mlir_specific, "Show DaphneIR after MLIR codegen (MLIR-specific)"), + clEnumVal(llvm, "Show DaphneIR after llvm lowering")), CommaSeparated); static llvm::cl::list scriptArgs1("args", cat(daphneOptions), @@ -479,6 +487,15 @@ int startDAPHNE(int argc, const char **argv, DaphneLibResult *daphneLibRes, int case mlir_codegen: user_config.explain_mlir_codegen = true; break; + case mlir_codegen_sparsity_exploiting_op_fusion: + user_config.explain_mlir_codegen_sparsity_exploiting_op_fusion = true; + break; + case mlir_codegen_daphneir_to_mlir: + user_config.explain_mlir_codegen_daphneir_to_mlir = true; + break; + case mlir_codegen_mlir_specific: + user_config.explain_mlir_codegen_mlir_specific = true; + break; } } diff --git a/src/compiler/execution/DaphneIrExecutor.cpp b/src/compiler/execution/DaphneIrExecutor.cpp index 7d3a4fecb..0123923cc 100644 --- a/src/compiler/execution/DaphneIrExecutor.cpp +++ b/src/compiler/execution/DaphneIrExecutor.cpp @@ -262,6 +262,14 @@ void DaphneIrExecutor::buildCodegenPipeline(mlir::PassManager &pm) { pm.addPass(mlir::daphne::createPrintIRPass("IR before codegen pipeline")); pm.addPass(mlir::daphne::createDaphneOptPass()); + + pm.addPass(mlir::daphne::createSparsityExploitationPass()); + // SparseExploit fuses multiple operations which only need to be lowered if still needed elsewhere. + // Todo: if possible, run only if SparseExploitLowering was successful. + pm.addPass(mlir::createCanonicalizerPass()); + if (userConfig_.explain_mlir_codegen_sparsity_exploiting_op_fusion) + pm.addPass(mlir::daphne::createPrintIRPass("IR after MLIR codegen (sparsity-exploiting operator fusion):")); + pm.addPass(mlir::daphne::createEwOpLoweringPass()); pm.addPass(mlir::daphne::createAggAllOpLoweringPass()); pm.addPass(mlir::daphne::createAggDimOpLoweringPass()); @@ -287,6 +295,9 @@ void DaphneIrExecutor::buildCodegenPipeline(mlir::PassManager &pm) { pm.addPass(mlir::daphne::createPrintIRPass("IR directly after lowering MatMulOp.")); } + if (userConfig_.explain_mlir_codegen_daphneir_to_mlir) + pm.addPass(mlir::daphne::createPrintIRPass("IR after MLIR codegen (DaphneIR to MLIR):")); + pm.addPass(mlir::createConvertMathToLLVMPass()); pm.addPass(mlir::daphne::createModOpLoweringPass()); pm.addPass(mlir::createCanonicalizerPass()); @@ -307,4 +318,6 @@ void DaphneIrExecutor::buildCodegenPipeline(mlir::PassManager &pm) { if (userConfig_.explain_mlir_codegen) pm.addPass(mlir::daphne::createPrintIRPass("IR after codegen pipeline")); + if (userConfig_.explain_mlir_codegen_mlir_specific) + pm.addPass(mlir::daphne::createPrintIRPass("IR after MLIR codegen (MLIR-specific):")); } diff --git a/src/compiler/inference/SelectMatrixRepresentationsPass.cpp b/src/compiler/inference/SelectMatrixRepresentationsPass.cpp index f49fa05f1..32a2b6f67 100644 --- a/src/compiler/inference/SelectMatrixRepresentationsPass.cpp +++ b/src/compiler/inference/SelectMatrixRepresentationsPass.cpp @@ -23,7 +23,6 @@ #include #include -#include using namespace mlir; diff --git a/src/compiler/lowering/AggAllOpLowering.cpp b/src/compiler/lowering/AggAllOpLowering.cpp index b324e73fa..213da40df 100644 --- a/src/compiler/lowering/AggAllOpLowering.cpp +++ b/src/compiler/lowering/AggAllOpLowering.cpp @@ -18,6 +18,7 @@ #include #include "compiler/utils/LoweringUtils.h" +#include #include "ir/daphneir/Daphne.h" #include "ir/daphneir/Passes.h" @@ -90,8 +91,9 @@ class AggAllOpLowering : public OpConversionPattern { ssize_t numCols = matrixType.getNumCols(); if (numRows < 0 || numCols < 0) { - return rewriter.notifyMatchFailure( - op, "aggAllOp codegen currently only works with matrix dimensions that are known at compile time"); + throw ErrorHandler::compilerError( + loc, "AggAllOpLowering", + "aggAllOp codegen currently only works with matrix dimensions that are known at compile time"); } Type matrixElementType = matrixType.getElementType(); diff --git a/src/compiler/lowering/AggDimOpLowering.cpp b/src/compiler/lowering/AggDimOpLowering.cpp index 9f75934f9..8452354f1 100644 --- a/src/compiler/lowering/AggDimOpLowering.cpp +++ b/src/compiler/lowering/AggDimOpLowering.cpp @@ -18,6 +18,8 @@ #include #include "compiler/utils/LoweringUtils.h" +#include + #include "ir/daphneir/Daphne.h" #include "ir/daphneir/Passes.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" @@ -100,8 +102,9 @@ class AggDimOpLowering : public OpConversionPattern { ssize_t numCols = matrixType.getNumCols(); if (numRows < 0 || numCols < 0) { - return rewriter.notifyMatchFailure( - op, "aggDimOp codegen currently only works with matrix dimensions that are known at compile time"); + throw ErrorHandler::compilerError( + loc, "AggDimOpLowering", + "aggDimOp codegen currently only works with matrix dimensions that are known at compile time"); } Type matrixElementType = matrixType.getElementType(); @@ -236,8 +239,9 @@ class AggDimIdxOpLowering : public OpConversionPattern { ssize_t numCols = matrixType.getNumCols(); if (numRows < 0 || numCols < 0) { - return rewriter.notifyMatchFailure( - op, "aggDimOp codegen currently only works with matrix dimensions that are known at compile time"); + throw ErrorHandler::compilerError( + loc, "AggDimOpLowering", + "aggDimOp codegen currently only works with matrix dimensions that are known at compile time"); } Type matrixElementType = matrixType.getElementType(); diff --git a/src/compiler/lowering/CMakeLists.txt b/src/compiler/lowering/CMakeLists.txt index d9386808e..937b07d1a 100644 --- a/src/compiler/lowering/CMakeLists.txt +++ b/src/compiler/lowering/CMakeLists.txt @@ -34,6 +34,7 @@ add_mlir_dialect_library(MLIRDaphneTransforms AggAllOpLowering.cpp AggDimOpLowering.cpp TransposeOpLowering.cpp + SparsityExploitationPass.cpp SliceRowOpLowering.cpp SliceColOpLowering.cpp diff --git a/src/compiler/lowering/EwOpsLowering.cpp b/src/compiler/lowering/EwOpsLowering.cpp index 0762f171d..256795529 100644 --- a/src/compiler/lowering/EwOpsLowering.cpp +++ b/src/compiler/lowering/EwOpsLowering.cpp @@ -31,6 +31,7 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinTypes.h" @@ -91,8 +92,9 @@ template struct UnaryOpLowering : publi ssize_t numCols = matrixType.getNumCols(); if (numRows < 0 || numCols < 0) { - return rewriter.notifyMatchFailure( - op, "ewOps codegen currently only works with matrix dimensions that are known at compile time"); + throw ErrorHandler::compilerError( + loc, "EwOpsLowering (BinaryOp)", + "ewOps codegen currently only works with matrix dimensions that are known at compile time"); } Value argMemref = rewriter.create( @@ -142,14 +144,69 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { this->setDebugName("EwDaphneOpLowering"); } + /** + * @brief Returns an affine map for indexing the rhs operand. + * Assumes that neither matrix is a singleton and lhs is not broadcast. + * + * If rhs has no dimensions of size 1, returns an identity map. + * Else, returns a map (i,j)->(0,j) or (i,j)->(i,0) to enable broadcasting of rhs. + */ + AffineMap buildRhsAffineMap(Location loc, ConversionPatternRewriter &rewriter, ssize_t lhsRows, ssize_t lhsCols, + ssize_t rhsRows, ssize_t rhsCols) const { + + AffineMap rhsAffineMap; + + // lhs could also be a row/column vector which should not be handled as broadcasting (even though the resulting + // affine maps coincide). This allows for a clearer error message as well. + if (lhsRows != 1 && rhsRows == 1) { + // rhs is a row vector, broadcast along columns + if (lhsCols != rhsCols) { + throw ErrorHandler::compilerError( + loc, "EwOpsLowering (BinaryOp)", + "could not broadcast rhs along columns. Rhs must " + "be a scalar value, singleton matrix or have an equal amount of column to " + "be broadcast but operands have dimensions (" + + std::to_string(lhsRows) + "," + std::to_string(lhsCols) + ") and (" + std::to_string(rhsRows) + + "," + std::to_string(rhsCols) + ")"); + } + rhsAffineMap = AffineMap::get(2, 0, {rewriter.getAffineConstantExpr(0), rewriter.getAffineDimExpr(1)}, + rewriter.getContext()); + } else if (lhsCols != 1 && rhsCols == 1) { + // rhs is a column vector, broadcast along rows + if (lhsRows != rhsRows) { + throw ErrorHandler::compilerError( + loc, "EwOpsLowering (BinaryOp)", + "could not broadcast rhs along rows. Rhs must " + "be a scalar value, singleton matrix or have an equal amount of rows to " + "be broadcast but operands have dimensions (" + + std::to_string(lhsRows) + "," + std::to_string(lhsCols) + ") and (" + std::to_string(rhsRows) + + "," + std::to_string(rhsCols) + ")"); + } + rhsAffineMap = AffineMap::get(2, 0, {rewriter.getAffineDimExpr(0), rewriter.getAffineConstantExpr(0)}, + rewriter.getContext()); + } else { + // rhs is not broadcasted, return identity mapping + if (lhsRows != rhsRows || lhsCols != rhsCols) { + throw ErrorHandler::compilerError( + loc, "EwOpsLowering (BinaryOp)", + "lhs and rhs must have equal dimensions or allow for broadcasting but operands have dimensions (" + + std::to_string(lhsRows) + "," + std::to_string(lhsCols) + ") and (" + std::to_string(rhsRows) + + "," + std::to_string(rhsCols) + ")"); + } + rhsAffineMap = AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()); + } + + return rhsAffineMap; + } + LogicalResult matchAndRewriteScalarVal(BinaryOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { rewriter.replaceOp(op, binaryFunc(rewriter, op.getLoc(), this->typeConverter, adaptor.getLhs(), adaptor.getRhs())); return mlir::success(); } - LogicalResult matchAndRewriteBroadcastRhs(BinaryOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, - Value &rhs) const { + LogicalResult matchAndRewriteBroadcastScalarRhs(BinaryOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, + Value &rhs) const { Location loc = op->getLoc(); Value lhs = adaptor.getLhs(); @@ -160,7 +217,7 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { Type matrixElementType = lhsMatrixType.getElementType(); MemRefType argMemRefType = MemRefType::get({lhsRows, lhsCols}, matrixElementType); - Value lhsMemref = rewriter.create(loc, argMemRefType, lhs); + auto lhsMemref = rewriter.create(loc, argMemRefType, lhs); Value resMemref = rewriter.create(loc, argMemRefType); @@ -189,13 +246,13 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { auto lhsMatrixType = lhs.getType().template dyn_cast(); auto rhsMatrixType = rhs.getType().template dyn_cast(); - // Match Scalar-Scalar and Matrix-Scalar broadcasting (assuming scalar values are always switched to rhs). - // Broadcasting where either Matrix is a singleton needs to be handled separately below. + // Match Scalar-Scalar and Matrix-Scalar broadcasting (assuming scalar values are always switched to + // rhs). Broadcasting where either Matrix is a singleton or vector needs to be handled separately below. if (!rhsMatrixType) { if (!lhsMatrixType) { return matchAndRewriteScalarVal(op, adaptor, rewriter); } - return matchAndRewriteBroadcastRhs(op, adaptor, rewriter, rhs); + return matchAndRewriteBroadcastScalarRhs(op, adaptor, rewriter, rhs); } Type matrixElementType = lhsMatrixType.getElementType(); @@ -206,13 +263,15 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { ssize_t rhsCols = rhsMatrixType.getNumCols(); if (lhsRows < 0 || lhsCols < 0 || rhsRows < 0 || rhsCols < 0) { - return rewriter.notifyMatchFailure( - op, "ewOps codegen currently only works with matrix dimensions that are known at compile time"); + throw ErrorHandler::compilerError( + loc, "EwOpsLowering (BinaryOp)", + "ewOps codegen currently only works with matrix dimensions that are known at compile time"); } - // Assume that if only one matrix contains a single value for broadcasting it is rhs. + // For efficiency, broadcasting a singleton is handled separately here (assumes singleton is always rhs). + // Broadcasting of row/column vectors is handled during the construction of the index map for rhs below. if ((lhsRows != 1 || lhsCols != 1) && rhsRows == 1 && rhsCols == 1) { - Value rhsMemref = rewriter.create( + auto rhsMemref = rewriter.create( loc, MemRefType::get({1, 1}, matrixElementType), rhs); Value rhsBroadcastVal = rewriter @@ -220,25 +279,21 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { ValueRange{rewriter.create(loc, 0), rewriter.create(loc, 0)}) .getResult(); - return matchAndRewriteBroadcastRhs(op, adaptor, rewriter, rhsBroadcastVal); + return matchAndRewriteBroadcastScalarRhs(op, adaptor, rewriter, rhsBroadcastVal); } - if (lhsRows != rhsRows || lhsCols != rhsCols) { - throw ErrorHandler::compilerError(loc, "EwOpsLowering (BinaryOp)", - "lhs and rhs must have equal dimensions or either one must " - "be a scalar value but have dimensions (" + - std::to_string(lhsRows) + "," + std::to_string(lhsCols) + ") and (" + - std::to_string(rhsRows) + "," + std::to_string(rhsCols) + ")"); - } + MemRefType lhsMemRefType = MemRefType::get({lhsRows, lhsCols}, matrixElementType); + MemRefType rhsMemRefType = MemRefType::get({rhsRows, rhsCols}, matrixElementType); + auto lhsMemref = rewriter.create(loc, lhsMemRefType, lhs); + auto rhsMemref = rewriter.create(loc, rhsMemRefType, rhs); - MemRefType argMemRefType = MemRefType::get({lhsRows, lhsCols}, matrixElementType); - Value lhsMemref = rewriter.create(loc, argMemRefType, lhs); - Value rhsMemref = rewriter.create(loc, argMemRefType, rhs); - - Value resMemref = rewriter.create(loc, argMemRefType); + // If any broadcasting occurs, it is assumed to be rhs so res inherits its shape from lhs. + Value resMemref = rewriter.create(loc, lhsMemRefType); + // Builds an affine map to index the args and accounts for broadcasting of rhs. + // Creation of rhs indexing map checks whether or not the dimensions match and returns a compiler error if not. SmallVector indexMaps = {AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()), - AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()), + buildRhsAffineMap(loc, rewriter, lhsRows, lhsCols, rhsRows, rhsCols), AffineMap::getMultiDimIdentityMap(2, rewriter.getContext())}; SmallVector iterTypes = {utils::IteratorType::parallel, utils::IteratorType::parallel}; @@ -364,7 +419,8 @@ using CosOpLowering = UnaryOpLowering>; // Rounding -// Prior canonicalization pass removes rounding ops on integers, meaning only f32/f64 types need to be dealt with +// Prior canonicalization pass removes rounding ops on integers, meaning only f32/f64 types need to be dealt +// with using FloorOpLowering = UnaryOpLowering>; using CeilOpLowering = UnaryOpLowering>; using RoundOpLowering = UnaryOpLowering>; @@ -388,8 +444,10 @@ using MinOpLowering = // Logical // using AndOpLowering = -// BinaryOpLowering>; // distinguish AndFOp -// using OrOpLowering = BinaryOpLowering>; // - " - +// BinaryOpLowering>; // distinguish +// AndFOp +// using OrOpLowering = BinaryOpLowering>; +// // - " - // **************************************************************************** // General Pass Setup diff --git a/src/compiler/lowering/RewriteToCallKernelOpPass.cpp b/src/compiler/lowering/RewriteToCallKernelOpPass.cpp index f99c9e472..fe73b1acd 100644 --- a/src/compiler/lowering/RewriteToCallKernelOpPass.cpp +++ b/src/compiler/lowering/RewriteToCallKernelOpPass.cpp @@ -32,12 +32,13 @@ #include "mlir/IR/IRMapping.h" #include "mlir/IR/Location.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/ArrayRef.h" #include #include #include #include -#include +#include #include #include #include @@ -133,10 +134,19 @@ class KernelReplacement : public RewritePattern { return mlir::daphne::FrameType::get(mctx, {mlir::daphne::UnknownType::get(mctx)}); if (auto lt = t.dyn_cast()) return mlir::daphne::ListType::get(mctx, adaptType(lt.getElementType(), generalizeToStructure)); - if (auto mrt = t.dyn_cast()) - // Remove any dimension information ({0, 0}), but retain the element - // type. - return mlir::MemRefType::get({0, 0}, mrt.getElementType()); + if (auto mrt = t.dyn_cast()) { + // Remove any specific dimension information ({0}), but retain the rank and element type. + int64_t mrtRank = mrt.getRank(); + if (mrtRank == 1) { + return mlir::MemRefType::get({0}, mrt.getElementType()); + } else if (mrtRank == 2) { + return mlir::MemRefType::get({0, 0}, mrt.getElementType()); + } else { + throw std::runtime_error( + "RewriteToCallKernelOpPass: expected MemRef to be of rank 1 or 2 but was given " + + std::to_string(mrtRank)); + } + } return t; } diff --git a/src/compiler/lowering/SparsityExploitationPass.cpp b/src/compiler/lowering/SparsityExploitationPass.cpp new file mode 100644 index 000000000..f1f057873 --- /dev/null +++ b/src/compiler/lowering/SparsityExploitationPass.cpp @@ -0,0 +1,345 @@ +/* + * Copyright 2023 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include +#include +#include + +#include "ir/daphneir/Daphne.h" +#include "ir/daphneir/Passes.h" +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/LLVMCommon/LoweringOptions.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/DialectInterface.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeRange.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/ArrayRef.h" + +using namespace mlir; + +/** + * @brief sum(sparse * ln(dense @ dense)) where either dense matrix can be transposed + */ +class SparsityExploitation final : public mlir::OpConversionPattern { + public: + using OpConversionPattern::OpAdaptor; + + SparsityExploitation(TypeConverter &typeConverter, mlir::MLIRContext *ctx) + : mlir::OpConversionPattern(typeConverter, ctx) { + this->setDebugName("SparsityExploitation"); + } + + LogicalResult matchAndRewrite(daphne::AllAggSumOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + + // ------------------------------------------------------------------------------------ + // Pattern Matching allAgg( IntersectOp( sparseLhs, ln(denseLhs @ denseRhs) ) ) + // ------------------------------------------------------------------------------------ + + // Rewrite is only called after the pattern has already been matched during legalization, so the results are + // assumed not to be nullptr. + auto sparseIntersectOp = + adaptor.getArg().getDefiningOp(); // IntersectOp( sparseLhs, ln(denseLhs @ denseRhs) ) + + Value sparseLhs = sparseIntersectOp.getLhs(); + auto unaryOp = sparseIntersectOp.getRhs().getDefiningOp(); // ln(denseLhs @ denseRhs) + + auto denseMatmulOp = unaryOp.getArg().getDefiningOp(); // denseLhs @ denseRhs + + // daphne.matMul Op has 4 arguments, the latter two (`transa`, `transb`) are bools indicating whether + // either matrix should be accessed as though it was transposed. + Value denseLhs = denseMatmulOp.getLhs(); + Value denseRhs = denseMatmulOp.getRhs(); + bool transa = CompilerUtils::constantOrThrow( + denseMatmulOp.getTransa(), "SparseExploitLowering: expected transa to be known at compile-time"); + bool transb = CompilerUtils::constantOrThrow( + denseMatmulOp.getTransb(), "SparseExploitLowering: expected transb to be known at compile-time"); + + auto sparseLhsMatType = sparseLhs.getType().template dyn_cast(); + Type resElementType = sparseLhsMatType.getElementType(); + ssize_t sparseLhsRows = sparseLhsMatType.getNumRows(); + ssize_t sparseLhsCols = sparseLhsMatType.getNumCols(); + + auto denseLhsMatType = denseLhs.getType().template dyn_cast(); + Type dotProdElementType = sparseLhsMatType.getElementType(); + ssize_t denseLhsRows = denseLhsMatType.getNumRows(); + ssize_t denseLhsCols = denseLhsMatType.getNumCols(); + + auto denseRhsMatType = denseRhs.getType().template dyn_cast(); + ssize_t denseRhsRows = denseRhsMatType.getNumRows(); + ssize_t denseRhsCols = denseRhsMatType.getNumCols(); + + if (sparseLhsRows < 0 || sparseLhsCols < 0 || denseLhsRows < 0 || denseLhsCols < 0 || denseRhsRows < 0 || + denseRhsCols < 0) { + throw ErrorHandler::compilerError( + loc, "SparseExploitLowering", + "sparse exploit codegen currently only works with matrix dimensions that are known at compile time"); + } + + // Verify dense-dense Matmul and sparse-dense intersection op have matching dimensions. + if ((transa ? denseLhsRows : denseLhsCols) != (transb ? denseRhsCols : denseRhsRows)) { + throw ErrorHandler::compilerError( + loc, "SparseExploitLowering", + "dense-dense matrix multiplication operands must have matching inner dimension (given: " + "(" + + std::to_string(transa ? denseLhsCols : denseLhsRows) + "x" + + std::to_string(transa ? denseLhsRows : denseLhsCols) + ") and (" + + std::to_string(transb ? denseRhsCols : denseRhsRows) + "x" + + std::to_string(transb ? denseRhsRows : denseRhsCols) + ")"); + } + if (sparseLhsRows != (transa ? denseLhsCols : denseLhsRows) || + sparseLhsCols != (transb ? denseRhsRows : denseRhsCols)) { + throw ErrorHandler::compilerError( + loc, "SparseExploitLowering", + "sparse-dense intersectionOp operands must have equal dimensions (given: (" + + std::to_string(sparseLhsRows) + "x" + std::to_string(sparseLhsCols) + ") and (" + + std::to_string(transa ? denseLhsCols : denseLhsRows) + "x" + + std::to_string(transb ? denseRhsRows : denseRhsCols) + ")"); + } + + MemRefType denseLhsMemRefType = MemRefType::get({denseLhsRows, denseLhsCols}, denseLhsMatType.getElementType()); + MemRefType denseRhsMemRefType = MemRefType::get({denseRhsRows, denseRhsCols}, denseRhsMatType.getElementType()); + auto denseLhsMemRef = rewriter.create(loc, denseLhsMemRefType, denseLhs); + auto denseRhsMemRef = rewriter.create(loc, denseRhsMemRefType, denseRhs); + + MemRefType sparseLhsValuesMemRefType = + MemRefType::get({ShapedType::kDynamic}, sparseLhsMatType.getElementType()); + MemRefType sparseLhsColIdxsMemRefType = MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()); + MemRefType sparseLhsRowOffsetsMemRefType = MemRefType::get({sparseLhsRows + 1}, rewriter.getIndexType()); + + auto argValues = + rewriter.create(loc, sparseLhsValuesMemRefType, sparseLhs); + auto argColIdx = + rewriter.create(loc, sparseLhsColIdxsMemRefType, sparseLhs); + auto argRowPtr = + rewriter.create(loc, sparseLhsRowOffsetsMemRefType, sparseLhs); + + /* The loop nest below loops first over the row offsets, meaning over the rows of the CSRMatrix. + * + * Then, in a nested loop, over it iterates over the range created by the difference of an entry in + * row offsets and its successor, which corresponds to the column length of row the first offset was taken of. + * + * Finally, the inner-most loop simply computes the dot product of the dense matrices at the given + * row/column indices. Its result is passed to the first nested loop that computes the logarithm, + * multiplies it with the value at the same indices in the CSRMatrix, and passes it to the outer + * loop which performs a (thread safe) addition. + * + * Todo: Test parallelization of the two inner scf loops (replace with scf::parallel / scf::reduce instead of + * scf::yield) and enable lowering of parallel loops to different threads (e.g. using omp dialect/OpenMP + * library). + */ + auto outerLoop = rewriter.create(loc, TypeRange{resElementType}, + ArrayRef{arith::AtomicRMWKind::addf}, + ArrayRef{sparseLhsRows}); + rewriter.setInsertionPointToStart(outerLoop.getBody()); + { + Value rowPtr = rewriter.create(loc, argRowPtr, ValueRange{outerLoop.getIVs()[0]}); + Value nextRowPtr = rewriter.create( + loc, argRowPtr, AffineMap::get(1, 0, rewriter.getAffineDimExpr(0) + 1, rewriter.getContext()), + ValueRange{outerLoop.getIVs()[0]}); + + Value innerLoopAcc = rewriter.create(loc, rewriter.getZeroAttr(resElementType)); + + auto innerLoop = rewriter.create( + loc, rowPtr, nextRowPtr, rewriter.create(loc, 1), ValueRange{innerLoopAcc}, + [&](OpBuilder &OpBuilderNested, Location locNested, Value loopIdx, ValueRange loopInvariants) { + Value colIdx = OpBuilderNested.create(locNested, argColIdx, ValueRange{loopIdx}); + + Value resDotProd = OpBuilderNested.create( + locNested, OpBuilderNested.getZeroAttr(dotProdElementType)); + + auto dotProdLoop = OpBuilderNested.create( + locNested, rewriter.create(loc, 0), + rewriter.create(loc, denseLhsCols), + rewriter.create(loc, 1), ValueRange{resDotProd}, + [&](OpBuilder &OpBuilderTwiceNested, Location locTwiceNested, Value loopIdxNested, + ValueRange loopInvariantsNested) { + Value currentValLhs = OpBuilderTwiceNested.create( + locTwiceNested, denseLhsMemRef, + transa ? ValueRange{loopIdxNested, outerLoop.getIVs()[0]} + : ValueRange{outerLoop.getIVs()[0], loopIdxNested}); + Value currentValRhs = OpBuilderTwiceNested.create( + locTwiceNested, denseRhsMemRef, + transb ? ValueRange{colIdx, loopIdxNested} : ValueRange{loopIdxNested, colIdx}); + + Value accDotProd; + if (llvm::isa(dotProdElementType)) { + currentValLhs = convertToSignlessInt(OpBuilderTwiceNested, locTwiceNested, + typeConverter, currentValLhs, dotProdElementType); + currentValRhs = convertToSignlessInt(OpBuilderTwiceNested, locTwiceNested, + typeConverter, currentValRhs, dotProdElementType); + accDotProd = OpBuilderTwiceNested.create(locTwiceNested, currentValLhs, + currentValRhs); + accDotProd = OpBuilderTwiceNested.create(locTwiceNested, accDotProd, + loopInvariantsNested[0]); + } else { + accDotProd = OpBuilderTwiceNested.create( + locTwiceNested, currentValLhs, currentValRhs, loopInvariantsNested[0]); + } + + OpBuilderTwiceNested.create(locTwiceNested, ValueRange{accDotProd}); + }); // end dot product + + Value dotProdLoopRes = dotProdLoop.getResult(0); + if (llvm::isa(dotProdElementType)) { + dotProdLoopRes = this->typeConverter->materializeTargetConversion( + OpBuilderNested, locNested, resElementType, dotProdLoopRes); + } + Value mappedDotProd = OpBuilderNested.create(locNested, dotProdLoopRes); + + Value currentVal = + OpBuilderNested.create(locNested, argValues, ValueRange{loopIdx}); + + // After log has been applied, arg will always be a float type so FMA is no restriction + Value intersectAggRes = + OpBuilderNested.create(locNested, currentVal, mappedDotProd, loopInvariants[0]); + + OpBuilderNested.create(locNested, ValueRange{intersectAggRes}); + }); // end inner loop + + rewriter.create(loc, ValueRange{innerLoop.getResult(0)}); + } // end outer loop + rewriter.setInsertionPointAfter(outerLoop); + + rewriter.replaceOp(op, outerLoop->getResult(0)); + return success(); + } +}; + +// **************************************************************************** +// General Pass Setup +// **************************************************************************** + +namespace { +/** + * @brief This pass lowers a specific pattern of sparse-dense operations + * to a loop nest that avoids materializing potentially large dense intermediates. + * + * The matched pattern is sum(sparse * ln(dense @ dense)) + * where either dense matrix can be transposed. + */ +struct SparsityExploitationPass + : public mlir::PassWrapper> { + explicit SparsityExploitationPass() = default; + + void getDependentDialects(mlir::DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnOperation() final; + + [[nodiscard]] StringRef getArgument() const final { return "lower-sparse-exploit"; } + [[nodiscard]] StringRef getDescription() const final { + return "This pass lowers a sum(sparse * ln(dense @ dense)) pattern " + "to 3 nested loops that avoid materializing potentially large dense intermediates."; + } +}; +} // end anonymous namespace + +void SparsityExploitationPass::runOnOperation() { + mlir::ConversionTarget target(getContext()); + mlir::RewritePatternSet patterns(&getContext()); + mlir::LowerToLLVMOptions llvmOptions(&getContext()); + mlir::LLVMTypeConverter typeConverter(&getContext(), llvmOptions); + + typeConverter.addConversion(convertInteger); + typeConverter.addConversion(convertFloat); + typeConverter.addConversion([](Type type) { return type; }); + typeConverter.addArgumentMaterialization(materializeCastFromIllegal); + typeConverter.addSourceMaterialization(materializeCastToIllegal); + typeConverter.addTargetMaterialization(materializeCastFromIllegal); + + target.addLegalDialect(); + + // Marks Op as illegal only if it matches the pattern: sum(sparse * ln(dense @ dense)) + target.addDynamicallyLegalOp([](Operation *op) { + auto definingEwMulOp = op->getOperand(0).getDefiningOp(); + if (definingEwMulOp == nullptr) { + return true; + } + + Type lhsType = definingEwMulOp.getLhs().getType(); + auto lhsMatType = lhsType.dyn_cast(); + + Value rhs = definingEwMulOp.getRhs(); + Type rhsType = rhs.getType(); + auto rhsMatType = rhsType.dyn_cast(); + + if (!lhsMatType || !rhsMatType) { + return true; + } + if (lhsMatType.getRepresentation() != daphne::MatrixRepresentation::Sparse || + rhsMatType.getRepresentation() != daphne::MatrixRepresentation::Dense) { + return true; + } + + auto definingEwLnOp = rhs.getDefiningOp(); + if (definingEwLnOp == nullptr) { + return true; + } + + auto definingMatMulOp = definingEwLnOp.getArg().getDefiningOp(); + if (definingMatMulOp == nullptr) { + return true; + } + + Type matmulLhsType = definingMatMulOp.getLhs().getType(); + Type matmulRhsType = definingMatMulOp.getRhs().getType(); + auto matmulLhsMatType = matmulLhsType.dyn_cast(); + auto matmulRhsMatType = matmulRhsType.dyn_cast(); + + if (matmulLhsMatType.getRepresentation() != daphne::MatrixRepresentation::Dense || + matmulRhsMatType.getRepresentation() != daphne::MatrixRepresentation::Dense) { + return true; + } + + return false; + }); + + patterns.insert(typeConverter, &getContext()); + auto module = getOperation(); + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); +} + +std::unique_ptr daphne::createSparsityExploitationPass() { + return std::make_unique(); +} diff --git a/src/compiler/lowering/TransposeOpLowering.cpp b/src/compiler/lowering/TransposeOpLowering.cpp index 6203286d1..1ae536e5c 100644 --- a/src/compiler/lowering/TransposeOpLowering.cpp +++ b/src/compiler/lowering/TransposeOpLowering.cpp @@ -17,6 +17,7 @@ #include "compiler/utils/LoweringUtils.h" #include "ir/daphneir/Daphne.h" #include "ir/daphneir/Passes.h" +#include #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" @@ -80,8 +81,9 @@ class TransposeOpLowering : public OpConversionPattern { ssize_t numCols = matrixType.getNumCols(); if (numRows < 0 || numCols < 0) { - return rewriter.notifyMatchFailure( - op, "transposeOp codegen currently only works with matrix dimensions that are known at compile time"); + throw ErrorHandler::compilerError( + loc, "TransposeOpLowering", + "transposeOp codegen currently only works with matrix dimensions that are known at compile time"); } Value argMemref = rewriter.create( diff --git a/src/compiler/utils/CompilerUtils.h b/src/compiler/utils/CompilerUtils.h index 5a72d8bf2..9af38755b 100644 --- a/src/compiler/utils/CompilerUtils.h +++ b/src/compiler/utils/CompilerUtils.h @@ -212,7 +212,9 @@ struct CompilerUtils { return "Target"; else if (auto memRefType = t.dyn_cast()) { const std::string vtName = mlirTypeToCppTypeName(memRefType.getElementType(), angleBrackets, false); - return angleBrackets ? ("StridedMemRefType<" + vtName + ",2>") : ("StridedMemRefType_" + vtName + "_2"); + const std::string rankStr = std::to_string(memRefType.getRank()); + return angleBrackets ? ("StridedMemRefType<" + vtName + "," + rankStr + ">") + : ("StridedMemRefType_" + vtName + "_" + rankStr); } std::string typeName; diff --git a/src/ir/daphneir/DaphneOps.td b/src/ir/daphneir/DaphneOps.td index a9b437477..b2d91310b 100644 --- a/src/ir/daphneir/DaphneOps.td +++ b/src/ir/daphneir/DaphneOps.td @@ -91,6 +91,27 @@ def Daphne_ConvertDenseMatrixToMemRef : Daphne_Op<"convertDenseMatrixToMemRef", let results = (outs AnyMemRef:$output); } +def Daphne_ConvertCSRMatrixToValuesMemRef : Daphne_Op<"convertCSRMatrixToValuesMemRef", [Pure]> { + let summary = "Given a CSRMatrix, return a StridedMemRefType."; + let description = [{ Constructs a StridedMemRefType with rank 1 containing the values of a CSRMatrix* with already allocated memory. }]; + let arguments = (ins MatrixOrU:$arg); + let results = (outs AnyMemRef:$resValues); +} + +def Daphne_ConvertCSRMatrixToColIdxsMemRef : Daphne_Op<"convertCSRMatrixToColIdxsMemRef", [Pure]> { + let summary = "Given a CSRMatrix, return a StridedMemRefType."; + let description = [{ Constructs a StridedMemRefType with rank 1 containing the column indices of a CSRMatrix* with already allocated memory. }]; + let arguments = (ins MatrixOrU:$arg); + let results = (outs AnyMemRef:$resColIdxs); +} + +def Daphne_ConvertCSRMatrixToRowOffsetsMemRef : Daphne_Op<"convertCSRMatrixToRowOffsetsMemRef", [Pure]> { + let summary = "Given a CSRMatrix, return a StridedMemRefType."; + let description = [{ Constructs a StridedMemRefType with rank 1 containing the row offsets of a CSRMatrix* with already allocated memory. }]; + let arguments = (ins MatrixOrU:$arg); + let results = (outs AnyMemRef:$resRowOffsets); +} + // **************************************************************************** // Data generation // **************************************************************************** @@ -195,7 +216,7 @@ def Daphne_MatMulOp : Daphne_Op<"matMul", [ DataTypeMat, ValueTypeFromArgs, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, CUDASupport, FPGAOPENCLSupport, - CastFirstTwoArgsToResType + CastFirstTwoArgsToResType, NoMemoryEffect ]> { let arguments = (ins MatrixOf<[NumScalar]>:$lhs, MatrixOf<[NumScalar]>:$rhs, BoolScalar:$transa, BoolScalar:$transb); let results = (outs MatrixOf<[NumScalar]>:$res); @@ -1076,7 +1097,7 @@ def Daphne_InnerJoinOp : Daphne_Op<"innerJoin", [ DataTypeFrm, ValueTypesConcat, DeclareOpInterfaceMethods, ]> { - let arguments = (ins FrameOrU:$lhs, FrameOrU:$rhs, StrScalar:$lhsOn, StrScalar:$rhsOn); + let arguments = (ins FrameOrU:$lhs, FrameOrU:$rhs, StrScalar:$lhsOn, StrScalar:$rhsOn, Size:$numRowRes); let results = (outs FrameOrU:$res); } @@ -1120,7 +1141,7 @@ def Daphne_SemiJoinOp : Daphne_Op<"semiJoin", [ DeclareOpInterfaceMethods, NumColsFromArg ]> { - let arguments = (ins FrameOrU:$lhs, FrameOrU:$rhs, StrScalar:$lhsOn, StrScalar:$rhsOn); + let arguments = (ins FrameOrU:$lhs, FrameOrU:$rhs, StrScalar:$lhsOn, StrScalar:$rhsOn, Size:$numRowRes); let results = (outs FrameOrU:$res, MatrixOf<[Size]>:$lhsTids); } diff --git a/src/ir/daphneir/Passes.h b/src/ir/daphneir/Passes.h index ad11c5e67..5ff3755bd 100644 --- a/src/ir/daphneir/Passes.h +++ b/src/ir/daphneir/Passes.h @@ -47,6 +47,7 @@ std::unique_ptr createDaphneOptPass(); std::unique_ptr createDistributeComputationsPass(); std::unique_ptr createDistributePipelinesPass(); std::unique_ptr createEwOpLoweringPass(); +std::unique_ptr createSparsityExploitationPass(); std::unique_ptr createInferencePass(InferenceConfig cfg = {false, true, true, true, true}); std::unique_ptr createInsertDaphneContextPass(const DaphneUserConfig &cfg); std::unique_ptr createLowerToLLVMPass(const DaphneUserConfig &cfg); diff --git a/src/ir/daphneir/Passes.td b/src/ir/daphneir/Passes.td index d1676dd95..9aad66829 100644 --- a/src/ir/daphneir/Passes.td +++ b/src/ir/daphneir/Passes.td @@ -256,6 +256,9 @@ def LowerEwOpPass: Pass<"lower-ew", "::mlir::func::FuncOp"> { let constructor = "mlir::daphne::createEwOpLoweringPass()"; } +def SparsityExploitationPass: Pass<"lower-sparse-exploit", "::mlir::func::FuncOp"> { + let constructor = "mlir::daphne::createSparsityExploitationPass()"; +} def SliceRowOpLoweringPass: Pass<"lower-slice-row", "::mlir::func::FuncOp"> { let constructor = "mlir::daphne::createSliceRowOpLoweringPass()"; } diff --git a/src/parser/catalog/KernelCatalogParser.cpp b/src/parser/catalog/KernelCatalogParser.cpp index 7560e7f4e..3c4c0fcfb 100644 --- a/src/parser/catalog/KernelCatalogParser.cpp +++ b/src/parser/catalog/KernelCatalogParser.cpp @@ -72,10 +72,11 @@ KernelCatalogParser::KernelCatalogParser(mlir::MLIRContext *mctx) { // MemRef type. if (!st.isa()) { // DAPHNE's StringType is not supported as the element type of a - // MemRef. The dimensions of the MemRef are irrelevant here, so we - // use {0, 0}. + // MemRef. The dimensions of the MemRef are irrelevant here. mlir::Type mrt = mlir::MemRefType::get({0, 0}, st); typeMap.emplace(CompilerUtils::mlirTypeToCppTypeName(mrt), mrt); + typeMap.emplace(CompilerUtils::mlirTypeToCppTypeName(mlir::MemRefType::get({0}, st)), + mlir::MemRefType::get({0}, st)); } } @@ -134,4 +135,4 @@ void KernelCatalogParser::parseKernelCatalog(const std::string &filePath, Kernel } catch (std::exception &e) { throw std::runtime_error("error while parsing kernel catalog file `" + filePath + "`: " + e.what()); } -} \ No newline at end of file +} diff --git a/src/parser/config/ConfigParser.cpp b/src/parser/config/ConfigParser.cpp index 2a7b6a28c..47d3def14 100644 --- a/src/parser/config/ConfigParser.cpp +++ b/src/parser/config/ConfigParser.cpp @@ -104,6 +104,15 @@ void ConfigParser::readUserConfig(const std::string &filename, DaphneUserConfig config.explain_obj_ref_mgnt = jf.at(DaphneConfigJsonParams::EXPLAIN_OBJ_REF_MGNT).get(); if (keyExists(jf, DaphneConfigJsonParams::EXPLAIN_MLIR_CODEGEN)) config.explain_mlir_codegen = jf.at(DaphneConfigJsonParams::EXPLAIN_MLIR_CODEGEN).get(); + if (keyExists(jf, DaphneConfigJsonParams::EXPLAIN_MLIR_CODEGEN_SPARSITY_EXPLOITING_OP_FUSION)) + config.explain_mlir_codegen_sparsity_exploiting_op_fusion = + jf.at(DaphneConfigJsonParams::EXPLAIN_MLIR_CODEGEN_SPARSITY_EXPLOITING_OP_FUSION).get(); + if (keyExists(jf, DaphneConfigJsonParams::EXPLAIN_MLIR_CODEGEN_DAPHNEIR_TO_MLIR)) + config.explain_mlir_codegen_daphneir_to_mlir = + jf.at(DaphneConfigJsonParams::EXPLAIN_MLIR_CODEGEN_DAPHNEIR_TO_MLIR).get(); + if (keyExists(jf, DaphneConfigJsonParams::EXPLAIN_MLIR_CODEGEN_MLIR_SPECIFIC)) + config.explain_mlir_codegen_mlir_specific = + jf.at(DaphneConfigJsonParams::EXPLAIN_MLIR_CODEGEN_MLIR_SPECIFIC).get(); if (keyExists(jf, DaphneConfigJsonParams::TASK_PARTITIONING_SCHEME)) { config.taskPartitioningScheme = jf.at(DaphneConfigJsonParams::TASK_PARTITIONING_SCHEME).get(); diff --git a/src/parser/config/JsonParams.h b/src/parser/config/JsonParams.h index d2c964cf9..3a6717497 100644 --- a/src/parser/config/JsonParams.h +++ b/src/parser/config/JsonParams.h @@ -54,6 +54,10 @@ struct DaphneConfigJsonParams { inline static const std::string EXPLAIN_VECTORIZED = "explain_vectorized"; inline static const std::string EXPLAIN_OBJ_REF_MGNT = "explain_obj_ref_mgnt"; inline static const std::string EXPLAIN_MLIR_CODEGEN = "explain_mlir_codegen"; + inline static const std::string EXPLAIN_MLIR_CODEGEN_SPARSITY_EXPLOITING_OP_FUSION = + "explain_mlir_codegen_sparsity_exploiting_op_fusion"; + inline static const std::string EXPLAIN_MLIR_CODEGEN_DAPHNEIR_TO_MLIR = "explain_mlir_codegen_daphneir_to_mlir"; + inline static const std::string EXPLAIN_MLIR_CODEGEN_MLIR_SPECIFIC = "explain_mlir_codegen_mlir_specific"; inline static const std::string TASK_PARTITIONING_SCHEME = "taskPartitioningScheme"; inline static const std::string NUMBER_OF_THREADS = "numberOfThreads"; inline static const std::string MINIMUM_TASK_SIZE = "minimumTaskSize"; @@ -95,6 +99,9 @@ struct DaphneConfigJsonParams { EXPLAIN_TYPE_ADAPTATION, EXPLAIN_VECTORIZED, EXPLAIN_MLIR_CODEGEN, + EXPLAIN_MLIR_CODEGEN_SPARSITY_EXPLOITING_OP_FUSION, + EXPLAIN_MLIR_CODEGEN_DAPHNEIR_TO_MLIR, + EXPLAIN_MLIR_CODEGEN_MLIR_SPECIFIC, EXPLAIN_OBJ_REF_MGNT, TASK_PARTITIONING_SCHEME, NUMBER_OF_THREADS, diff --git a/src/parser/daphnedsl/DaphneDSLBuiltins.cpp b/src/parser/daphnedsl/DaphneDSLBuiltins.cpp index bc136b9d0..10208d788 100644 --- a/src/parser/daphnedsl/DaphneDSLBuiltins.cpp +++ b/src/parser/daphnedsl/DaphneDSLBuiltins.cpp @@ -993,13 +993,18 @@ antlrcpp::Any DaphneDSLBuiltins::build(mlir::Location loc, const std::string &fu builder.create(loc, FrameType::get(builder.getContext(), colTypes), args[0], args[1])); } if (func == "innerJoin") { - checkNumArgsExact(loc, func, numArgs, 4); + checkNumArgsBetween(loc, func, numArgs, 4, 5); std::vector colTypes; + mlir::Value numRowRes; for (int i = 0; i < 2; i++) for (mlir::Type t : args[i].getType().dyn_cast().getColumnTypes()) colTypes.push_back(t); + if (numArgs == 5) + numRowRes = utils.castSI64If(args[4]); + else + numRowRes = builder.create(loc, int64_t(-1)); return static_cast(builder.create(loc, FrameType::get(builder.getContext(), colTypes), - args[0], args[1], args[2], args[3])); + args[0], args[1], args[2], args[3], numRowRes)); } if (func == "fullOuterJoin") return createJoinOp(loc, func, args); @@ -1011,14 +1016,19 @@ antlrcpp::Any DaphneDSLBuiltins::build(mlir::Location loc, const std::string &fu // TODO Reconcile this with the other join ops, but we need it to work // quickly now. // return createJoinOp(loc, func, args); - checkNumArgsExact(loc, func, numArgs, 4); + checkNumArgsBetween(loc, func, numArgs, 4, 5); mlir::Value lhs = args[0]; mlir::Value rhs = args[1]; mlir::Value lhsOn = args[2]; mlir::Value rhsOn = args[3]; + mlir::Value numRowRes; + if (numArgs == 5) + numRowRes = utils.castSI64If(args[4]); + else + numRowRes = builder.create(loc, int64_t(-1)); return builder .create(loc, FrameType::get(builder.getContext(), {utils.unknownType}), utils.matrixOfSizeType, - lhs, rhs, lhsOn, rhsOn) + lhs, rhs, lhsOn, rhsOn, numRowRes) .getResults(); } if (func == "groupJoin") { @@ -1034,6 +1044,45 @@ antlrcpp::Any DaphneDSLBuiltins::build(mlir::Location loc, const std::string &fu .getResults(); } + // -------------------------------------------------------------------- + // Grouping and aggregation + // -------------------------------------------------------------------- + + if (func == "groupSum") { + // Arbitrary number of columns to group on. + // A single column to calculate the sum on. + checkNumArgsMin(loc, func, numArgs, 3); + mlir::Value currentFrame = args[0]; + mlir::Value aggCol = args[numArgs - 1]; + std::vector groupName; + std::vector columnName; + std::vector colTypes; + + // set aggregaton function to SUM + auto aggFunc = static_cast( + mlir::daphne::GroupEnumAttr::get(builder.getContext(), mlir::daphne::GroupEnum::SUM)); + std::vector functionName; + functionName.push_back(aggFunc); + + // get group columns + for (size_t i = 1; i < numArgs - 1; i++) { + groupName.push_back(args[i]); + } + + // get agg column + columnName.push_back(aggCol); + + // result column types + mlir::Type vt = utils.unknownType; + for (size_t i = 0; i < groupName.size() + columnName.size(); i++) { + colTypes.push_back(vt); + } + + return static_cast(builder.create(loc, FrameType::get(builder.getContext(), colTypes), + currentFrame, groupName, columnName, + builder.getArrayAttr(functionName))); + } + // ******************************************************************** // Frame label manipulation // ******************************************************************** diff --git a/src/parser/daphnedsl/DaphneDSLGrammar.g4 b/src/parser/daphnedsl/DaphneDSLGrammar.g4 index 53b8782f5..82ec2e095 100644 --- a/src/parser/daphnedsl/DaphneDSLGrammar.g4 +++ b/src/parser/daphnedsl/DaphneDSLGrammar.g4 @@ -91,8 +91,8 @@ expr: | lhs=expr op='@' rhs=expr # matmulExpr | lhs=expr op='^' rhs=expr # powExpr | lhs=expr op='%' rhs=expr # modExpr - | lhs=expr op=('*'|'/') rhs=expr # mulExpr - | lhs=expr op=('+'|'-') rhs=expr # addExpr + | lhs=expr op=('*'|'/') ('::' kernel=IDENTIFIER)? rhs=expr # mulExpr + | lhs=expr op=('+'|'-') ('::' kernel=IDENTIFIER)? rhs=expr # addExpr | lhs=expr op=('=='|'!='|'<'|'<='|'>'|'>=') rhs=expr # cmpExpr | lhs=expr op='&&' rhs=expr # conjExpr | lhs=expr op='||' rhs=expr # disjExpr diff --git a/src/parser/daphnedsl/DaphneDSLVisitor.cpp b/src/parser/daphnedsl/DaphneDSLVisitor.cpp index fab458673..21d6e0424 100644 --- a/src/parser/daphnedsl/DaphneDSLVisitor.cpp +++ b/src/parser/daphnedsl/DaphneDSLVisitor.cpp @@ -1197,13 +1197,30 @@ antlrcpp::Any DaphneDSLVisitor::visitMulExpr(DaphneDSLGrammarParser::MulExprCont mlir::Location loc = utils.getLoc(ctx->op); mlir::Value lhs = valueOrErrorOnVisit(ctx->lhs); mlir::Value rhs = valueOrErrorOnVisit(ctx->rhs); + bool hasKernelHint = ctx->kernel != nullptr; + mlir::Value res = nullptr; if (op == "*") - return utils.retValWithInferedType(builder.create(loc, lhs, rhs)); + res = utils.retValWithInferedType(builder.create(loc, lhs, rhs)); if (op == "/") - return utils.retValWithInferedType(builder.create(loc, lhs, rhs)); + res = utils.retValWithInferedType(builder.create(loc, lhs, rhs)); - throw ErrorHandler::compilerError(utils.getLoc(ctx->start), "DSLVisitor", "unexpected op symbol"); + if (hasKernelHint) { + std::string kernel = ctx->kernel->getText(); + + // We deliberately don't check if the specified kernel + // is registered for the created kind of operation, + // since this is checked in RewriteToCallKernelOpPass. + + mlir::Operation *op = res.getDefiningOp(); + // TODO Don't hardcode the attribute name. + op->setAttr("kernel_hint", builder.getStringAttr(kernel)); + } + + if (res) + return res; + else + throw ErrorHandler::compilerError(utils.getLoc(ctx->start), "DSLVisitor", "unexpected op symbol"); } antlrcpp::Any DaphneDSLVisitor::visitAddExpr(DaphneDSLGrammarParser::AddExprContext *ctx) { @@ -1211,7 +1228,9 @@ antlrcpp::Any DaphneDSLVisitor::visitAddExpr(DaphneDSLGrammarParser::AddExprCont mlir::Location loc = utils.getLoc(ctx->op); mlir::Value lhs = valueOrErrorOnVisit(ctx->lhs); mlir::Value rhs = valueOrErrorOnVisit(ctx->rhs); + bool hasKernelHint = ctx->kernel != nullptr; + mlir::Value res = nullptr; if (op == "+") // Note that we use '+' for both addition (EwAddOp) and concatenation // (EwConcatOp). The choice is made based on the types of the operands @@ -1219,11 +1238,28 @@ antlrcpp::Any DaphneDSLVisitor::visitAddExpr(DaphneDSLGrammarParser::AddExprCont // types might not be known at this point in time. Thus, we always // create an EwAddOp here. Note that EwAddOp has a canonicalize method // rewriting it to EwConcatOp if necessary. - return utils.retValWithInferedType(builder.create(loc, lhs, rhs)); + res = utils.retValWithInferedType(builder.create(loc, lhs, rhs)); if (op == "-") - return utils.retValWithInferedType(builder.create(loc, lhs, rhs)); + res = utils.retValWithInferedType(builder.create(loc, lhs, rhs)); - throw ErrorHandler::compilerError(utils.getLoc(ctx->start), "DSLVisitor", "unexpected op symbol"); + if (hasKernelHint) { + std::string kernel = ctx->kernel->getText(); + + // We deliberately don't check if the specified kernel + // is registered for the created kind of operation, + // since this is checked in RewriteToCallKernelOpPass. + + mlir::Operation *op = res.getDefiningOp(); + // TODO Don't hardcode the attribute name. + op->setAttr("kernel_hint", builder.getStringAttr(kernel)); + + // TODO retain the attr in case EwAddOp is rewritten to EwConcatOp. + } + + if (res) + return res; + else + throw ErrorHandler::compilerError(utils.getLoc(ctx->start), "DSLVisitor", "unexpected op symbol"); } antlrcpp::Any DaphneDSLVisitor::visitCmpExpr(DaphneDSLGrammarParser::CmpExprContext *ctx) { diff --git a/src/parser/sql/SQLVisitor.cpp b/src/parser/sql/SQLVisitor.cpp index d3c2a3d9f..e098ced3e 100644 --- a/src/parser/sql/SQLVisitor.cpp +++ b/src/parser/sql/SQLVisitor.cpp @@ -551,8 +551,11 @@ antlrcpp::Any SQLVisitor::visitInnerJoin(SQLGrammarParser::InnerJoinContext *ctx mlir::Value rhsName = valueOrErrorOnVisit(ctx->rhs); mlir::Value lhsName = valueOrErrorOnVisit(ctx->lhs); + mlir::Value numRowRes = + static_cast(builder.create(queryLoc, static_cast(-1))); + return static_cast( - builder.create(loc, t, currentFrame, tojoin, rhsName, lhsName)); + builder.create(loc, t, currentFrame, tojoin, rhsName, lhsName, numRowRes)); } std::vector rhsNames; diff --git a/src/runtime/local/datastructures/CSRMatrix.h b/src/runtime/local/datastructures/CSRMatrix.h index 08dc4a3e1..0298622bb 100644 --- a/src/runtime/local/datastructures/CSRMatrix.h +++ b/src/runtime/local/datastructures/CSRMatrix.h @@ -66,9 +66,9 @@ template class CSRMatrix : public Matrix { */ size_t maxNumNonZeros; - std::shared_ptr values; - std::shared_ptr colIdxs; - std::shared_ptr rowOffsets; + std::shared_ptr values; + std::shared_ptr colIdxs; + std::shared_ptr rowOffsets; size_t lastAppendedRowIdx; @@ -126,7 +126,7 @@ template class CSRMatrix : public Matrix { maxNumNonZeros = src->maxNumNonZeros; values = src->values; colIdxs = src->colIdxs; - rowOffsets = std::shared_ptr(src->rowOffsets, src->rowOffsets.get() + rowLowerIncl); + rowOffsets = std::shared_ptr(src->rowOffsets, src->rowOffsets.get() + rowLowerIncl); } virtual ~CSRMatrix() { @@ -174,6 +174,8 @@ template class CSRMatrix : public Matrix { const ValueType *getValues() const { return values.get(); } + std::shared_ptr getValuesSharedPtr() const { return values; } + ValueType *getValues(size_t rowIdx) { // We allow equality here to enable retrieving a pointer to the end. if (rowIdx > numRows) @@ -189,6 +191,8 @@ template class CSRMatrix : public Matrix { const size_t *getColIdxs() const { return colIdxs.get(); } + std::shared_ptr getColIdxsSharedPtr() const { return colIdxs; } + size_t *getColIdxs(size_t rowIdx) { // We allow equality here to enable retrieving a pointer to the end. if (rowIdx > numRows) @@ -204,6 +208,8 @@ template class CSRMatrix : public Matrix { const size_t *getRowOffsets() const { return rowOffsets.get(); } + std::shared_ptr getRowOffsetsSharedPtr() const { return rowOffsets; } + ValueType get(size_t rowIdx, size_t colIdx) const override { if (rowIdx >= numRows) throw std::runtime_error("CSRMatrix (get): rowIdx is out of bounds"); diff --git a/src/runtime/local/io/WriteCsv.h b/src/runtime/local/io/WriteCsv.h index 34f4efcc0..e5b55b804 100644 --- a/src/runtime/local/io/WriteCsv.h +++ b/src/runtime/local/io/WriteCsv.h @@ -48,6 +48,42 @@ template struct WriteCsv { template void writeCsv(const DTArg *arg, File *file) { WriteCsv::apply(arg, file); } +// **************************************************************************** +// Helper functions +// **************************************************************************** + +/** + * @brief Returns a CSV value representation of the given string. + * + * If necessary, the given string is wrapped in quotation marks (`"`) and quoation marks occurring in the string are + * escaped by doubling them, i.e., `"` is replaced by `""`. Quoting the string is always allowed in CSV, but only + * *necessary* and done by this function when the string contains the value separator, newlines, or quotes. If the + * string does not need to be quoted, it is returned as it is. + * + * @param s The string to convert to a CSV value representation. + * @return The CSV value representation of the given string. + */ +std::string quoteStrCsvIf(const std::string &s) { + if (s.find_first_of(",\n\r\"") != std::string::npos) { + // String needs to be quoted. + // Inside the quoted string, quotes ('"') must be escaped by duplicating them. + std::stringstream strm; + strm << '"'; + for (size_t i = 0; i < s.length(); i++) { + char c = s[i]; + if (c == '"') + strm << '"' << '"'; + else + strm << c; + } + strm << '"'; + return strm.str(); + } else { + // String does not need to be quoted. + return s; + } +} + // **************************************************************************** // (Partial) template specializations for different data/value types // **************************************************************************** @@ -67,9 +103,13 @@ template struct WriteCsv> { for (size_t i = 0; i < arg->getNumRows(); ++i) { for (size_t j = 0; j < argNumCols; ++j) { - fprintf(file->identifier, - std::is_floating_point::value ? "%f" : (std::is_same::value ? "%ld" : "%d"), - valuesArg[i * rowSkip + j]); + if constexpr (std::is_same::value) + fprintf(file->identifier, "%s", quoteStrCsvIf(valuesArg[i * rowSkip + j]).c_str()); + else + fprintf(file->identifier, + std::is_floating_point::value ? "%f" + : (std::is_same::value ? "%ld" : "%d"), + valuesArg[i * rowSkip + j]); if (j < (arg->getNumCols() - 1)) fprintf(file->identifier, ","); else @@ -125,6 +165,10 @@ template <> struct WriteCsv { case ValueTypeCode::F64: fprintf(file->identifier, "%f", reinterpret_cast(array)[i]); break; + case ValueTypeCode::STR: + fprintf(file->identifier, "%s", + quoteStrCsvIf(reinterpret_cast(array)[i]).c_str()); + break; default: throw std::runtime_error("unknown value type code"); } diff --git a/src/runtime/local/io/utils.h b/src/runtime/local/io/utils.h index 9f6fb821b..97ef3b002 100644 --- a/src/runtime/local/io/utils.h +++ b/src/runtime/local/io/utils.h @@ -108,7 +108,7 @@ inline size_t setCString(struct File *file, size_t start_pos, std::string *res, if (!is_not_end) break; if (is_multiLine && str[pos] == '"' && str[pos + 1] == '"') { - res->append("\"\""); + res->append("\""); pos += 2; } else if (is_multiLine && str[pos] == '\\' && str[pos + 1] == '"') { res->append("\\\""); diff --git a/src/runtime/local/kernels/BinaryOpCode.h b/src/runtime/local/kernels/BinaryOpCode.h index 30023586c..9fc55fd53 100644 --- a/src/runtime/local/kernels/BinaryOpCode.h +++ b/src/runtime/local/kernels/BinaryOpCode.h @@ -132,16 +132,20 @@ static constexpr bool supportsBinaryOp = false; SUPPORT(BITWISE_AND, VT) // Generates code specifying that all binary operations of a certain category should be -// supported on the given argument value type `VTArg` (for the left and right-hand-side -// arguments, for simplicity) and the given result value type `VTRes`. -#define SUPPORT_COMPARISONS_RA(VTRes, VTArg) \ +// supported on the given argument value types `VTLhs` and `VTRhs` (for the left and right-hand-side +// arguments, respectively) and the given result value type `VTRes`. +#define SUPPORT_COMPARISONS_RLR(VTRes, VTLhs, VTRhs) \ /* string Comparisons operations. */ \ - SUPPORT_RLR(LT, VTRes, VTArg, VTArg) \ - SUPPORT_RLR(GT, VTRes, VTArg, VTArg) -#define SUPPORT_EQUALITY_RA(VTRes, VTArg) \ + SUPPORT_RLR(LT, VTRes, VTLhs, VTRhs) \ + SUPPORT_RLR(GT, VTRes, VTLhs, VTRhs) +#define SUPPORT_COMPARISONS_EQUAL_RLR(VTRes, VTLhs, VTRhs) \ /* string Comparisons operations. */ \ - SUPPORT_RLR(EQ, VTRes, VTArg, VTArg) \ - SUPPORT_RLR(NEQ, VTRes, VTArg, VTArg) + SUPPORT_RLR(LE, VTRes, VTLhs, VTRhs) \ + SUPPORT_RLR(GE, VTRes, VTLhs, VTRhs) +#define SUPPORT_EQUALITY_RLR(VTRes, VTLhs, VTRhs) \ + /* string Comparisons operations. */ \ + SUPPORT_RLR(EQ, VTRes, VTLhs, VTRhs) \ + SUPPORT_RLR(NEQ, VTRes, VTLhs, VTRhs) #define SUPPORT_STRING_RLR(VTRes, VTLhs, VTRhs) \ /* string concatenation operations. */ \ /* Since the result may not fit in FixedStr16,*/ \ @@ -175,11 +179,15 @@ SUPPORT_NUMERIC_INT(uint64_t) SUPPORT_NUMERIC_INT(uint32_t) SUPPORT_NUMERIC_INT(uint8_t) // Strings binary operations. -SUPPORT_EQUALITY_RA(int64_t, std::string) -SUPPORT_EQUALITY_RA(int64_t, FixedStr16) -SUPPORT_EQUALITY_RA(int64_t, const char *) -SUPPORT_COMPARISONS_RA(int64_t, std::string) -SUPPORT_COMPARISONS_RA(int64_t, FixedStr16) +SUPPORT_EQUALITY_RLR(int64_t, std::string, std::string) +SUPPORT_EQUALITY_RLR(int64_t, FixedStr16, FixedStr16) +SUPPORT_EQUALITY_RLR(int64_t, const char *, const char *) +SUPPORT_EQUALITY_RLR(int64_t, std::string, const char *) +SUPPORT_COMPARISONS_RLR(int64_t, std::string, std::string) +SUPPORT_COMPARISONS_RLR(int64_t, FixedStr16, FixedStr16) +SUPPORT_COMPARISONS_RLR(int64_t, std::string, const char *) +SUPPORT_COMPARISONS_EQUAL_RLR(int64_t, std::string, std::string) +SUPPORT_COMPARISONS_EQUAL_RLR(int64_t, std::string, const char *) SUPPORT_STRING_RLR(std::string, std::string, std::string) SUPPORT_STRING_RLR(std::string, FixedStr16, FixedStr16) SUPPORT_STRING_RLR(const char *, const char *, const char *) @@ -195,6 +203,7 @@ SUPPORT_STRING_RLR(std::string, std::string, const char *) #undef SUPPORT_BITWISE #undef SUPPORT_NUMERIC_FP #undef SUPPORT_NUMERIC_INT -#undef SUPPORT_EQUALITY_RA -#undef SUPPORT_COMPARISONS_RA +#undef SUPPORT_EQUALITY_RLR +#undef SUPPORT_COMPARISONS_RLR +#undef SUPPORT_COMPARISONS_EQUAL_RLR #undef SUPPORT_STRING_RLR diff --git a/src/runtime/local/kernels/CastObj.h b/src/runtime/local/kernels/CastObj.h index ecb53bcbb..3984276d4 100644 --- a/src/runtime/local/kernels/CastObj.h +++ b/src/runtime/local/kernels/CastObj.h @@ -68,7 +68,7 @@ template class CastObj, Frame> { const size_t numRows = argFrm->getNumRows(); const DenseMatrix *argCol = argFrm->getColumn(c); for (size_t r = 0; r < numRows; r++) - res->set(r, c, static_cast(argCol->get(r, 0))); + res->set(r, c, castSca(argCol->get(r, 0), nullptr)); DataObjectFactory::destroy(argCol); } @@ -126,6 +126,9 @@ template class CastObj, Frame> { case ValueTypeCode::UI8: castCol(res, arg, c); break; + case ValueTypeCode::STR: + castCol(res, arg, c); + break; default: throw std::runtime_error("CastObj::apply: unknown value type code"); } diff --git a/src/runtime/local/kernels/CastSca.h b/src/runtime/local/kernels/CastSca.h index 7e0b84b85..547fc6c0a 100644 --- a/src/runtime/local/kernels/CastSca.h +++ b/src/runtime/local/kernels/CastSca.h @@ -123,4 +123,12 @@ template struct CastSca { } }; +// ---------------------------------------------------------------------------- +// string <- string +// ---------------------------------------------------------------------------- + +template <> struct CastSca { + static std::string apply(const std::string arg, DaphneContext *ctx) { return arg; } +}; + #endif // SRC_RUNTIME_LOCAL_KERNELS_CASTSCA_H diff --git a/src/runtime/local/kernels/ConvertCSRMatrixToColIdxsMemRef.h b/src/runtime/local/kernels/ConvertCSRMatrixToColIdxsMemRef.h new file mode 100644 index 000000000..64f50b5cf --- /dev/null +++ b/src/runtime/local/kernels/ConvertCSRMatrixToColIdxsMemRef.h @@ -0,0 +1,36 @@ +/* + * Copyright 2024 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "mlir/ExecutionEngine/CRunnerUtils.h" +#include "runtime/local/context/DaphneContext.h" +#include "runtime/local/datastructures/CSRMatrix.h" + +template +inline StridedMemRefType convertCSRMatrixToColIdxsMemRef(const CSRMatrix *input, DCTX(ctx)) { + StridedMemRefType colIdxsMemRef{}; + + colIdxsMemRef.basePtr = input->getColIdxsSharedPtr().get(); + colIdxsMemRef.data = colIdxsMemRef.basePtr; + colIdxsMemRef.offset = 0; + colIdxsMemRef.sizes[0] = input->getNumNonZeros(); + colIdxsMemRef.strides[0] = 1; + + input->increaseRefCounter(); + + return colIdxsMemRef; +} diff --git a/src/runtime/local/kernels/ConvertCSRMatrixToRowOffsetsMemRef.h b/src/runtime/local/kernels/ConvertCSRMatrixToRowOffsetsMemRef.h new file mode 100644 index 000000000..0354e7df9 --- /dev/null +++ b/src/runtime/local/kernels/ConvertCSRMatrixToRowOffsetsMemRef.h @@ -0,0 +1,36 @@ +/* + * Copyright 2024 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "mlir/ExecutionEngine/CRunnerUtils.h" +#include "runtime/local/context/DaphneContext.h" +#include "runtime/local/datastructures/CSRMatrix.h" + +template +inline StridedMemRefType convertCSRMatrixToRowOffsetsMemRef(const CSRMatrix *input, DCTX(ctx)) { + StridedMemRefType rowOffsetsMemRef{}; + + rowOffsetsMemRef.basePtr = input->getRowOffsetsSharedPtr().get(); + rowOffsetsMemRef.data = rowOffsetsMemRef.basePtr; + rowOffsetsMemRef.offset = 0; + rowOffsetsMemRef.sizes[0] = input->getNumRows() + 1; + rowOffsetsMemRef.strides[0] = 1; + + input->increaseRefCounter(); + + return rowOffsetsMemRef; +} diff --git a/src/runtime/local/kernels/ConvertCSRMatrixToValuesMemRef.h b/src/runtime/local/kernels/ConvertCSRMatrixToValuesMemRef.h new file mode 100644 index 000000000..eca515fa0 --- /dev/null +++ b/src/runtime/local/kernels/ConvertCSRMatrixToValuesMemRef.h @@ -0,0 +1,36 @@ +/* + * Copyright 2024 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "mlir/ExecutionEngine/CRunnerUtils.h" +#include "runtime/local/context/DaphneContext.h" +#include "runtime/local/datastructures/CSRMatrix.h" + +template +inline StridedMemRefType convertCSRMatrixToValuesMemRef(const CSRMatrix *input, DCTX(ctx)) { + StridedMemRefType valuesMemRef{}; + + valuesMemRef.basePtr = input->getValuesSharedPtr().get(); + valuesMemRef.data = valuesMemRef.basePtr; + valuesMemRef.offset = 0; + valuesMemRef.sizes[0] = input->getNumNonZeros(); // Is numRowsAllocated needed to account for views? + valuesMemRef.strides[0] = 1; + + input->increaseRefCounter(); + + return valuesMemRef; +} diff --git a/src/runtime/local/kernels/ConvertDenseMatrixToMemRef.h b/src/runtime/local/kernels/ConvertDenseMatrixToMemRef.h index 1cabfcaa8..19191bada 100644 --- a/src/runtime/local/kernels/ConvertDenseMatrixToMemRef.h +++ b/src/runtime/local/kernels/ConvertDenseMatrixToMemRef.h @@ -17,6 +17,7 @@ #pragma once #include "mlir/ExecutionEngine/CRunnerUtils.h" +#include "runtime/local/context/DaphneContext.h" #include "runtime/local/datastructures/DenseMatrix.h" template diff --git a/src/runtime/local/kernels/EwBinaryObjSca.h b/src/runtime/local/kernels/EwBinaryObjSca.h index 61b5fa8e8..62d704ded 100644 --- a/src/runtime/local/kernels/EwBinaryObjSca.h +++ b/src/runtime/local/kernels/EwBinaryObjSca.h @@ -63,7 +63,7 @@ struct EwBinaryObjSca, DenseMatrix, VTRhs> { if (res == nullptr) res = DataObjectFactory::create>(numRows, numCols, false); - const VTRes *valuesLhs = lhs->getValues(); + const VTLhs *valuesLhs = lhs->getValues(); VTRes *valuesRes = res->getValues(); EwBinaryScaFuncPtr func = getEwBinaryScaFuncPtr(opCode); diff --git a/src/runtime/local/kernels/ExtractRow.h b/src/runtime/local/kernels/ExtractRow.h index 442bccdd0..a4367d198 100644 --- a/src/runtime/local/kernels/ExtractRow.h +++ b/src/runtime/local/kernels/ExtractRow.h @@ -118,16 +118,23 @@ template struct ExtractRow { throw std::out_of_range(errMsg.str()); } for (size_t c = 0; c < numCols; c++) { - // We always copy in units of 8 bytes (uint64_t). If the - // actual element size is lower, the superfluous bytes will - // be overwritten by the next match. With this approach, we - // do not need to call memcpy for each element, nor - // interpret the types for a L/S of fitting size. - // TODO Don't multiply by elementSize, but left-shift by - // ld(elementSize). - *reinterpret_cast(resCols[c]) = - *reinterpret_cast(argCols[c] + pos * elementSizes[c]); - resCols[c] += elementSizes[c]; + if (schema[c] == ValueTypeCode::STR) { + // Handle std::string column + *reinterpret_cast(resCols[c]) = + *reinterpret_cast(argCols[c] + pos * elementSizes[c]); + resCols[c] += elementSizes[c]; + } else { + // We always copy in units of 8 bytes (uint64_t). If the + // actual element size is lower, the superfluous bytes will + // be overwritten by the next match. With this approach, we + // do not need to call memcpy for each element, nor + // interpret the types for a L/S of fitting size. + // TODO Don't multiply by elementSize, but left-shift by + // ld(elementSize). + *reinterpret_cast(resCols[c]) = + *reinterpret_cast(argCols[c] + pos * elementSizes[c]); + resCols[c] += elementSizes[c]; + } } } res->shrinkNumRows(numRowsSel); diff --git a/src/runtime/local/kernels/FilterCol.h b/src/runtime/local/kernels/FilterCol.h index 3e9671456..1d6b68227 100644 --- a/src/runtime/local/kernels/FilterCol.h +++ b/src/runtime/local/kernels/FilterCol.h @@ -73,6 +73,16 @@ template struct FilterCol, DenseMa VT *valuesRes = res->getValues(); const size_t rowSkipArg = arg->getRowSkip(); const size_t rowSkipRes = res->getRowSkip(); + + // Two alternative approaches for doing the main work. + // Note that even though sel is essentially a bit vector, we represent it as a 64-bit integer at the moment, for + // simplicity. As both the elements in sel used in approach 1 and the positions in approach 2 are currently + // represented as 64-bit integers, approach 2 should always be faster than approach 1, because in approach 2 we + // iterate over at most as many 64-bit values as in approach 1 while we can omit the check. Once we change to a + // 1-bit representation for the values in sel, we should rethink the trade-off between the two approaches. +#if 0 // approach 1 + // For every row in arg, iterate over all elements in sel (one per column in arg), check if the respective + // column should be part of the output and if so, copy the value to the output. for (size_t r = 0; r < numRows; r++) { for (size_t ca = 0, cr = 0; ca < numColsArg; ca++) if (sel->get(ca, 0)) @@ -80,6 +90,22 @@ template struct FilterCol, DenseMa valuesArg += rowSkipArg; valuesRes += rowSkipRes; } +#else // approach 2 + // Once in the beginning, create a vector of the positions of the columns we want to copy to the output. + // Negligible effort, unless the number of rows in arg is very small. + std::vector positions; + for (size_t c = 0; c < numColsArg; c++) + if (sel->get(c, 0)) + positions.push_back(c); + // For every row in arg, iterate over the array of those positions and copy the value in the respective column + // to the output. + for (size_t r = 0; r < numRows; r++) { + for (size_t i = 0; i < positions.size(); i++) + valuesRes[i] = valuesArg[positions[i]]; + valuesArg += rowSkipArg; + valuesRes += rowSkipRes; + } +#endif } }; diff --git a/src/runtime/local/kernels/FilterRow.h b/src/runtime/local/kernels/FilterRow.h index 0c53ec2ab..7b6a8f257 100644 --- a/src/runtime/local/kernels/FilterRow.h +++ b/src/runtime/local/kernels/FilterRow.h @@ -131,13 +131,20 @@ template struct FilterRow { for (size_t r = 0; r < numRows; r++) { if (valuesSel[r]) { for (size_t c = 0; c < numCols; c++) { - // We always copy in units of 8 bytes (uint64_t). If the - // actual element size is lower, the superfluous bytes will - // be overwritten by the next match. With this approach, we - // do not need to call memcpy for each element, nor - // interpret the types for a L/S of fitting size. - *reinterpret_cast(resCols[c]) = *reinterpret_cast(argCols[c]); - resCols[c] += elementSizes[c]; + if (schema[c] == ValueTypeCode::STR) { + // Handle std::string column + *reinterpret_cast(resCols[c]) = + *reinterpret_cast(argCols[c]); // Deep copy the string + resCols[c] += elementSizes[c]; + } else { + // We always copy in units of 8 bytes (uint64_t). If the + // actual element size is lower, the superfluous bytes will + // be overwritten by the next match. With this approach, we + // do not need to call memcpy for each element, nor + // interpret the types for a L/S of fitting size. + *reinterpret_cast(resCols[c]) = *reinterpret_cast(argCols[c]); + resCols[c] += elementSizes[c]; + } } } for (size_t c = 0; c < numCols; c++) diff --git a/src/runtime/local/kernels/Group.h b/src/runtime/local/kernels/Group.h index 0d8df6410..5e192d93b 100644 --- a/src/runtime/local/kernels/Group.h +++ b/src/runtime/local/kernels/Group.h @@ -58,6 +58,24 @@ void group(DT *&res, const DT *arg, const char **keyCols, size_t numKeyCols, con // Frame <- Frame // ---------------------------------------------------------------------------- +// TODO If possible, reuse the stringifyGroupEnum() from the DAPHNE compiler. +inline std::string myStringifyGroupEnum(mlir::daphne::GroupEnum val) { + using mlir::daphne::GroupEnum; + switch (val) { + case GroupEnum::COUNT: + return "COUNT"; + case GroupEnum::SUM: + return "SUM"; + case GroupEnum::MIN: + return "MIN"; + case GroupEnum::MAX: + return "MAX"; + case GroupEnum::AVG: + return "AVG"; + } + throw std::runtime_error("invalid GroupEnum value"); +} + // returns the result of the aggregation function aggFunc over the (contiguous) // memory between the begin and end pointer template @@ -65,26 +83,60 @@ VTRes aggregate(const mlir::daphne::GroupEnum &aggFunc, const VTArg *begin, cons using mlir::daphne::GroupEnum; switch (aggFunc) { case GroupEnum::COUNT: - return end - begin; + if constexpr (std::is_same::value) + throw std::invalid_argument(std::string("aggregate: ") + myStringifyGroupEnum(aggFunc) + + std::string(" aggregation is not supported for these value types.")); + else + return end - begin; break; // TODO: Do we need to check for Null elements here? case GroupEnum::SUM: - return std::accumulate(begin, end, (VTRes)0); + if constexpr ((std::is_same::value) || (std::is_same::value)) + throw std::invalid_argument(std::string("aggregate: ") + myStringifyGroupEnum(aggFunc) + + std::string(" aggregation is not supported for these value types.")); + else + return std::accumulate(begin, end, (VTRes)0); break; case GroupEnum::MIN: - return *std::min_element(begin, end); + if constexpr ((std::is_same::value) || (std::is_same::value)) + throw std::invalid_argument(std::string("aggregate: ") + myStringifyGroupEnum(aggFunc) + + std::string(" aggregation is not supported for these value types.")); + else + return *std::min_element(begin, end); break; case GroupEnum::MAX: - return *std::max_element(begin, end); + if constexpr ((std::is_same::value) || (std::is_same::value)) + throw std::invalid_argument(std::string("aggregate: ") + myStringifyGroupEnum(aggFunc) + + std::string(" aggregation is not supported for these value types.")); + else + return *std::max_element(begin, end); break; case GroupEnum::AVG: - return std::accumulate(begin, end, (double)0) / (double)(end - begin); + if constexpr ((std::is_same::value) || (std::is_same::value)) + throw std::invalid_argument(std::string("aggregate: ") + myStringifyGroupEnum(aggFunc) + + std::string(" aggregation is not supported for these value types.")); + else + return std::accumulate(begin, end, (double)0) / (double)(end - begin); break; default: - return *begin; + if constexpr (std::is_same::value || std::is_same::value) + throw std::invalid_argument("aggregate: Unsupported aggregation operation for string types."); + else + return *begin; break; } } +template <> +std::string aggregate(const mlir::daphne::GroupEnum &aggFunc, const std::string *begin, const std::string *end) { + using mlir::daphne::GroupEnum; + if (aggFunc == GroupEnum::MIN) + return *std::min_element(begin, end); + if (aggFunc == GroupEnum::MAX) + return *std::max_element(begin, end); + else + return *begin; +} + // struct which calls the aggregate() function (specified via aggFunc) on each // duplicate group in the groups vector and on all implied single groups for a // sepcified column (colIdx) of the argument frame (arg) and stores the result @@ -117,22 +169,14 @@ template struct ColumnGroupAgg { } }; -inline std::string myStringifyGroupEnum(mlir::daphne::GroupEnum val) { - using mlir::daphne::GroupEnum; - switch (val) { - case GroupEnum::COUNT: - return "COUNT"; - case GroupEnum::SUM: - return "SUM"; - case GroupEnum::MIN: - return "MIN"; - case GroupEnum::MAX: - return "MAX"; - case GroupEnum::AVG: - return "AVG"; +// Since DeduceValueTypeAndExecute can not handle string values, +// we add special ColumnGroupAgg function for arg with std::string values. +template struct ColumnGroupAggStringVTArg { + static void apply(Frame *res, const Frame *arg, size_t colIdx, std::vector> *groups, + mlir::daphne::GroupEnum aggFunc, DCTX(ctx)) { + ColumnGroupAgg::apply(res, arg, colIdx, groups, aggFunc, ctx); } - return ""; -} +}; template <> struct Group { static void apply(Frame *&res, const Frame *arg, const char **keyCols, size_t numKeyCols, const char **aggCols, @@ -270,9 +314,18 @@ template <> struct Group { // copying key columns and column-wise group aggregation for (size_t i = 0; i < numColsRes; i++) { - DeduceValueTypeAndExecute::apply( - res->getSchema()[i], ordered->getSchema()[i], res, ordered, i, groups, - (i < numKeyCols) ? (GroupEnum)0 : aggFuncs[i - numKeyCols], ctx); + if (ordered->getSchema()[i] == ValueTypeCode::STR) { + if (res->getSchema()[i] == ValueTypeCode::STR) + ColumnGroupAgg::apply( + res, ordered, i, groups, (i < numKeyCols) ? (GroupEnum)0 : aggFuncs[i - numKeyCols], ctx); + else + DeduceValueTypeAndExecute::apply( + res->getSchema()[i], res, ordered, i, groups, + (i < numKeyCols) ? (GroupEnum)0 : aggFuncs[i - numKeyCols], ctx); + } else + DeduceValueTypeAndExecute::apply( + res->getSchema()[i], ordered->getSchema()[i], res, ordered, i, groups, + (i < numKeyCols) ? (GroupEnum)0 : aggFuncs[i - numKeyCols], ctx); } delete groups; DataObjectFactory::destroy(ordered); diff --git a/src/runtime/local/kernels/InnerJoin.h b/src/runtime/local/kernels/InnerJoin.h index a6185056b..e0c8be1a3 100644 --- a/src/runtime/local/kernels/InnerJoin.h +++ b/src/runtime/local/kernels/InnerJoin.h @@ -1,3 +1,19 @@ +/* + * Copyright 2024 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + #ifndef SRC_RUNTIME_LOCAL_KERNELS_INNERJOIN_H #define SRC_RUNTIME_LOCAL_KERNELS_INNERJOIN_H @@ -10,7 +26,8 @@ #include #include -#include +#include +#include #include #include @@ -30,43 +47,85 @@ template void innerJoinSet(ValueTypeCode vtcType, Frame *&res, const Frame *arg, const int64_t toRow, const int64_t toCol, const int64_t fromRow, const int64_t fromCol, DCTX(ctx)) { if (vtcType == ValueTypeUtils::codeFor) { - innerJoinSetValue(res->getColumn(toCol), arg->getColumn(fromCol), toRow, fromRow, ctx); + const DenseMatrix *colArg = arg->getColumn(fromCol); + DenseMatrix *colRes = res->getColumn(toCol); + innerJoinSetValue(colRes, colArg, toRow, fromRow, ctx); + DataObjectFactory::destroy(colArg, colRes); } } -template -bool innerJoinEqual( - // results - Frame *&res, - // arguments - const DenseMatrix *argLhs, const DenseMatrix *argRhs, const int64_t targetLhs, - const int64_t targetRhs, - // context - DCTX(ctx)) { - const VTLhs l = argLhs->get(targetLhs, 0); - const VTRhs r = argRhs->get(targetRhs, 0); - return l == r; +// Create a hash table for rhs +template +std::unordered_map> BuildHashRhs(const Frame *rhs, const char *rhsOn, + const size_t numRowRhs) { + std::unordered_map> res; + const DenseMatrix *col = rhs->getColumn(rhsOn); + for (size_t row_idx_r = 0; row_idx_r < numRowRhs; row_idx_r++) { + VTRhs key = col->get(row_idx_r, 0); + res[key].push_back(row_idx_r); + } + DataObjectFactory::destroy(col); + return res; } -template -bool innerJoinProbeIf( - // value type known only at run-time - ValueTypeCode vtcLhs, ValueTypeCode vtcRhs, - // results - Frame *&res, +template +int64_t ProbeHashLhs( + // results and results schema + Frame *&res, ValueTypeCode *schema, // input frames const Frame *lhs, const Frame *rhs, // input column names - const char *lhsOn, const char *rhsOn, - // input rows - const int64_t targetL, const int64_t targetR, + const char *lhsOn, + // num columns + const size_t numColRhs, const size_t numColLhs, // context - DCTX(ctx)) { - if (vtcLhs == ValueTypeUtils::codeFor && vtcRhs == ValueTypeUtils::codeFor) { - return innerJoinEqual(res, lhs->getColumn(lhsOn), rhs->getColumn(rhsOn), targetL, - targetR, ctx); + DCTX(ctx), + // hashed map of Rhs + std::unordered_map> hashRhsIndex, + // Lhs rowa + const size_t numRowLhs) { + int64_t row_idx_res = 0; + int64_t col_idx_res = 0; + auto lhsFKCol = lhs->getColumn(lhsOn); + for (size_t row_idx_l = 0; row_idx_l < numRowLhs; row_idx_l++) { + auto key = lhsFKCol->get(row_idx_l, 0); + auto it = hashRhsIndex.find(key); + + if (it != hashRhsIndex.end()) { + for (size_t row_idx_r : it->second) { + col_idx_res = 0; + + // Populate result row from lhs columns + for (size_t idx_c = 0; idx_c < numColLhs; idx_c++) { + innerJoinSet(schema[col_idx_res], res, lhs, row_idx_res, col_idx_res, row_idx_l, idx_c, + ctx); + innerJoinSet(schema[col_idx_res], res, lhs, row_idx_res, col_idx_res, row_idx_l, idx_c, + ctx); + innerJoinSet(schema[col_idx_res], res, lhs, row_idx_res, col_idx_res, row_idx_l, idx_c, + ctx); + + col_idx_res++; + } + + // Populate result row from rhs columns + for (size_t idx_c = 0; idx_c < numColRhs; idx_c++) { + innerJoinSet(schema[col_idx_res], res, rhs, row_idx_res, col_idx_res, row_idx_r, idx_c, + ctx); + innerJoinSet(schema[col_idx_res], res, rhs, row_idx_res, col_idx_res, row_idx_r, idx_c, + ctx); + + innerJoinSet(schema[col_idx_res], res, rhs, row_idx_res, col_idx_res, row_idx_r, idx_c, + ctx); + + col_idx_res++; + } + + row_idx_res++; + } + } } - return false; + DataObjectFactory::destroy(lhsFKCol); + return row_idx_res; } // **************************************************************************** @@ -80,16 +139,17 @@ inline void innerJoin( const Frame *lhs, const Frame *rhs, // input column names const char *lhsOn, const char *rhsOn, + // result size + int64_t numRowRes, // context DCTX(ctx)) { // Find out the value types of the columns to process. ValueTypeCode vtcLhsOn = lhs->getColumnType(lhsOn); - ValueTypeCode vtcRhsOn = rhs->getColumnType(rhsOn); // Perhaps check if res already allocated. const size_t numRowRhs = rhs->getNumRows(); const size_t numRowLhs = lhs->getNumRows(); - const size_t totalRows = numRowRhs * numRowLhs; + const size_t totalRows = numRowRes == -1 ? numRowRhs * numRowLhs : numRowRes; const size_t numColRhs = rhs->getNumCols(); const size_t numColLhs = lhs->getNumCols(); const size_t totalCols = numColRhs + numColLhs; @@ -97,55 +157,34 @@ inline void innerJoin( const std::string *oldlabels_r = rhs->getLabels(); int64_t col_idx_res = 0; - int64_t row_idx_res = 0; + int64_t row_idx_res; + // Set up schema and labels ValueTypeCode schema[totalCols]; std::string newlabels[totalCols]; - // Setting Schema and Labels for (size_t col_idx_l = 0; col_idx_l < numColLhs; col_idx_l++) { schema[col_idx_res] = lhs->getColumnType(col_idx_l); - newlabels[col_idx_res] = oldlabels_l[col_idx_l]; - col_idx_res++; + newlabels[col_idx_res++] = oldlabels_l[col_idx_l]; } for (size_t col_idx_r = 0; col_idx_r < numColRhs; col_idx_r++) { schema[col_idx_res] = rhs->getColumnType(col_idx_r); - newlabels[col_idx_res] = oldlabels_r[col_idx_r]; - col_idx_res++; + newlabels[col_idx_res++] = oldlabels_r[col_idx_r]; } - // Creating Result Frame + // Initialize result frame with an estimate res = DataObjectFactory::create(totalRows, totalCols, schema, newlabels, false); - for (size_t row_idx_l = 0; row_idx_l < numRowLhs; row_idx_l++) { - for (size_t row_idx_r = 0; row_idx_r < numRowRhs; row_idx_r++) { - col_idx_res = 0; - // PROBE ROWS - bool hit = false; - hit = hit || innerJoinProbeIf(vtcLhsOn, vtcRhsOn, res, lhs, rhs, lhsOn, rhsOn, row_idx_l, - row_idx_r, ctx); - hit = hit || innerJoinProbeIf(vtcLhsOn, vtcRhsOn, res, lhs, rhs, lhsOn, rhsOn, row_idx_l, - row_idx_r, ctx); - if (hit) { - for (size_t idx_c = 0; idx_c < numColLhs; idx_c++) { - innerJoinSet(schema[col_idx_res], res, lhs, row_idx_res, col_idx_res, row_idx_l, idx_c, - ctx); - innerJoinSet(schema[col_idx_res], res, lhs, row_idx_res, col_idx_res, row_idx_l, idx_c, - ctx); - col_idx_res++; - } - for (size_t idx_c = 0; idx_c < numColRhs; idx_c++) { - innerJoinSet(schema[col_idx_res], res, rhs, row_idx_res, col_idx_res, row_idx_r, idx_c, - ctx); - - innerJoinSet(schema[col_idx_res], res, rhs, row_idx_res, col_idx_res, row_idx_r, idx_c, - ctx); - col_idx_res++; - } - row_idx_res++; - } - } + // Build hash table and prob left table + if (vtcLhsOn == ValueTypeCode::STR) { + row_idx_res = ProbeHashLhs(res, schema, lhs, rhs, lhsOn, numColRhs, numColLhs, ctx, + BuildHashRhs(rhs, rhsOn, numRowRhs), numRowLhs); + } else { + row_idx_res = ProbeHashLhs(res, schema, lhs, rhs, lhsOn, numColRhs, numColLhs, ctx, + BuildHashRhs(rhs, rhsOn, numRowRhs), numRowLhs); } + // Shrink result frame to actual size res->shrinkNumRows(row_idx_res); } + #endif // SRC_RUNTIME_LOCAL_KERNELS_INNERJOIN_H diff --git a/src/runtime/local/kernels/Order.h b/src/runtime/local/kernels/Order.h index afa6d1a10..5b78e6a6a 100644 --- a/src/runtime/local/kernels/Order.h +++ b/src/runtime/local/kernels/Order.h @@ -197,8 +197,11 @@ struct OrderFrame { if (numColIdxs > 1) { for (size_t i = 0; i < numColIdxs - 1; i++) { - DeduceValueTypeAndExecute::apply(arg->getSchema()[colIdxs[i]], arg, idx, groups, - ascending[i], colIdxs[i], ctx); + if (arg->getSchema()[colIdxs[i]] == ValueTypeCode::STR) + MultiColumnIDSort::apply(arg, idx, groups, ascending[i], colIdxs[i], ctx); + else + DeduceValueTypeAndExecute::apply(arg->getSchema()[colIdxs[i]], arg, idx, groups, + ascending[i], colIdxs[i], ctx); } } @@ -206,11 +209,17 @@ struct OrderFrame { // use size_t colIdx = colIdxs[numColIdxs - 1]; if (groupsRes == nullptr) { - DeduceValueTypeAndExecute::apply(arg->getSchema()[colIdx], arg, idx, groups, - ascending[numColIdxs - 1], colIdx, ctx); + if (arg->getSchema()[colIdx] == ValueTypeCode::STR) + ColumnIDSort::apply(arg, idx, groups, ascending[numColIdxs - 1], colIdx, ctx); + else + DeduceValueTypeAndExecute::apply(arg->getSchema()[colIdx], arg, idx, groups, + ascending[numColIdxs - 1], colIdx, ctx); } else { - DeduceValueTypeAndExecute::apply(arg->getSchema()[colIdx], arg, idx, groups, - ascending[numColIdxs - 1], colIdx, ctx); + if (arg->getSchema()[colIdx] == ValueTypeCode::STR) + MultiColumnIDSort::apply(arg, idx, groups, ascending[numColIdxs - 1], colIdx, ctx); + else + DeduceValueTypeAndExecute::apply(arg->getSchema()[colIdx], arg, idx, groups, + ascending[numColIdxs - 1], colIdx, ctx); groupsRes->insert(groupsRes->end(), groups.begin(), groups.end()); } } diff --git a/src/runtime/local/kernels/SemiJoin.h b/src/runtime/local/kernels/SemiJoin.h index eb815e0b7..4e308d087 100644 --- a/src/runtime/local/kernels/SemiJoin.h +++ b/src/runtime/local/kernels/SemiJoin.h @@ -46,6 +46,8 @@ void semiJoinCol( Frame *&res, DenseMatrix *&resLhsTid, // arguments const DenseMatrix *argLhs, const DenseMatrix *argRhs, + // result size + int64_t numRowRes, // context DCTX(ctx)) { if (argLhs->getNumCols() != 1) @@ -72,11 +74,14 @@ void semiJoinCol( // Create the output data objects. if (res == nullptr) { ValueTypeCode schema[] = {ValueTypeUtils::codeFor}; - res = DataObjectFactory::create(numArgLhs, 1, schema, nullptr, false); + const size_t resSize = numRowRes == -1 ? numArgLhs : numRowRes; + res = DataObjectFactory::create(resSize, 1, schema, nullptr, false); } auto resLhs = res->getColumn(0); - if (resLhsTid == nullptr) - resLhsTid = DataObjectFactory::create>(numArgLhs, 1, false); + if (resLhsTid == nullptr) { + const size_t resLhsTidSize = numRowRes == -1 ? numArgLhs : numRowRes; + resLhsTid = DataObjectFactory::create>(resLhsTidSize, 1, false); + } size_t pos = 0; for (size_t i = 0; i < numArgLhs; i++) { @@ -107,11 +112,13 @@ void semiJoinColIf( const Frame *lhs, const Frame *rhs, // input column names const char *lhsOn, const char *rhsOn, + // result size + int64_t numRowRes, // context DCTX(ctx)) { if (vtcLhs == ValueTypeUtils::codeFor && vtcRhs == ValueTypeUtils::codeFor) { semiJoinCol(res, resLhsTid, lhs->getColumn(lhsOn), rhs->getColumn(rhsOn), - ctx); + numRowRes, ctx); } } @@ -127,6 +134,8 @@ void semiJoin( const Frame *lhs, const Frame *rhs, // input column names const char *lhsOn, const char *rhsOn, + // result size + int64_t numRowRes, // context DCTX(ctx)) { // Find out the value types of the columns to process. @@ -136,8 +145,8 @@ void semiJoin( // Call the semiJoin-kernel on columns for the actual combination of // value types. // Repeat this for all type combinations... - semiJoinColIf(vtcLhsOn, vtcRhsOn, res, lhsTid, lhs, rhs, lhsOn, rhsOn, ctx); - semiJoinColIf(vtcLhsOn, vtcRhsOn, res, lhsTid, lhs, rhs, lhsOn, rhsOn, ctx); + semiJoinColIf(vtcLhsOn, vtcRhsOn, res, lhsTid, lhs, rhs, lhsOn, rhsOn, numRowRes, ctx); + semiJoinColIf(vtcLhsOn, vtcRhsOn, res, lhsTid, lhs, rhs, lhsOn, rhsOn, numRowRes, ctx); // Set the column labels of the result frame. std::string labels[] = {lhsOn}; diff --git a/src/runtime/local/kernels/kernels.json b/src/runtime/local/kernels/kernels.json index 4058f45d0..19597655e 100644 --- a/src/runtime/local/kernels/kernels.json +++ b/src/runtime/local/kernels/kernels.json @@ -419,6 +419,7 @@ [["DenseMatrix", "double"], "Frame"], [["DenseMatrix", "int64_t"], "Frame"], [["DenseMatrix", "uint64_t"], "Frame"], + [["DenseMatrix", "std::string"], "Frame"], ["Frame", ["DenseMatrix", "double"]], ["Frame", ["DenseMatrix", "int64_t"]], ["Frame", ["DenseMatrix", "uint64_t"]], @@ -1539,6 +1540,93 @@ ["double"] ] }, + { + "kernelTemplate": { + "header": "ConvertCSRMatrixToValuesMemRef.h", + "opName": "convertCSRMatrixToValuesMemRef", + "returnType": "StridedMemRefType", + "templateParams": [ + { + "name": "VT", + "isDataType": false + } + ], + "runtimeParams": [ + { + "type": "CSRMatrix *", + "name": "input" + } + ] + }, + "instantiations": [ + ["int64_t"], + ["int32_t"], + ["int8_t"], + ["uint64_t"], + ["uint32_t"], + ["uint8_t"], + ["float"], + ["double"] + ] + }, + { + "kernelTemplate": { + "header": "ConvertCSRMatrixToColIdxsMemRef.h", + "opName": "convertCSRMatrixToColIdxsMemRef", + "returnType": "StridedMemRefType", + "templateParams": [ + { + "name": "VT", + "isDataType": false + } + ], + "runtimeParams": [ + { + "type": "CSRMatrix *", + "name": "input" + } + ] + }, + "instantiations": [ + ["int64_t"], + ["int32_t"], + ["int8_t"], + ["uint64_t"], + ["uint32_t"], + ["uint8_t"], + ["float"], + ["double"] + ] + }, + { + "kernelTemplate": { + "header": "ConvertCSRMatrixToRowOffsetsMemRef.h", + "opName": "convertCSRMatrixToRowOffsetsMemRef", + "returnType": "StridedMemRefType", + "templateParams": [ + { + "name": "VT", + "isDataType": false + } + ], + "runtimeParams": [ + { + "type": "CSRMatrix *", + "name": "input" + } + ] + }, + "instantiations": [ + ["int64_t"], + ["int32_t"], + ["int8_t"], + ["uint64_t"], + ["uint32_t"], + ["uint8_t"], + ["float"], + ["double"] + ] + }, { "kernelTemplate": { "header": "CreateFrame.h", @@ -2109,6 +2197,11 @@ ["DenseMatrix", "std::string"], ["DenseMatrix", "std::string"], "const char *" + ], + [ + ["DenseMatrix", "int64_t"], + ["DenseMatrix", "std::string"], + "const char *" ] ], "opCodes": [ @@ -2974,7 +3067,11 @@ { "type": "const char *", "name": "rhsOn" - } + }, + { + "type": "int64_t", + "name": "numRowRes" + } ] }, "instantiations": [[]] @@ -3937,6 +4034,7 @@ [["DenseMatrix", "double"]], [["DenseMatrix", "int64_t"]], [["DenseMatrix", "uint8_t"]], + [["DenseMatrix", "std::string"]], ["Frame"] ] }, @@ -4267,6 +4365,10 @@ { "type": "const char *", "name": "rhsOn" + }, + { + "type": "int64_t", + "name": "numRowRes" } ] }, diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index a8051d1b7..7b707e746 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -32,12 +32,10 @@ set(TEST_SOURCES api/cli/extensibility/HintTest.cpp api/cli/functions/FunctionsTest.cpp api/cli/functions/RecursiveFunctionsTest.cpp - api/cli/io/ReadTest.cpp - api/cli/io/WriteTest.cpp api/cli/import/ImportTest.cpp api/cli/indexing/IndexingTest.cpp api/cli/inference/InferenceTest.cpp - api/cli/io/ReadTest.cpp + api/cli/io/ReadWriteTest.cpp api/cli/lists/ListsTest.cpp api/cli/literals/LiteralsTest.cpp api/cli/operations/ConstantFoldingTest.cpp @@ -63,6 +61,7 @@ set(TEST_SOURCES api/cli/codegen/EwUnaryTest.cpp api/cli/codegen/EwOpLoopFusionTest.cpp api/cli/codegen/MapOpTest.cpp + api/cli/codegen/SparsityExploitTest.cpp api/cli/codegen/TransposeTest.cpp ir/daphneir/InferTypesTest.cpp diff --git a/test/api/cli/codegen/SparsityExploitTest.cpp b/test/api/cli/codegen/SparsityExploitTest.cpp new file mode 100644 index 000000000..1cb495507 --- /dev/null +++ b/test/api/cli/codegen/SparsityExploitTest.cpp @@ -0,0 +1,29 @@ +/* + * Copyright 2024 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include + +const std::string dirPath = "test/api/cli/codegen/"; + +TEST_CASE("sparsity exploit (sparse-dense cross-entropy)", TAG_CODEGEN) { + std::string result = "1.61078\n"; + compareDaphneToStr(result, dirPath + "sparsityExploit.daphne", "--select-matrix-repr"); + compareDaphneToStr(result, dirPath + "sparsityExploit.daphne", "--select-matrix-repr", "--mlir-codegen"); +} diff --git a/test/api/cli/codegen/sparsityExploit.daphne b/test/api/cli/codegen/sparsityExploit.daphne new file mode 100644 index 000000000..eb57018be --- /dev/null +++ b/test/api/cli/codegen/sparsityExploit.daphne @@ -0,0 +1,5 @@ +sparseLhs = rand(15, 10, 0.0, 1.0, 0.1, 1); // size: 15x10 range: (0,1) sparsity: 0.1 +DenseU = rand(15, 5, 0.0, 1.0, 1.0, 2); // size: 15x5 range: (0,1) sparsity: 1 +DenseV = rand(10, 5, 0.0, 1.0, 1.0, 3); // size: 10x5 range: (0,1) sparsity: 1 + +print(sum(sparseLhs * ln(DenseU @ t(DenseV)))); diff --git a/test/api/cli/extensibility/HintTest.cpp b/test/api/cli/extensibility/HintTest.cpp index 94de4c6f1..cf120c142 100644 --- a/test/api/cli/extensibility/HintTest.cpp +++ b/test/api/cli/extensibility/HintTest.cpp @@ -58,7 +58,7 @@ MAKE_FAILURE_TEST_CASE("hint_kernel", 3) // Check if DAPHNE terminates normally when expected and produces the expected // output. -MAKE_SUCCESS_TEST_CASE("hint_kernel", 3) +MAKE_SUCCESS_TEST_CASE("hint_kernel", 5) // Check if DAPHNE terminates normally when expected and if the IR really // contains the kernel hint. diff --git a/test/api/cli/extensibility/hint_kernel_success_1.daphne b/test/api/cli/extensibility/hint_kernel_success_1.daphne index bc235a592..f3c4a73a2 100644 --- a/test/api/cli/extensibility/hint_kernel_success_1.daphne +++ b/test/api/cli/extensibility/hint_kernel_success_1.daphne @@ -1,3 +1,3 @@ -// Hint to use an existing pre-compiled kernel for a DaphneIR op with exactly zero results. +// Hint to use an existing pre-compiled kernel for a DaphneIR op with exactly zero results (DaphneDSL built-in function). print::_print__int64_t__bool__bool(42); \ No newline at end of file diff --git a/test/api/cli/extensibility/hint_kernel_success_2.daphne b/test/api/cli/extensibility/hint_kernel_success_2.daphne index 5bff4eb99..3cd1a27c8 100644 --- a/test/api/cli/extensibility/hint_kernel_success_2.daphne +++ b/test/api/cli/extensibility/hint_kernel_success_2.daphne @@ -1,4 +1,4 @@ -// Hint to use an existing pre-compiled kernel for a DaphneIR op with exactly one result. +// Hint to use an existing pre-compiled kernel for a DaphneIR op with exactly one result (DaphneDSL built-in function). res = sum::_sumAll__int64_t__DenseMatrix_int64_t([21, 21]); print(res); \ No newline at end of file diff --git a/test/api/cli/extensibility/hint_kernel_success_3.daphne b/test/api/cli/extensibility/hint_kernel_success_3.daphne index f17fa1353..165b5ed47 100644 --- a/test/api/cli/extensibility/hint_kernel_success_3.daphne +++ b/test/api/cli/extensibility/hint_kernel_success_3.daphne @@ -1,4 +1,4 @@ -// Hint to use an existing pre-compiled kernel for a DaphneIR op with more than one result. +// Hint to use an existing pre-compiled kernel for a DaphneIR op with more than one result (DaphneDSL built-in function). codes, dict = recode::_recode__DenseMatrix_int64_t__DenseMatrix_double__DenseMatrix_double__bool([1.1, 3.3, 1.1, 2.2], false); print(codes); diff --git a/test/api/cli/extensibility/hint_kernel_success_4.daphne b/test/api/cli/extensibility/hint_kernel_success_4.daphne new file mode 100644 index 000000000..40be8637c --- /dev/null +++ b/test/api/cli/extensibility/hint_kernel_success_4.daphne @@ -0,0 +1,4 @@ +// Hint to use an existing pre-compiled kernel for a DaphneIR op with exactly one result (DaphneDSL operator symbol, addExpr (binary +-)). + +res = [2] +::_ewAdd__DenseMatrix_int64_t__DenseMatrix_int64_t__DenseMatrix_int64_t [3]; +print(res); \ No newline at end of file diff --git a/test/api/cli/extensibility/hint_kernel_success_4.txt b/test/api/cli/extensibility/hint_kernel_success_4.txt new file mode 100644 index 000000000..6a5ef86f3 --- /dev/null +++ b/test/api/cli/extensibility/hint_kernel_success_4.txt @@ -0,0 +1,2 @@ +DenseMatrix(1x1, int64_t) +5 diff --git a/test/api/cli/extensibility/hint_kernel_success_5.daphne b/test/api/cli/extensibility/hint_kernel_success_5.daphne new file mode 100644 index 000000000..05444480e --- /dev/null +++ b/test/api/cli/extensibility/hint_kernel_success_5.daphne @@ -0,0 +1,4 @@ +// Hint to use an existing pre-compiled kernel for a DaphneIR op with exactly one result (DaphneDSL operator symbol, mulExpr (binary */)). + +res = [2] *::_ewMul__DenseMatrix_int64_t__DenseMatrix_int64_t__DenseMatrix_int64_t [3]; +print(res); \ No newline at end of file diff --git a/test/api/cli/extensibility/hint_kernel_success_5.txt b/test/api/cli/extensibility/hint_kernel_success_5.txt new file mode 100644 index 000000000..4d02485d0 --- /dev/null +++ b/test/api/cli/extensibility/hint_kernel_success_5.txt @@ -0,0 +1,2 @@ +DenseMatrix(1x1, int64_t) +6 diff --git a/test/api/cli/io/.gitignore b/test/api/cli/io/.gitignore deleted file mode 100644 index 32aa93035..000000000 --- a/test/api/cli/io/.gitignore +++ /dev/null @@ -1,4 +0,0 @@ -matrix_full.csv -matrix_full.csv.meta -matrix_view.csv -matrix_view.csv.meta \ No newline at end of file diff --git a/test/api/cli/io/ReadCsv1.csv b/test/api/cli/io/ReadCsv1.csv deleted file mode 100644 index 79d814f7b..000000000 --- a/test/api/cli/io/ReadCsv1.csv +++ /dev/null @@ -1,2 +0,0 @@ --0.1,-0.2,0.1,0.2 -3.14,5.41,6.22216,5 diff --git a/test/api/cli/io/ReadCsv1.csv.meta b/test/api/cli/io/ReadCsv1.csv.meta deleted file mode 100644 index 82073032e..000000000 --- a/test/api/cli/io/ReadCsv1.csv.meta +++ /dev/null @@ -1,6 +0,0 @@ -{ - "numRows": 2, - "numCols": 4, - "valueType": "f64", - "numNonZeros": 0 -} \ No newline at end of file diff --git a/test/api/cli/io/ReadCsv2.csv b/test/api/cli/io/ReadCsv2.csv deleted file mode 100644 index e09ff2e54..000000000 --- a/test/api/cli/io/ReadCsv2.csv +++ /dev/null @@ -1,10 +0,0 @@ -"banana, grape",36,Fruit Basket -"xyz""uvw",34,"No Category\"" -"parrot, rabbit",31,Pets -"line1 -line2",26,"with newline" -chair,28,Furniture Set -"green, yellow\n",51, -"\n\"xyz""uvw\"",29,"Mixed string" -"\"blue, \"\"",42,"" -"unknown, item",23,Unknown Item diff --git a/test/api/cli/io/ReadCsv3.csv b/test/api/cli/io/ReadCsv3.csv deleted file mode 100644 index 81fa66d93..000000000 --- a/test/api/cli/io/ReadCsv3.csv +++ /dev/null @@ -1,4 +0,0 @@ -1,-1,"" -2,-2, -3,-3,"multi-line," -4,-4,simple string diff --git a/test/api/cli/io/ReadCsv3.csv.meta b/test/api/cli/io/ReadCsv3.csv.meta deleted file mode 100644 index fc14ce73d..000000000 --- a/test/api/cli/io/ReadCsv3.csv.meta +++ /dev/null @@ -1,18 +0,0 @@ -{ - "numRows": 4, - "numCols": 3, - "schema": [ - { - "label": "a", - "valueType": "si8" - }, - { - "label": "b", - "valueType": "ui8" - }, - { - "label": "c", - "valueType": "str" - } - ] -} diff --git a/test/api/cli/io/ReadTest.cpp b/test/api/cli/io/ReadTest.cpp deleted file mode 100644 index 885ce81bd..000000000 --- a/test/api/cli/io/ReadTest.cpp +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright 2021 The DAPHNE Consortium - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include - -#include - -#include -#include - -const std::string dirPath = "test/api/cli/io/"; - -// TODO Add script-level test cases for reading files in various formats. -// This test case used to read a COO file, but being a quick fix, it was not -// integrated cleanly. There should either be a seprate reader for COO, or we -// should do that via the respective Matrix Market format. -#if 0 -TEST_CASE("readSparse", TAG_IO) { - auto arg = "filename=\"" + dirPath + "readSparse.coo\""; - compareDaphneToRef(dirPath + "readSparse.txt", - dirPath + "readSparse.daphne", - "--select-matrix-representations", - "--args", - arg.c_str()); -} -#endif - -TEST_CASE("readFrameFromCSV", TAG_IO) { - compareDaphneToRef(dirPath + "testReadFrame.txt", dirPath + "testReadFrame.daphne"); -} - -TEST_CASE("readStringValuesIntoFrameFromCSV", TAG_IO) { - compareDaphneToRef(dirPath + "testReadStringIntoFrame.txt", dirPath + "testReadStringIntoFrame.daphne"); -} - -TEST_CASE("readMatrixFromCSV", TAG_IO) { - compareDaphneToRef(dirPath + "testReadMatrix.txt", dirPath + "testReadMatrix.daphne"); -} - -TEST_CASE("readStringMatrixFromCSV", TAG_IO) { - compareDaphneToRef(dirPath + "testReadStringMatrix.txt", dirPath + "testReadStringMatrix.daphne"); -} - -// does not yet work! -// TEST_CASE("readReadMatrixFromCSV_DynamicPath", TAG_IO) -// { -// compareDaphneToRef(dirPath + "testReadMatrix.txt", dirPath + -// "testReadMatrix_DynamicPath.daphne"); -// } \ No newline at end of file diff --git a/test/api/cli/io/ReadWriteTest.cpp b/test/api/cli/io/ReadWriteTest.cpp new file mode 100644 index 000000000..e782ac80d --- /dev/null +++ b/test/api/cli/io/ReadWriteTest.cpp @@ -0,0 +1,111 @@ +/* + * Copyright 2024 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include + +#include + +#include +#include + +const std::string dirPath = "test/api/cli/io/"; + +// ******************************************************************************** +// Read test cases +// ******************************************************************************** + +// These test cases check if the CSV reader reads the expected data from a CSV file. +// The data is chosen to trigger various interesting cases. The tests perform the following steps: +// - read a matrix/frame from a specified CSV file +// - compare it to a matrix/frame hard-coded in a DaphneDSL script (nan-safe for floating-point value types) +// - check if the number of incorrect elements is zero +#define MAKE_READ_TEST_CASE(dt, name, nanSafe) \ + TEST_CASE("read_" dt "_" name, TAG_IO) { \ + const std::string scriptPath = dirPath + "read/read_" + dt + "_" + name + ".daphne"; \ + const std::string inPath = dirPath + "ref/" + dt + "_" + name + "_ref.csv"; \ + compareDaphneToStr("0\n", scriptPath.c_str(), "--args", \ + ("inPath=\"" + inPath + "\",nanSafe=" + nanSafe).c_str()); \ + } + +MAKE_READ_TEST_CASE("matrix", "si64", "false") +MAKE_READ_TEST_CASE("matrix", "f64", "true") +MAKE_READ_TEST_CASE("matrix", "str", "false") +MAKE_READ_TEST_CASE("frame", "mixed-no-str", "false") +MAKE_READ_TEST_CASE("frame", "mixed-str", "false") + +// These test cases check if the CSV reader can be used in various relevant ways in DaphneDSL. +// The data is rather simple. The tests perform the following steps: +// - read a matrix/frame from a specified CSV file +// - compare it to a matrix/frame hard-coded in a DaphneDSL script +// - check if the number of incorrect elements is zero +#define MAKE_READ_TEST_CASE_2(scriptName) \ + TEST_CASE("read_" scriptName ".daphne", TAG_IO) { \ + const std::string scriptPath = dirPath + "read/read_" + scriptName + ".daphne"; \ + compareDaphneToStr("0\n", scriptPath.c_str()); \ + } + +// TODO The commented out test cases don't work yet (see #931). +MAKE_READ_TEST_CASE_2("matrix_read-in-udf") +MAKE_READ_TEST_CASE_2("matrix_dynamic-path-1") +// MAKE_READ_TEST_CASE_2("matrix_dynamic-path-2") +// MAKE_READ_TEST_CASE_2("matrix_dynamic-path-3") +MAKE_READ_TEST_CASE_2("frame_read-in-udf") +MAKE_READ_TEST_CASE_2("frame_dynamic-path-1") +// MAKE_READ_TEST_CASE_2("frame_dynamic-path-2") +// MAKE_READ_TEST_CASE_2("frame_dynamic-path-3") + +// ******************************************************************************** +// Write test cases +// ******************************************************************************** + +// These test cases check if the CSV writer produces the expected CSV files. +// The data is chosen to trigger various interesting cases. The tests perform the following steps: +// - start with matrix/frame hard-coded in a DaphneDSL script +// - write it to a CSV file +// - don't compare the file contents to a reference file, as there can be multiple equivalent CSV representations, +// because +// - any field may be quoted, but only certain fields need to be quoted +// - there could be trailing zeros in floating-point numbers +// - there are different valid notations for floating-point numbers (e.g., 0.001 vs 1e-3) +// - there are different valid notations for integers (e.g., 32 vs 0x20) +// - ... +// - instead, read the written file and a reference CSV file and compare the two read matrices/frames in DaphneDSL +// - note: these tests don't check if the matrix/frame read from the reference CSV file looks as expected (that's done +// by the read tests) +#define MAKE_WRITE_TEST_CASE(dt, name, nanSafe) \ + TEST_CASE("write_" dt "_" name, TAG_IO) { \ + const std::string scriptPathWrt = dirPath + "write/write_" + dt + "_" + name + ".daphne"; \ + const std::string scriptPathCmp = dirPath + "do_check_" + dt + ".daphne"; \ + const std::string outPath = dirPath + "out/" + dt + "_" + name + ".csv"; \ + const std::string refPath = dirPath + "ref/" + dt + "_" + name + "_ref.csv"; \ + std::filesystem::remove(outPath); /* remove old output file if it still exists */ \ + checkDaphneStatusCode(StatusCode::SUCCESS, scriptPathWrt.c_str(), "--args", \ + ("outPath=\"" + outPath + "\"").c_str()); \ + /* TODO REQUIRE() the status code to be success (don't only CHECK() it), because in case of failure, */ \ + /* the next check doesn't make sense and might produce a misleading output. */ \ + compareDaphneToStr("0\n", scriptPathCmp.c_str(), "--args", \ + ("chkPath=\"" + outPath + "\",refPath=\"" + refPath + "\",nanSafe=" + nanSafe).c_str()); \ + } + +// TODO The commented out test cases don't work yet, as the CSV writer doesn't support strings yet. +MAKE_WRITE_TEST_CASE("matrix", "si64", "false") +MAKE_WRITE_TEST_CASE("matrix", "f64", "true") +MAKE_WRITE_TEST_CASE("matrix", "str", "false") +MAKE_WRITE_TEST_CASE("matrix", "view", "false") +MAKE_WRITE_TEST_CASE("frame", "mixed-no-str", "false") +MAKE_WRITE_TEST_CASE("frame", "mixed-str", "false") \ No newline at end of file diff --git a/test/api/cli/io/WriteTest.cpp b/test/api/cli/io/WriteTest.cpp deleted file mode 100644 index b58b8ec09..000000000 --- a/test/api/cli/io/WriteTest.cpp +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Copyright 2024 The DAPHNE Consortium - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include - -#include - -#include -#include - -const std::string dirPath = "test/api/cli/io/"; - -TEST_CASE("writeMatrixCSV_Full", TAG_IO) { - std::string csvPath = dirPath + "matrix_full.csv"; - std::filesystem::remove(csvPath); // remove old file if it still exists - checkDaphneStatusCode(StatusCode::SUCCESS, dirPath + "writeMatrix_full.daphne", "--args", - std::string("outPath=\"" + csvPath + "\"").c_str()); - compareDaphneToRef(dirPath + "matrix_full_ref.csv", dirPath + "readMatrix.daphne", "--args", - std::string("inPath=\"" + csvPath + "\"").c_str()); -} - -TEST_CASE("writeMatrixCSV_View", TAG_IO) { - std::string csvPath = dirPath + "matrix_view.csv"; - std::filesystem::remove(csvPath); // remove old file if it still exists - checkDaphneStatusCode(StatusCode::SUCCESS, dirPath + "writeMatrix_view.daphne", "--args", - std::string("outPath=\"" + csvPath + "\"").c_str()); - compareDaphneToRef(dirPath + "matrix_view_ref.csv", dirPath + "readMatrix.daphne", "--args", - std::string("inPath=\"" + csvPath + "\"").c_str()); -} \ No newline at end of file diff --git a/test/api/cli/io/check_frame.daphne b/test/api/cli/io/check_frame.daphne new file mode 100644 index 000000000..8084373ee --- /dev/null +++ b/test/api/cli/io/check_frame.daphne @@ -0,0 +1,53 @@ +# Copyright 2024 The DAPHNE Consortium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Counts in how many corresponding elements the two given frames differ. +# If nanSafe is true, then two corresponding nan elements are considered equal. +# If there are differences, the two input frames are printed for debug information. +def checkFrame(chk, ref, nanSafe:bool) { + if(nanSafe) + stop("nan-safe comparison of frames is not supported yet"); + + if(ncol(chk) != ncol(ref)) + stop("the two input frames have different #cols: found " + ncol(chk) + " but expected " + ncol(ref)); + + numDiff = 0; + + # TODO Doesn't work (missing elementwise binary ops on two frames, see #932). + # numDiff = numDiff + sum(as.matrix(chk != ref)); + + # TODO Doesn't work, because the type of the c-th column isn't known at compile-time. + # for(c in 0:ncol(chk)-1) { + # colChk = as.matrix(chk[, c]); + # colRef = as.matrix(ref[, c]); + # numDiff = numDiff + sum(as.matrix(colChk != colRef)); + # } + + # TODO The following code only works for a hard-coded number of columns. + if(ncol(chk) != 3) + stop("this script only works for frames with exactly three columns"); + numDiff = numDiff + sum(as.si64(as.matrix(chk[, 0]) != as.matrix(ref[, 0]))); + numDiff = numDiff + sum(as.si64(as.matrix(chk[, 1]) != as.matrix(ref[, 1]))); + numDiff = numDiff + sum(as.si64(as.matrix(chk[, 2]) != as.matrix(ref[, 2]))); + + print(numDiff); + + # Debug output. + if(numDiff > 0) { + print("chk"); + print(chk); + print("ref"); + print(ref); + } +} \ No newline at end of file diff --git a/test/api/cli/io/check_matrix.daphne b/test/api/cli/io/check_matrix.daphne new file mode 100644 index 000000000..6ed473dca --- /dev/null +++ b/test/api/cli/io/check_matrix.daphne @@ -0,0 +1,36 @@ +# Copyright 2024 The DAPHNE Consortium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Counts in how many corresponding elements the two given matrices differ. +# If nanSafe is true, then two corresponding nan elements are considered equal. +# If there are differences, the two input matrices are printed for debug information. +def checkMatrix(chk, ref, nanSafe:bool) { + numDiff = 0; + if(nanSafe) + # The corresponding values are not equal and at least one of them is not nan (they're not both nan). + numDiff = sum(as.matrix(chk != ref) && (as.matrix(isNan(chk)) == 0 || as.matrix(isNan(ref)) == 0)); + else + # The corresponding values are not equal. + numDiff = sum(as.matrix(chk != ref)); + + print(numDiff); + + # Debug output. + if(numDiff > 0) { + print("chk"); + print(chk); + print("ref"); + print(ref); + } +} \ No newline at end of file diff --git a/test/api/cli/io/do_check_frame.daphne b/test/api/cli/io/do_check_frame.daphne new file mode 100644 index 000000000..75bbc3d71 --- /dev/null +++ b/test/api/cli/io/do_check_frame.daphne @@ -0,0 +1,21 @@ +# Copyright 2024 The DAPHNE Consortium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Reads the files $chkPath and $refPath and compares the two read frames. + +import "check_frame.daphne"; + +chk = readFrame($chkPath); +ref = readFrame($refPath); +check_frame.checkFrame(chk, ref, $nanSafe); \ No newline at end of file diff --git a/test/api/cli/io/do_check_matrix.daphne b/test/api/cli/io/do_check_matrix.daphne new file mode 100644 index 000000000..578017b9f --- /dev/null +++ b/test/api/cli/io/do_check_matrix.daphne @@ -0,0 +1,21 @@ +# Copyright 2024 The DAPHNE Consortium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Reads the files $chkPath and $refPath and compares the two read matrices. + +import "check_matrix.daphne"; + +chk = readMatrix($chkPath); +ref = readMatrix($refPath); +check_matrix.checkMatrix(chk, ref, $nanSafe); \ No newline at end of file diff --git a/test/api/cli/io/matrix_full_ref.csv b/test/api/cli/io/matrix_full_ref.csv deleted file mode 100644 index b828190bc..000000000 --- a/test/api/cli/io/matrix_full_ref.csv +++ /dev/null @@ -1,4 +0,0 @@ -DenseMatrix(3x2, int64_t) -1 2 -3 4 -5 6 diff --git a/test/api/cli/io/matrix_view_ref.csv b/test/api/cli/io/matrix_view_ref.csv deleted file mode 100644 index 767f21dd4..000000000 --- a/test/api/cli/io/matrix_view_ref.csv +++ /dev/null @@ -1,4 +0,0 @@ -DenseMatrix(3x1, int64_t) -2 -4 -6 diff --git a/test/api/cli/io/out/.gitignore b/test/api/cli/io/out/.gitignore new file mode 100644 index 000000000..9e06bd6c2 --- /dev/null +++ b/test/api/cli/io/out/.gitignore @@ -0,0 +1,2 @@ +*.csv +*.csv.meta \ No newline at end of file diff --git a/test/api/cli/io/read/read_frame_dynamic-path-1.daphne b/test/api/cli/io/read/read_frame_dynamic-path-1.daphne new file mode 100644 index 000000000..ade414fde --- /dev/null +++ b/test/api/cli/io/read/read_frame_dynamic-path-1.daphne @@ -0,0 +1,7 @@ +# Read a frame from a file when the file path is the result of an expression (only string concat). + +import "../check_frame.daphne"; + +chk = readFrame("test/api/cli/io/ref/frame_" + "123" + "_ref.csv"); +ref = {"a": [1], "b": [2], "c": [3]}; +check_frame.checkFrame(chk, ref, false); \ No newline at end of file diff --git a/test/api/cli/io/read/read_frame_dynamic-path-2.daphne b/test/api/cli/io/read/read_frame_dynamic-path-2.daphne new file mode 100644 index 000000000..03666a93c --- /dev/null +++ b/test/api/cli/io/read/read_frame_dynamic-path-2.daphne @@ -0,0 +1,7 @@ +# Read a frame from a file when the file path is the result of an expression (string/number concat). + +import "../check_frame.daphne"; + +chk = readFrame("test/api/cli/io/ref/frame_" + 123 + "_ref.csv"); +ref = {"a": [1], "b": [2], "c": [3]}; +check_frame.checkFrame(chk, ref, false); \ No newline at end of file diff --git a/test/api/cli/io/read/read_frame_dynamic-path-3.daphne b/test/api/cli/io/read/read_frame_dynamic-path-3.daphne new file mode 100644 index 000000000..94a8623cd --- /dev/null +++ b/test/api/cli/io/read/read_frame_dynamic-path-3.daphne @@ -0,0 +1,7 @@ +# Read a frame from a file when the file path is the result of an expression (string/casted number concat). + +import "../check_frame.daphne"; + +chk = readFrame("test/api/cli/io/ref/frame_" + as.str(123) + "_ref.csv"); +ref = {"a": [1], "b": [2], "c": [3]}; +check_frame.checkFrame(chk, ref, false); \ No newline at end of file diff --git a/test/api/cli/io/read/read_frame_mixed-no-str.daphne b/test/api/cli/io/read/read_frame_mixed-no-str.daphne new file mode 100644 index 000000000..c04df7052 --- /dev/null +++ b/test/api/cli/io/read/read_frame_mixed-no-str.daphne @@ -0,0 +1,11 @@ +# Read a frame with columns of various value types (not including string) from a file. + +import "../check_frame.daphne"; + +# TODO Test nan values. + +chk = readFrame($inPath); +ref = {"c_si64": [0, 1, -22, 3, -44], + "c_f64": [0.0, 1.1, -22.2, inf, -inf], + "c_si64_": [1000, 2000, 3000, 4000, 5000]}; +check_frame.checkFrame(chk, ref, $nanSafe); \ No newline at end of file diff --git a/test/api/cli/io/read/read_frame_mixed-str.daphne b/test/api/cli/io/read/read_frame_mixed-str.daphne new file mode 100644 index 000000000..d75f2668a --- /dev/null +++ b/test/api/cli/io/read/read_frame_mixed-str.daphne @@ -0,0 +1,11 @@ +# Read a frame with columns of various value types (including string) from a file. + +import "../check_frame.daphne"; + +# TODO test nan + +chk = readFrame($inPath); +ref = {"c_si64": [0, 1, -22, 3, -44], + "c_f64": [0.0, 1.1, -22.2, inf, -inf], + "c_str": ["abc", "", "d\"e", "fg\nhi", "mn,op"]}; +check_frame.checkFrame(chk, ref, $nanSafe); \ No newline at end of file diff --git a/test/api/cli/io/read/read_frame_read-in-udf.daphne b/test/api/cli/io/read/read_frame_read-in-udf.daphne new file mode 100644 index 000000000..da3e34fcd --- /dev/null +++ b/test/api/cli/io/read/read_frame_read-in-udf.daphne @@ -0,0 +1,13 @@ +# Read a frame from a file when the file path is a parameter to a UDF. + +import "../check_frame.daphne"; + +# TODO Test nan values. + +def myReadFrame(path:str) { + return readFrame(path); +} + +chk = myReadFrame("test/api/cli/io/ref/frame_123_ref.csv"); +ref = {"a": [1], "b": [2], "c": [3]}; +check_frame.checkFrame(chk, ref, false); \ No newline at end of file diff --git a/test/api/cli/io/read/read_matrix_dynamic-path-1.daphne b/test/api/cli/io/read/read_matrix_dynamic-path-1.daphne new file mode 100644 index 000000000..69fb87368 --- /dev/null +++ b/test/api/cli/io/read/read_matrix_dynamic-path-1.daphne @@ -0,0 +1,7 @@ +# Read a matrix from a file when the file path is the result of an expression (only string concat). + +import "../check_matrix.daphne"; + +chk = readMatrix("test/api/cli/io/ref/matrix_" + "123" + "_ref.csv"); +ref = [1]; +check_matrix.checkMatrix(chk, ref, false); \ No newline at end of file diff --git a/test/api/cli/io/read/read_matrix_dynamic-path-2.daphne b/test/api/cli/io/read/read_matrix_dynamic-path-2.daphne new file mode 100644 index 000000000..551a36713 --- /dev/null +++ b/test/api/cli/io/read/read_matrix_dynamic-path-2.daphne @@ -0,0 +1,7 @@ +# Read a matrix from a file when the file path is the result of an expression (string/number concat). + +import "../check_matrix.daphne"; + +chk = readMatrix("test/api/cli/io/ref/matrix_" + 123 + "_ref.csv"); +ref = [1]; +check_matrix.checkMatrix(chk, ref, false); \ No newline at end of file diff --git a/test/api/cli/io/read/read_matrix_dynamic-path-3.daphne b/test/api/cli/io/read/read_matrix_dynamic-path-3.daphne new file mode 100644 index 000000000..99736cd5a --- /dev/null +++ b/test/api/cli/io/read/read_matrix_dynamic-path-3.daphne @@ -0,0 +1,7 @@ +# Read a matrix from a file when the file path is the result of an expression (string/casted number concat). + +import "../check_matrix.daphne"; + +chk = readMatrix("test/api/cli/io/ref/matrix_" + as.str(123) + "_ref.csv"); +ref = [1]; +check_matrix.checkMatrix(chk, ref, false); \ No newline at end of file diff --git a/test/api/cli/io/read/read_matrix_f64.daphne b/test/api/cli/io/read/read_matrix_f64.daphne new file mode 100644 index 000000000..7c551f164 --- /dev/null +++ b/test/api/cli/io/read/read_matrix_f64.daphne @@ -0,0 +1,8 @@ +# Read a matrix of value type f64 from a file. + +import "../check_matrix.daphne"; + +chk = readMatrix($inPath); +ref = [1.1, -22.2, 3.3, -44.4, 5.5, + -66.6, 0.0, nan, inf, -inf](2, 5); +check_matrix.checkMatrix(chk, ref, $nanSafe); \ No newline at end of file diff --git a/test/api/cli/io/read/read_matrix_read-in-udf.daphne b/test/api/cli/io/read/read_matrix_read-in-udf.daphne new file mode 100644 index 000000000..50dc821fb --- /dev/null +++ b/test/api/cli/io/read/read_matrix_read-in-udf.daphne @@ -0,0 +1,11 @@ +# Read a matrix from a file when the file path is a parameter to a UDF. + +import "../check_matrix.daphne"; + +def myReadMatrix(path:str) { + return readMatrix(path); +} + +chk = myReadMatrix("test/api/cli/io/ref/matrix_123_ref.csv"); +ref = [1]; +check_matrix.checkMatrix(chk, ref, false); \ No newline at end of file diff --git a/test/api/cli/io/read/read_matrix_si64.daphne b/test/api/cli/io/read/read_matrix_si64.daphne new file mode 100644 index 000000000..aa412bf73 --- /dev/null +++ b/test/api/cli/io/read/read_matrix_si64.daphne @@ -0,0 +1,8 @@ +# Read a matrix of value type si64 from a file. + +import "../check_matrix.daphne"; + +chk = readMatrix($inPath); +ref = [1, -22, 3, -44, + 5, -66, 0, 0](2, 4); +check_matrix.checkMatrix(chk, ref, $nanSafe); \ No newline at end of file diff --git a/test/api/cli/io/read/read_matrix_str.daphne b/test/api/cli/io/read/read_matrix_str.daphne new file mode 100644 index 000000000..27a02a8a4 --- /dev/null +++ b/test/api/cli/io/read/read_matrix_str.daphne @@ -0,0 +1,8 @@ +# Read a matrix of value type str from a file. + +import "../check_matrix.daphne"; + +chk = readMatrix($inPath); +ref = ["abc", "", "d\"e", + "fg\nhi", "jkl", "mn,op"](2, 3); +check_matrix.checkMatrix(chk, ref, $nanSafe); \ No newline at end of file diff --git a/test/api/cli/io/readMatrix.daphne b/test/api/cli/io/readMatrix.daphne deleted file mode 100644 index c47563c51..000000000 --- a/test/api/cli/io/readMatrix.daphne +++ /dev/null @@ -1,4 +0,0 @@ -// Read a matrix from the file $inPath and print it. - -X = readMatrix($inPath); -print(X); \ No newline at end of file diff --git a/test/api/cli/io/readSparse.daphne b/test/api/cli/io/readSparse.daphne deleted file mode 100644 index 578ece638..000000000 --- a/test/api/cli/io/readSparse.daphne +++ /dev/null @@ -1,2 +0,0 @@ -M = readMatrix($filename); -print(M); \ No newline at end of file diff --git a/test/api/cli/io/readSparse.txt b/test/api/cli/io/readSparse.txt deleted file mode 100644 index a6ed7d416..000000000 --- a/test/api/cli/io/readSparse.txt +++ /dev/null @@ -1,11 +0,0 @@ -CSRMatrix(10x5, double) -1 0 0 0 0 -0 0 0 0 0 -0 0 0 0 0 -0 0 0 0 0 -0 0 0 0 0 -0 1 1 0 0 -0 0 0 0 0 -0 0 0 0 0 -0 0 0 0 0 -0 0 0 0 1 diff --git a/test/api/cli/io/ref/frame_123_ref.csv b/test/api/cli/io/ref/frame_123_ref.csv new file mode 100644 index 000000000..b0246d596 --- /dev/null +++ b/test/api/cli/io/ref/frame_123_ref.csv @@ -0,0 +1 @@ +1,2,3 diff --git a/test/api/cli/io/ref/frame_123_ref.csv.meta b/test/api/cli/io/ref/frame_123_ref.csv.meta new file mode 100644 index 000000000..6a5eb1cc3 --- /dev/null +++ b/test/api/cli/io/ref/frame_123_ref.csv.meta @@ -0,0 +1,9 @@ +{ + "numRows": 1, + "numCols": 3, + "schema": [ + {"label": "a", "valueType": "si64"}, + {"label": "b", "valueType": "si64"}, + {"label": "c", "valueType": "si64"} + ] +} \ No newline at end of file diff --git a/test/api/cli/io/ref/frame_mixed-no-str_ref.csv b/test/api/cli/io/ref/frame_mixed-no-str_ref.csv new file mode 100644 index 000000000..aa53316e6 --- /dev/null +++ b/test/api/cli/io/ref/frame_mixed-no-str_ref.csv @@ -0,0 +1,5 @@ +0,0.0,1000 +1,1.1,2000 +-22,-22.2,3000 +3,inf,4000 +-44,-inf,5000 diff --git a/test/api/cli/io/ref/frame_mixed-no-str_ref.csv.meta b/test/api/cli/io/ref/frame_mixed-no-str_ref.csv.meta new file mode 100644 index 000000000..ca2b6b377 --- /dev/null +++ b/test/api/cli/io/ref/frame_mixed-no-str_ref.csv.meta @@ -0,0 +1,9 @@ +{ + "numRows": 5, + "numCols": 3, + "schema": [ + {"label": "c_si64", "valueType": "si64"}, + {"label": "c_f64", "valueType": "f64"}, + {"label": "c_si64_", "valueType": "si64"} + ] +} \ No newline at end of file diff --git a/test/api/cli/io/ref/frame_mixed-str_ref.csv b/test/api/cli/io/ref/frame_mixed-str_ref.csv new file mode 100644 index 000000000..f55d63405 --- /dev/null +++ b/test/api/cli/io/ref/frame_mixed-str_ref.csv @@ -0,0 +1,6 @@ +0,0.0,abc +1,1.1, +-22,-22.2,"d""e" +3,inf,"fg +hi" +-44,-inf,"mn,op" diff --git a/test/api/cli/io/ref/frame_mixed-str_ref.csv.meta b/test/api/cli/io/ref/frame_mixed-str_ref.csv.meta new file mode 100644 index 000000000..986e25f0c --- /dev/null +++ b/test/api/cli/io/ref/frame_mixed-str_ref.csv.meta @@ -0,0 +1,9 @@ +{ + "numRows": 5, + "numCols": 3, + "schema": [ + {"label": "c_si64", "valueType": "si64"}, + {"label": "c_f64", "valueType": "f64"}, + {"label": "c_str", "valueType": "str"} + ] +} \ No newline at end of file diff --git a/test/api/cli/io/ref/matrix_123_ref.csv b/test/api/cli/io/ref/matrix_123_ref.csv new file mode 100644 index 000000000..d00491fd7 --- /dev/null +++ b/test/api/cli/io/ref/matrix_123_ref.csv @@ -0,0 +1 @@ +1 diff --git a/test/api/cli/io/ref/matrix_123_ref.csv.meta b/test/api/cli/io/ref/matrix_123_ref.csv.meta new file mode 100644 index 000000000..97ad56a68 --- /dev/null +++ b/test/api/cli/io/ref/matrix_123_ref.csv.meta @@ -0,0 +1,5 @@ +{ + "numRows": 1, + "numCols": 1, + "valueType": "si64" +} \ No newline at end of file diff --git a/test/api/cli/io/ref/matrix_f64_ref.csv b/test/api/cli/io/ref/matrix_f64_ref.csv new file mode 100644 index 000000000..f0205d55d --- /dev/null +++ b/test/api/cli/io/ref/matrix_f64_ref.csv @@ -0,0 +1,2 @@ +1.1,-22.2,3.3,-44.4,5.5 +-66.6,0.0,nan,inf,-inf diff --git a/test/api/cli/io/ref/matrix_f64_ref.csv.meta b/test/api/cli/io/ref/matrix_f64_ref.csv.meta new file mode 100644 index 000000000..d5b38231d --- /dev/null +++ b/test/api/cli/io/ref/matrix_f64_ref.csv.meta @@ -0,0 +1,5 @@ +{ + "numRows": 2, + "numCols": 5, + "valueType": "f64" +} \ No newline at end of file diff --git a/test/api/cli/io/ref/matrix_si64_ref.csv b/test/api/cli/io/ref/matrix_si64_ref.csv new file mode 100644 index 000000000..332b71f5c --- /dev/null +++ b/test/api/cli/io/ref/matrix_si64_ref.csv @@ -0,0 +1,2 @@ +1,-22,3,-44 +5,-66,0,0 diff --git a/test/api/cli/io/ref/matrix_si64_ref.csv.meta b/test/api/cli/io/ref/matrix_si64_ref.csv.meta new file mode 100644 index 000000000..0b44c2898 --- /dev/null +++ b/test/api/cli/io/ref/matrix_si64_ref.csv.meta @@ -0,0 +1,5 @@ +{ + "numRows": 2, + "numCols": 4, + "valueType": "si64" +} \ No newline at end of file diff --git a/test/api/cli/io/ref/matrix_str_ref.csv b/test/api/cli/io/ref/matrix_str_ref.csv new file mode 100644 index 000000000..d18309603 --- /dev/null +++ b/test/api/cli/io/ref/matrix_str_ref.csv @@ -0,0 +1,3 @@ +abc,,"d""e" +"fg +hi","jkl","mn,op" diff --git a/test/api/cli/io/ReadCsv2.csv.meta b/test/api/cli/io/ref/matrix_str_ref.csv.meta similarity index 70% rename from test/api/cli/io/ReadCsv2.csv.meta rename to test/api/cli/io/ref/matrix_str_ref.csv.meta index 0fe44b980..4e96a97fc 100644 --- a/test/api/cli/io/ReadCsv2.csv.meta +++ b/test/api/cli/io/ref/matrix_str_ref.csv.meta @@ -1,5 +1,5 @@ { - "numRows": 9, + "numRows": 2, "numCols": 3, "valueType": "str" } \ No newline at end of file diff --git a/test/api/cli/io/ref/matrix_view_ref.csv b/test/api/cli/io/ref/matrix_view_ref.csv new file mode 100644 index 000000000..332b71f5c --- /dev/null +++ b/test/api/cli/io/ref/matrix_view_ref.csv @@ -0,0 +1,2 @@ +1,-22,3,-44 +5,-66,0,0 diff --git a/test/api/cli/io/ref/matrix_view_ref.csv.meta b/test/api/cli/io/ref/matrix_view_ref.csv.meta new file mode 100644 index 000000000..0b44c2898 --- /dev/null +++ b/test/api/cli/io/ref/matrix_view_ref.csv.meta @@ -0,0 +1,5 @@ +{ + "numRows": 2, + "numCols": 4, + "valueType": "si64" +} \ No newline at end of file diff --git a/test/api/cli/io/testReadFrame.daphne b/test/api/cli/io/testReadFrame.daphne deleted file mode 100644 index 59bde6fe6..000000000 --- a/test/api/cli/io/testReadFrame.daphne +++ /dev/null @@ -1,6 +0,0 @@ -# Test reading from a file when the file path is not trivially constant (i.e., a parameter to a UDF) -def readFrameFromCSV(path: str) { - print(readFrame(path)); -} - -readFrameFromCSV("test/api/cli/io/ReadCsv1.csv"); \ No newline at end of file diff --git a/test/api/cli/io/testReadFrame.txt b/test/api/cli/io/testReadFrame.txt deleted file mode 100644 index e5af505cc..000000000 --- a/test/api/cli/io/testReadFrame.txt +++ /dev/null @@ -1,3 +0,0 @@ -Frame(2x4, [col_0:double, col_1:double, col_2:double, col_3:double]) --0.1 -0.2 0.1 0.2 -3.14 5.41 6.22216 5 diff --git a/test/api/cli/io/testReadMatrix.daphne b/test/api/cli/io/testReadMatrix.daphne deleted file mode 100644 index cefdb281c..000000000 --- a/test/api/cli/io/testReadMatrix.daphne +++ /dev/null @@ -1,6 +0,0 @@ -# Test reading from a file when the file path is not trivially constant (i.e., a parameter to a UDF) -def readMatrixFromCSV(path: str) { - print(readMatrix(path)); -} - -readMatrixFromCSV("test/api/cli/io/ReadCsv1.csv"); \ No newline at end of file diff --git a/test/api/cli/io/testReadMatrix.txt b/test/api/cli/io/testReadMatrix.txt deleted file mode 100644 index 26ed2a258..000000000 --- a/test/api/cli/io/testReadMatrix.txt +++ /dev/null @@ -1,3 +0,0 @@ -DenseMatrix(2x4, double) --0.1 -0.2 0.1 0.2 -3.14 5.41 6.22216 5 diff --git a/test/api/cli/io/testReadMatrix_DynamicPath.daphne b/test/api/cli/io/testReadMatrix_DynamicPath.daphne deleted file mode 100644 index 6c31fba84..000000000 --- a/test/api/cli/io/testReadMatrix_DynamicPath.daphne +++ /dev/null @@ -1,5 +0,0 @@ -# Test dynamic computation of string path -> does not yet work! -i = 1; -filename = "test/api/cli/io/ReadCsv" + i + ".csv"; -m = readMatrix(filename); -print(m); \ No newline at end of file diff --git a/test/api/cli/io/testReadStringIntoFrame.daphne b/test/api/cli/io/testReadStringIntoFrame.daphne deleted file mode 100644 index d2b457193..000000000 --- a/test/api/cli/io/testReadStringIntoFrame.daphne +++ /dev/null @@ -1,3 +0,0 @@ -// Test reading frame with string columns. - -print(readFrame("test/api/cli/io/ReadCsv3.csv")); \ No newline at end of file diff --git a/test/api/cli/io/testReadStringIntoFrame.txt b/test/api/cli/io/testReadStringIntoFrame.txt deleted file mode 100644 index 1cca28d7b..000000000 --- a/test/api/cli/io/testReadStringIntoFrame.txt +++ /dev/null @@ -1,5 +0,0 @@ -Frame(4x3, [a:int8_t, b:uint8_t, c:std::string]) -1 255 -2 254 -3 253 multi-line, -4 252 simple string diff --git a/test/api/cli/io/testReadStringMatrix.daphne b/test/api/cli/io/testReadStringMatrix.daphne deleted file mode 100644 index f3499212f..000000000 --- a/test/api/cli/io/testReadStringMatrix.daphne +++ /dev/null @@ -1,3 +0,0 @@ -// Test reading matrix of strings - -print(readMatrix("test/api/cli/io/ReadCsv2.csv")); \ No newline at end of file diff --git a/test/api/cli/io/testReadStringMatrix.txt b/test/api/cli/io/testReadStringMatrix.txt deleted file mode 100644 index fc4085f8c..000000000 --- a/test/api/cli/io/testReadStringMatrix.txt +++ /dev/null @@ -1,11 +0,0 @@ -DenseMatrix(9x3, std::string) -banana, grape 36 Fruit Basket -xyz""uvw 34 No Category\" -parrot, rabbit 31 Pets -line1 -line2 26 with newline -chair 28 Furniture Set -green, yellow\n 51 -\n\"xyz""uvw\" 29 Mixed string -\"blue, \"\" 42 -unknown, item 23 Unknown Item diff --git a/test/api/cli/io/write/write_frame_mixed-no-str.daphne b/test/api/cli/io/write/write_frame_mixed-no-str.daphne new file mode 100644 index 000000000..e7f1261e7 --- /dev/null +++ b/test/api/cli/io/write/write_frame_mixed-no-str.daphne @@ -0,0 +1,8 @@ +# Write a frame with columns of various value types (not including string) to a file. + +# TODO Test nan values. + +X = {"c_si64": [0, 1, -22, 3, -44], + "c_f64": [0.0, 1.1, -22.2, inf, -inf], + "c_si64_": [1000, 2000, 3000, 4000, 5000]}; +write(X, $outPath); \ No newline at end of file diff --git a/test/api/cli/io/write/write_frame_mixed-str.daphne b/test/api/cli/io/write/write_frame_mixed-str.daphne new file mode 100644 index 000000000..eed5a7f51 --- /dev/null +++ b/test/api/cli/io/write/write_frame_mixed-str.daphne @@ -0,0 +1,8 @@ +# Write a frame with columns of various value types (including string) to a file. + +# TODO Test nan values. + +X = {"c_si64": [0, 1, -22, 3, -44], + "c_f64": [0.0, 1.1, -22.2, inf, -inf], + "c_str": ["abc", "", "d\"e", "fg\nhi", "mn,op"]}; +write(X, $outPath); \ No newline at end of file diff --git a/test/api/cli/io/write/write_matrix_f64.daphne b/test/api/cli/io/write/write_matrix_f64.daphne new file mode 100644 index 000000000..69982599f --- /dev/null +++ b/test/api/cli/io/write/write_matrix_f64.daphne @@ -0,0 +1,5 @@ +# Write a matrix of value type f64 to a file. + +X = [1.1, -22.2, 3.3, -44.4, 5.5, + -66.6, 0.0, nan, inf, -inf](2, 5); +write(X, $outPath); \ No newline at end of file diff --git a/test/api/cli/io/write/write_matrix_si64.daphne b/test/api/cli/io/write/write_matrix_si64.daphne new file mode 100644 index 000000000..aaebe897b --- /dev/null +++ b/test/api/cli/io/write/write_matrix_si64.daphne @@ -0,0 +1,5 @@ +# Write a matrix of value type si64 to a file. + +X = [1, -22, 3, -44, + 5, -66, 0, 0](2, 4); +write(X, $outPath); \ No newline at end of file diff --git a/test/api/cli/io/write/write_matrix_str.daphne b/test/api/cli/io/write/write_matrix_str.daphne new file mode 100644 index 000000000..289052275 --- /dev/null +++ b/test/api/cli/io/write/write_matrix_str.daphne @@ -0,0 +1,5 @@ +# Write a matrix of value type str to a file. + +X = ["abc", "", "d\"e", + "fg\nhi", "jkl", "mn,op"](2, 3); +write(X, $outPath); \ No newline at end of file diff --git a/test/api/cli/io/write/write_matrix_view.daphne b/test/api/cli/io/write/write_matrix_view.daphne new file mode 100644 index 000000000..d5b1f55fb --- /dev/null +++ b/test/api/cli/io/write/write_matrix_view.daphne @@ -0,0 +1,6 @@ +# Write a matrix which is a view into another one to a file. + +X = [0, 1, -22, 3, -44, 0, + 0, 5, -66, 0, 0, 0](2, 6); +X = X[, 1:ncol(X)-1]; # omit the first and last column +write(X, $outPath); \ No newline at end of file diff --git a/test/api/cli/io/writeMatrix_full.daphne b/test/api/cli/io/writeMatrix_full.daphne deleted file mode 100644 index a66f44b22..000000000 --- a/test/api/cli/io/writeMatrix_full.daphne +++ /dev/null @@ -1,4 +0,0 @@ -// Write an entire small matrix to the file $outPath. - -X = reshape(seq(1, 6, 1), 3, 2); -write(X, $outPath); \ No newline at end of file diff --git a/test/api/cli/io/writeMatrix_view.daphne b/test/api/cli/io/writeMatrix_view.daphne deleted file mode 100644 index 90538095d..000000000 --- a/test/api/cli/io/writeMatrix_view.daphne +++ /dev/null @@ -1,5 +0,0 @@ -// Write a view into a small matrix to the file $outPath. - -X = reshape(seq(1, 6, 1), 3, 2); -X = X[, 1]; # only the 2nd column -write(X, $outPath); \ No newline at end of file diff --git a/test/api/cli/operations/OperationsTest.cpp b/test/api/cli/operations/OperationsTest.cpp index 47857ceb0..4b8f44aa6 100644 --- a/test/api/cli/operations/OperationsTest.cpp +++ b/test/api/cli/operations/OperationsTest.cpp @@ -40,8 +40,10 @@ MAKE_TEST_CASE("createFrame", 1) MAKE_TEST_CASE("ctable", 1) MAKE_TEST_CASE("fill", 1) MAKE_TEST_CASE("gemv", 1) +MAKE_TEST_CASE("groupSum", 1) MAKE_TEST_CASE("idxMax", 1) MAKE_TEST_CASE("idxMin", 1) +MAKE_TEST_CASE("innerJoin", 1) MAKE_TEST_CASE("isNan", 1) MAKE_TEST_CASE("lower", 1) MAKE_TEST_CASE("mean", 1) @@ -59,6 +61,7 @@ MAKE_TEST_CASE("rbind", 1) MAKE_TEST_CASE("recode", 4) MAKE_TEST_CASE("replace", 1) MAKE_TEST_CASE("reverse", 1) +MAKE_TEST_CASE("semiJoin", 1) MAKE_TEST_CASE("seq", 2) MAKE_TEST_CASE("solve", 1) MAKE_TEST_CASE("sqrt", 1) diff --git a/test/api/cli/operations/groupSum_1.daphne b/test/api/cli/operations/groupSum_1.daphne new file mode 100644 index 000000000..82b08ca51 --- /dev/null +++ b/test/api/cli/operations/groupSum_1.daphne @@ -0,0 +1,8 @@ +// Grouping and sum aggregation (two grouping columns). + +fr = createFrame([ 11, 11, 22, 33, 33 ], [ 1, 1, 2, 3, 4 ], + [ 100, 200, 300, 400, 500 ], "a", "b", "c"); + +res = groupSum(fr, "a", "b", "c"); + +print(res); \ No newline at end of file diff --git a/test/api/cli/operations/groupSum_1.txt b/test/api/cli/operations/groupSum_1.txt new file mode 100644 index 000000000..4e24eb7a4 --- /dev/null +++ b/test/api/cli/operations/groupSum_1.txt @@ -0,0 +1,5 @@ +Frame(4x3, [a:int64_t, b:int64_t, SUM(c):int64_t]) +11 1 300 +22 2 300 +33 3 400 +33 4 500 diff --git a/test/api/cli/operations/innerJoin_1.daphne b/test/api/cli/operations/innerJoin_1.daphne new file mode 100644 index 000000000..b2e97b3b0 --- /dev/null +++ b/test/api/cli/operations/innerJoin_1.daphne @@ -0,0 +1,10 @@ +# Test inner join without and with optional arg for result size. + +f1 = createFrame([ 1, 2 ], [ 3, 4 ], "a", "b"); +f2 = createFrame([ 3, 4, 5 ], [ 6, 7, 8 ], "c", "d"); + +f3 = innerJoin(f1, f2, "b", "c"); +f4 = innerJoin(f1, f2, "b", "c", 2); + +print(f3); +print(f4); \ No newline at end of file diff --git a/test/api/cli/operations/innerJoin_1.txt b/test/api/cli/operations/innerJoin_1.txt new file mode 100644 index 000000000..193f56e12 --- /dev/null +++ b/test/api/cli/operations/innerJoin_1.txt @@ -0,0 +1,6 @@ +Frame(2x4, [a:int64_t, b:int64_t, c:int64_t, d:int64_t]) +1 3 3 6 +2 4 4 7 +Frame(2x4, [a:int64_t, b:int64_t, c:int64_t, d:int64_t]) +1 3 3 6 +2 4 4 7 diff --git a/test/api/cli/operations/semiJoin_1.daphne b/test/api/cli/operations/semiJoin_1.daphne new file mode 100644 index 000000000..8a3b3b321 --- /dev/null +++ b/test/api/cli/operations/semiJoin_1.daphne @@ -0,0 +1,10 @@ +# Test semi-join without and with optional arg for result size. + +f1 = createFrame([ 1, 2 ], [ 3, 4 ], "a", "b"); +f2 = createFrame([ 3, 4, 5 ], [ 6, 7, 8 ], "c", "d"); + +keys1, tids1 = semiJoin(f1, f2, "b", "c"); +keys2, tids2 = semiJoin(f1, f2, "b", "c", 2); + +print(f1[tids1, ]); +print(f1[tids2, ]); \ No newline at end of file diff --git a/test/api/cli/operations/semiJoin_1.txt b/test/api/cli/operations/semiJoin_1.txt new file mode 100644 index 000000000..b135e0a8a --- /dev/null +++ b/test/api/cli/operations/semiJoin_1.txt @@ -0,0 +1,6 @@ +Frame(2x2, [a:int64_t, b:int64_t]) +1 3 +2 4 +Frame(2x2, [a:int64_t, b:int64_t]) +1 3 +2 4 diff --git a/test/codegen/sparseExploit.mlir b/test/codegen/sparseExploit.mlir new file mode 100644 index 000000000..3b5789de0 --- /dev/null +++ b/test/codegen/sparseExploit.mlir @@ -0,0 +1,38 @@ +// RUN: daphne-opt -pass-pipeline="builtin.module(lower-sparse-exploit, canonicalize)" %s | FileCheck %s + +// COM: Canonicalizer is run to guarantee matMul, ewLn, and ewMul are also removed as they become redundant but are not removed by the sparse pass itself. +// COM: They are tested here regardless as it is crucial to the pass to replace the entire pattern. + +module { + func.func @double() { + %0 = "daphne.constant"() {value = true} : () -> i1 + %1 = "daphne.constant"() {value = false} : () -> i1 + %2 = "daphne.constant"() {value = 2 : index} : () -> index + %3 = "daphne.constant"() {value = 10 : index} : () -> index + %4 = "daphne.constant"() {value = 3 : si64} : () -> si64 + %5 = "daphne.constant"() {value = 2 : si64} : () -> si64 + %6 = "daphne.constant"() {value = 1 : si64} : () -> si64 + %7 = "daphne.constant"() {value = 0.000000e+00 : f64} : () -> f64 + %8 = "daphne.constant"() {value = 1.000000e+00 : f64} : () -> f64 + %9 = "daphne.constant"() {value = 2.000000e-01 : f64} : () -> f64 + %10 = "daphne.randMatrix"(%3, %3, %7, %8, %9, %6) : (index, index, f64, f64, f64, si64) -> !daphne.Matrix<10x10xf64:sp[2.000000e-01]:rep[sparse]> + %11 = "daphne.randMatrix"(%3, %2, %7, %8, %8, %5) : (index, index, f64, f64, f64, si64) -> !daphne.Matrix<10x2xf64:sp[1.000000e+00]> + %12 = "daphne.randMatrix"(%3, %2, %7, %8, %8, %4) : (index, index, f64, f64, f64, si64) -> !daphne.Matrix<10x2xf64:sp[1.000000e+00]> + // CHECK-NOT: daphne.matMul + // CHECK-NOT: daphne.ewLn + // CHECK-NOT: daphne.ewMul + // CHECK-NOT: daphne.sumAll + // CHECK: affine.parallel + // CHECK: scf.for + // CHECK: scf.for + // CHECK: math.fma + // CHECK: math.log + // CHECK: math.fma + %13 = "daphne.matMul"(%11, %12, %1, %0) : (!daphne.Matrix<10x2xf64:sp[1.000000e+00]>, !daphne.Matrix<10x2xf64:sp[1.000000e+00]>, i1, i1) -> !daphne.Matrix<10x10xf64:sp[1.000000e+00]> + %14 = "daphne.ewLn"(%13) : (!daphne.Matrix<10x10xf64:sp[1.000000e+00]>) -> !daphne.Matrix<10x10xf64> + %15 = "daphne.ewMul"(%10, %14) : (!daphne.Matrix<10x10xf64:sp[2.000000e-01]:rep[sparse]>, !daphne.Matrix<10x10xf64>) -> !daphne.Matrix<10x10xf64:sp[2.000000e-01]:rep[sparse]> + %16 = "daphne.sumAll"(%15) : (!daphne.Matrix<10x10xf64:sp[2.000000e-01]:rep[sparse]>) -> f64 + "daphne.print"(%16, %0, %1) : (f64, i1, i1) -> () + "daphne.return"() : () -> () + } +} \ No newline at end of file diff --git a/test/runtime/local/io/ReadCsvTest.cpp b/test/runtime/local/io/ReadCsvTest.cpp index eed5e30e0..fd68ece99 100644 --- a/test/runtime/local/io/ReadCsvTest.cpp +++ b/test/runtime/local/io/ReadCsvTest.cpp @@ -226,7 +226,7 @@ TEST_CASE("ReadCsv, frame of numbers and strings", TAG_IO) { CHECK(m->getColumn(2)->get(1, 0) == "sample,"); CHECK(m->getColumn(2)->get(2, 0) == "line1\nline2"); CHECK(m->getColumn(2)->get(3, 0) == ""); - CHECK(m->getColumn(2)->get(4, 0) == "\"\"\\n\\\"abc\"\"def\\\""); + CHECK(m->getColumn(2)->get(4, 0) == "\"\\n\\\"abc\"def\\\""); CHECK(m->getColumn(2)->get(5, 0) == ""); CHECK(m->getColumn(3)->get(0, 0) == 444); @@ -337,10 +337,10 @@ TEMPLATE_PRODUCT_TEST_CASE("ReadCsv", TAG_IO, (DenseMatrix), (ALL_STRING_VALUE_T CHECK(m->get(0, 0) == "apple, orange"); CHECK(m->get(1, 0) == "dog, cat"); CHECK(m->get(2, 0) == "table"); - CHECK(m->get(3, 0) == "\"\""); - CHECK(m->get(4, 0) == "abc\"\"def"); + CHECK(m->get(3, 0) == "\""); + CHECK(m->get(4, 0) == "abc\"def"); CHECK(m->get(5, 0) == "red, blue\\n"); - CHECK(m->get(6, 0) == "\\n\\\"abc\"\"def\\\""); + CHECK(m->get(6, 0) == "\\n\\\"abc\"def\\\""); CHECK(m->get(7, 0) == "line1\nline2"); CHECK(m->get(8, 0) == "\\\"red, \\\"\\\""); diff --git a/test/runtime/local/kernels/InnerJoinTest.cpp b/test/runtime/local/kernels/InnerJoinTest.cpp index 371c20cfe..c2ee660f0 100644 --- a/test/runtime/local/kernels/InnerJoinTest.cpp +++ b/test/runtime/local/kernels/InnerJoinTest.cpp @@ -31,25 +31,25 @@ #include -TEST_CASE("innerJoin", TAG_KERNELS) { +TEST_CASE("InnerJoin", TAG_KERNELS) { auto lhsC0 = genGivenVals>(4, {1, 2, 3, 4}); auto lhsC1 = genGivenVals>(4, {11.0, 22.0, 33.0, 44.00}); std::vector lhsCols = {lhsC0, lhsC1}; std::string lhsLabels[] = {"a", "b"}; auto lhs = DataObjectFactory::create(lhsCols, lhsLabels); - auto rhsC0 = genGivenVals>(3, {1, 4, 5}); - auto rhsC1 = genGivenVals>(3, {-1, -4, -5}); - auto rhsC2 = genGivenVals>(3, {0.1, 0.2, 0.3}); + auto rhsC0 = genGivenVals>(4, {1, 4, 5, 4}); + auto rhsC1 = genGivenVals>(4, {-1, -4, -5, -6}); + auto rhsC2 = genGivenVals>(4, {0.1, 0.2, 0.3, 0.4}); std::vector rhsCols = {rhsC0, rhsC1, rhsC2}; std::string rhsLabels[] = {"c", "d", "e"}; auto rhs = DataObjectFactory::create(rhsCols, rhsLabels); Frame *res = nullptr; - innerJoin(res, lhs, rhs, "a", "c", nullptr); + innerJoin(res, lhs, rhs, "a", "c", -1, nullptr); // Check the meta data. - CHECK(res->getNumRows() == 2); + CHECK(res->getNumRows() == 3); CHECK(res->getNumCols() == 5); CHECK(res->getColumnType(0) == ValueTypeCode::SI64); @@ -64,11 +64,11 @@ TEST_CASE("innerJoin", TAG_KERNELS) { CHECK(res->getLabels()[3] == "d"); CHECK(res->getLabels()[4] == "e"); - auto resC0Exp = genGivenVals>(2, {1, 4}); - auto resC1Exp = genGivenVals>(2, {11.0, 44.0}); - auto resC2Exp = genGivenVals>(2, {1, 4}); - auto resC3Exp = genGivenVals>(2, {-1, -4}); - auto resC4Exp = genGivenVals>(2, {0.1, 0.2}); + auto resC0Exp = genGivenVals>(3, {1, 4, 4}); + auto resC1Exp = genGivenVals>(3, {11.0, 44.0, 44.0}); + auto resC2Exp = genGivenVals>(3, {1, 4, 4}); + auto resC3Exp = genGivenVals>(3, {-1, -4, -6}); + auto resC4Exp = genGivenVals>(3, {0.1, 0.2, 0.4}); CHECK(*(res->getColumn(0)) == *resC0Exp); CHECK(*(res->getColumn(1)) == *resC1Exp); diff --git a/test/runtime/local/kernels/SemiJoinTest.cpp b/test/runtime/local/kernels/SemiJoinTest.cpp index 5b0d76017..ea6a47070 100644 --- a/test/runtime/local/kernels/SemiJoinTest.cpp +++ b/test/runtime/local/kernels/SemiJoinTest.cpp @@ -58,7 +58,7 @@ TEST_CASE("SemiJoin", TAG_KERNELS) { // res Frame *res = nullptr; DenseMatrix *lhsTid = nullptr; - semiJoin(res, lhsTid, lhs, rhs, "a", "c", nullptr); + semiJoin(res, lhsTid, lhs, rhs, "a", "c", -1, nullptr); CHECK(*res == *expRes); CHECK(*lhsTid == *expTid);