diff --git a/src/runtime/local/io/WriteMM.h b/src/runtime/local/io/WriteMM.h new file mode 100644 index 000000000..6de781312 --- /dev/null +++ b/src/runtime/local/io/WriteMM.h @@ -0,0 +1,237 @@ +#ifndef WRITE_MM_H +#define WRITE_MM_H + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +template struct WriteMM { static void apply(const DTArg *arg, const char *filename) = delete; }; + +// Convenience function +template void writeMM(const DTArg *arg, const char *filename) { WriteMM::apply(arg, filename); } + +// ---------------------------------------------------------------------------- +// DenseMatrix +// ---------------------------------------------------------------------------- + +template struct WriteMM> { + static void apply(const DenseMatrix *arg, const char *filename) { + const char *format = MM_DENSE_STR; + std::ofstream f(filename); + if (!f.is_open()) { + throw std::runtime_error("WriteMM::apply: Cannot open file"); + } + + const char *field; + if (std::is_integral::value) { + field = MM_INT_STR; + } else if (std::is_floating_point::value) { + field = MM_REAL_STR; + } else { + throw std::runtime_error("WriteMM::apply: Unsupported data type"); + } + + const char *symmetry = MM_GENERAL_STR; + if (isSymmetric(arg)) { + symmetry = MM_SYMM_STR; + } else if (isSkewSymmetric(arg)) { + symmetry = MM_SKEW_STR; + } + + f << MatrixMarketBanner << " matrix " << format << " " << field << " " << symmetry << std::endl; + + size_t rows = arg->getNumRows(); + size_t cols = arg->getNumCols(); + f << rows << " " << cols << std::endl; + + const VT *values = arg->getValues(); + if (!values) { + throw std::runtime_error("WriteMM::apply: Null pointer for 'values' in DenseMatrix"); + } + + if (strcmp(symmetry, MM_GENERAL_STR) == 0) { + for (size_t i = 0; i < cols; ++i) { + for (size_t j = 0; j < rows; ++j) { + size_t idx = j * cols + i; + if (strcmp(field, MM_REAL_STR) == 0) + f << std::scientific << std::setprecision(13) << values[idx] << std::endl; + else + f << values[idx] << std::endl; + } + } + } else if (strcmp(symmetry, MM_SYMM_STR) == 0) { + for (size_t i = 0; i < cols; ++i) { + for (size_t j = i; j < rows; ++j) { + size_t idx = j * cols + i; + if (strcmp(field, MM_REAL_STR) == 0) + f << std::scientific << std::setprecision(13) << values[idx] << std::endl; + else + f << values[idx] << std::endl; + } + } + } else if (strcmp(symmetry, MM_SKEW_STR) == 0) { + for (size_t i = 0; i < cols; ++i) { + for (size_t j = i + 1; j < rows; ++j) { + size_t idx = j * cols + i; + if (strcmp(field, MM_REAL_STR) == 0) + f << std::scientific << std::setprecision(13) << values[idx] << std::endl; + else + f << values[idx] << std::endl; + } + } + } else { + throw std::runtime_error("WriteMM::apply: Unsupported symmetry type"); + } + f.close(); + } + + private: + static bool isSymmetric(const DenseMatrix *arg) { + size_t rows = arg->getNumRows(); + size_t cols = arg->getNumCols(); + if (rows != cols) + return false; + const VT *values = arg->getValues(); + for (size_t i = 0; i < cols; ++i) { + for (size_t j = i + 1; j < rows; ++j) { + size_t idx1 = j + i * rows; + size_t idx2 = i + j * rows; + if (values[idx1] != values[idx2]) + return false; + } + } + return true; + } + + static bool isSkewSymmetric(const DenseMatrix *arg) { + size_t rows = arg->getNumRows(); + size_t cols = arg->getNumCols(); + if (rows != cols) + return false; + const VT *values = arg->getValues(); + for (size_t i = 0; i < rows; ++i) { + size_t idx_diag = i + i * rows; + if (values[idx_diag] != 0) + return false; + } + for (size_t i = 0; i < cols; ++i) { + for (size_t j = i + 1; j < rows; ++j) { + size_t idx1 = j + i * rows; + size_t idx2 = i + j * rows; + if (values[idx1] != -values[idx2]) + return false; + } + } + return true; + } +}; + +// ---------------------------------------------------------------------------- +// CSRMatrix +// ---------------------------------------------------------------------------- + +template struct WriteMM> { + static void apply(const CSRMatrix *arg, const char *filename) { + const char *format = MM_SPARSE_STR; + std::ofstream f(filename); + if (!f.is_open()) { + throw std::runtime_error("WriteMM::apply: Cannot open file"); + } + + const char *field; + if (std::is_integral::value) { + field = MM_INT_STR; + } else if (std::is_floating_point::value) { + field = MM_REAL_STR; + } else { + throw std::runtime_error("WriteMM::apply: Unsupported data type"); + } + + const char *symmetry = MM_GENERAL_STR; + + f << MatrixMarketBanner << " matrix " << format << " " << field << " " << symmetry << std::endl; + + size_t rows = arg->getNumRows(); + size_t cols = arg->getNumCols(); + size_t nnz = countNNZ(arg, symmetry); + + f << rows << " " << cols << " " << nnz << std::endl; + + const size_t *rowOffsets = arg->getRowOffsets(); + const size_t *colIdxs = arg->getColIdxs(); + const VT *values = arg->getValues(); + + std::vector>> colEntries(cols); + + for (size_t i = 0; i < rows; ++i) { + for (size_t idx = rowOffsets[i]; idx < rowOffsets[i + 1]; ++idx) { + size_t j = colIdxs[idx]; + VT val = values[idx]; + colEntries[j].emplace_back(i, val); + } + } + + if (strcmp(symmetry, MM_GENERAL_STR) == 0) { + for (size_t j = 0; j < cols; ++j) { + for (const auto &entry : colEntries[j]) { + size_t i = entry.first; + VT val = entry.second; + if (strcmp(field, MM_REAL_STR) == 0) { + if (val >= 0) { + f << i + 1 << " " << j + 1 << " " << std::scientific << std::setprecision(13) << val + << std::endl; + } else { + f << i + 1 << " " << j + 1 << " " << std::scientific << std::setprecision(13) << val + << std::endl; + } + } else { + f << i + 1 << " " << j + 1 << " " << val << std::endl; + } + } + } + } else { + throw std::runtime_error("WriteMM::apply: Unsupported symmetry type"); + } + + f.close(); + } + + private: + static size_t countNNZ(const CSRMatrix *arg, const char *symmetry) { + size_t nnz = 0; + size_t rows = arg->getNumRows(); + size_t cols = arg->getNumCols(); + + std::vector>> colEntries(cols); + + const size_t *rowOffsets = arg->getRowOffsets(); + const size_t *colIdxs = arg->getColIdxs(); + const VT *values = arg->getValues(); + + for (size_t i = 0; i < rows; ++i) { + for (size_t idx = rowOffsets[i]; idx < rowOffsets[i + 1]; ++idx) { + size_t j = colIdxs[idx]; + colEntries[j].emplace_back(i, values[idx]); + } + } + + if (strcmp(symmetry, MM_GENERAL_STR) == 0) { + nnz = arg->getNumNonZeros(); + } else { + throw std::runtime_error("WriteMM::apply: Unsupported symmetry type"); + } + + return nnz; + } +}; + +#endif // WRITE_MM_H diff --git a/src/runtime/local/kernels/Write.h b/src/runtime/local/kernels/Write.h index 40f6014a3..b3f9316cb 100644 --- a/src/runtime/local/kernels/Write.h +++ b/src/runtime/local/kernels/Write.h @@ -26,6 +26,7 @@ #include #include #include +#include #if USE_HDFS #include #endif @@ -34,9 +35,7 @@ // Struct for partial template specialization // **************************************************************************** -template struct Write { - static void apply(const DTArg *arg, const char *filename, DCTX(ctx)) = delete; -}; +template struct Write { static void apply(const DTArg *arg, const char *filename, DCTX(ctx)) = delete; }; // **************************************************************************** // Convenience function @@ -82,10 +81,11 @@ template struct Write> { // call WriteHDFS writeHDFS(arg, filename, ctx); #endif + } else if (ext == "mtx") { + writeMM(arg, filename); } } }; - // ---------------------------------------------------------------------------- // Frame // ---------------------------------------------------------------------------- @@ -121,6 +121,8 @@ template struct Write> { MetaDataParser::writeMetaData(filename, metaData); writeCsv(arg, file); closeFile(file); + } else if (ext == "mtx") { + writeMM(arg, filename); // Write Matrix in MatrixMarket format } else { throw std::runtime_error("[Write.h] - generic Matrix type currently only supports csv " "file extension."); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index fad31d46d..e6103c19b 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -82,6 +82,7 @@ set(TEST_SOURCES runtime/local/io/ReadCsvTest.cpp runtime/local/io/ReadParquetTest.cpp runtime/local/io/ReadMMTest.cpp + runtime/local/io/WriteMMTest.cpp runtime/local/io/WriteDaphneTest.cpp runtime/local/io/ReadDaphneTest.cpp runtime/local/io/DaphneSerializerTest.cpp diff --git a/test/api/cli/io/WriteTest.cpp b/test/api/cli/io/WriteTest.cpp index b58b8ec09..0800945db 100644 --- a/test/api/cli/io/WriteTest.cpp +++ b/test/api/cli/io/WriteTest.cpp @@ -21,9 +21,49 @@ #include #include +#include +#include #include const std::string dirPath = "test/api/cli/io/"; +const std::string dirPath2 = "test/runtime/local/io/"; + +bool compareFiles(const std::string &filePath1, const std::string &filePath2) { + + std::ifstream file1(filePath1, std::ios::binary); + std::ifstream file2(filePath2, std::ios::binary); + + if (!file1.is_open() || !file2.is_open()) { + std::cerr << "Cannot open one or both files." << std::endl; + return false; + } + + std::string line1, line2; + bool filesAreEqual = true; + + while (std::getline(file1, line1)) { + if (!std::getline(file2, line2)) { + filesAreEqual = false; + break; + } + + if (line1 != line2) { + filesAreEqual = false; + break; + } + } + + if (filesAreEqual && std::getline(file2, line2)) { + if (!line2.empty()) { + filesAreEqual = false; + } + } + + file1.close(); + file2.close(); + + return filesAreEqual; +} TEST_CASE("writeMatrixCSV_Full", TAG_IO) { std::string csvPath = dirPath + "matrix_full.csv"; @@ -41,4 +81,31 @@ TEST_CASE("writeMatrixCSV_View", TAG_IO) { std::string("outPath=\"" + csvPath + "\"").c_str()); compareDaphneToRef(dirPath + "matrix_view_ref.csv", dirPath + "readMatrix.daphne", "--args", std::string("inPath=\"" + csvPath + "\"").c_str()); +} + +TEST_CASE("writeMatrixMtxaig", TAG_IO) { + std::string expectedPath = dirPath2 + "aig.mtx"; + std::string resultPath = "out.mtx"; + checkDaphneStatusCode(StatusCode::SUCCESS, dirPath + "readAndWriteMtx.daphne", "--args", + std::string("inPath=\"" + expectedPath + "\"").c_str()); + CHECK(compareFiles(expectedPath, resultPath)); + std::filesystem::remove(resultPath); // remove old file if it still exists +} + +TEST_CASE("writeMatrixMtxaik", TAG_IO) { + std::string expectedPath = dirPath2 + "aik.mtx"; + std::string resultPath = "out.mtx"; + checkDaphneStatusCode(StatusCode::SUCCESS, dirPath + "readAndWriteMtx.daphne", "--args", + std::string("inPath=\"" + expectedPath + "\"").c_str()); + CHECK(compareFiles(expectedPath, resultPath)); + std::filesystem::remove(resultPath); // remove old file if it still exists +} + +TEST_CASE("writeMatrixMtxais", TAG_IO) { + std::string expectedPath = dirPath2 + "ais.mtx"; + std::string resultPath = "out.mtx"; + checkDaphneStatusCode(StatusCode::SUCCESS, dirPath + "readAndWriteMtx.daphne", "--args", + std::string("inPath=\"" + expectedPath + "\"").c_str()); + CHECK(compareFiles(expectedPath, resultPath)); + std::filesystem::remove(resultPath); // remove old file if it still exists } \ No newline at end of file diff --git a/test/api/cli/io/readAndWriteMtx.daphne b/test/api/cli/io/readAndWriteMtx.daphne new file mode 100644 index 000000000..590559c4f --- /dev/null +++ b/test/api/cli/io/readAndWriteMtx.daphne @@ -0,0 +1,3 @@ +X = readMatrix($inPath); +print(X); +write(X, "out.mtx"); \ No newline at end of file diff --git a/test/runtime/local/io/WriteMMTest.cpp b/test/runtime/local/io/WriteMMTest.cpp new file mode 100644 index 000000000..4d02edfa3 --- /dev/null +++ b/test/runtime/local/io/WriteMMTest.cpp @@ -0,0 +1,209 @@ +/* + * Copyright 2022 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include + +bool compareContentsFromFile(const std::string &filePath1, const std::string &filePath2) { + std::ifstream file1(filePath1, std::ios::binary); + std::ifstream file2(filePath2, std::ios::binary); + + if (!file1.is_open() || !file2.is_open()) { + std::cerr << "Cannot open one or both files." << std::endl; + return false; + } + + std::string line1, line2; + bool filesAreEqual = true; + + while (std::getline(file1, line1)) { + if (!std::getline(file2, line2)) { + filesAreEqual = false; + break; + } + + if (line1 != line2) { + filesAreEqual = false; + break; + } + } + + if (filesAreEqual && std::getline(file2, line2)) { + if (!line2.empty()) { + filesAreEqual = false; + } + } + + file1.close(); + file2.close(); + + return filesAreEqual; +} + +TEMPLATE_PRODUCT_TEST_CASE("WriteMM AIG", TAG_IO, (DenseMatrix), (int32_t)) { + using DT = TestType; + DT *m = nullptr; + + size_t numRows = 4; + size_t numCols = 3; + + char filename[] = "./test/runtime/local/io/aig.mtx"; + char resultPath[] = "out.mtx"; + readMM(m, filename); + + REQUIRE(m->getNumRows() == numRows); + REQUIRE(m->getNumCols() == numCols); + + writeMM(m, resultPath); + + CHECK(compareContentsFromFile(filename, resultPath)); + std::filesystem::remove(resultPath); + + CHECK(m->get(0, 0) == 1); + CHECK(m->get(1, 0) == 2); + CHECK(m->get(0, 1) == 5); + CHECK(m->get(3, 2) == 12); + CHECK(m->get(2, 1) == 7); + + DataObjectFactory::destroy(m); +} + +TEMPLATE_PRODUCT_TEST_CASE("WriteMM AIK", TAG_IO, (DenseMatrix), (int32_t)) { + using DT = TestType; + DT *m = nullptr; + + size_t numRows = 4; + size_t numCols = 4; + + char filename[] = "./test/runtime/local/io/aik.mtx"; + char resultPath[] = "out.mtx"; + readMM(m, filename); + + REQUIRE(m->getNumRows() == numRows); + REQUIRE(m->getNumCols() == numCols); + + writeMM(m, resultPath); + + CHECK(compareContentsFromFile(filename, resultPath)); + std::filesystem::remove(resultPath); + + CHECK(m->get(1, 0) == 1); + + for (size_t r = 0; r < numRows; r++) { + CHECK(m->get(r, r) == 0); + for (size_t c = r + 1; c < numCols; c++) + CHECK(m->get(r, c) == -m->get(c, r)); + } + + DataObjectFactory::destroy(m); +} + +TEMPLATE_PRODUCT_TEST_CASE("WriteMM AIS", TAG_IO, (DenseMatrix), (int32_t)) { + using DT = TestType; + DT *m = nullptr; + + size_t numRows = 3; + size_t numCols = 3; + + char filename[] = "./test/runtime/local/io/ais.mtx"; + char resultPath[] = "out.mtx"; + readMM(m, filename); + + REQUIRE(m->getNumRows() == numRows); + REQUIRE(m->getNumCols() == numCols); + + writeMM(m, resultPath); + + CHECK(compareContentsFromFile(filename, resultPath)); + std::filesystem::remove(resultPath); + + CHECK(m->get(1, 1) == 4); + + for (size_t r = 0; r < numRows; r++) + for (size_t c = r + 1; c < numCols; c++) + CHECK(m->get(r, c) == m->get(c, r)); + + DataObjectFactory::destroy(m); +} + +TEMPLATE_PRODUCT_TEST_CASE("WriteMM CIG (CSR)", TAG_IO, (CSRMatrix), (int32_t)) { + using DT = TestType; + DT *m = nullptr; + + size_t numRows = 9; + size_t numCols = 9; + + char filename[] = "./test/runtime/local/io/cig.mtx"; + char resultPath[] = "out.mtx"; + readMM(m, filename); + + REQUIRE(m->getNumRows() == numRows); + REQUIRE(m->getNumCols() == numCols); + + writeMM(m, resultPath); + + CHECK(compareContentsFromFile(filename, resultPath)); + std::filesystem::remove(resultPath); + + CHECK(m->get(0, 0) == 1); + CHECK(m->get(2, 0) == 0); + CHECK(m->get(3, 4) == 9); + CHECK(m->get(7, 4) == 4); + + DataObjectFactory::destroy(m); +} + +TEMPLATE_PRODUCT_TEST_CASE("WriteMM CRG (CSR)", TAG_IO, (CSRMatrix), (double)) { + using DT = TestType; + DT *m = nullptr; + + size_t numRows = 497; + size_t numCols = 507; + + char filename[] = "./test/runtime/local/io/crg.mtx"; + char resultPath[] = "out.mtx"; + readMM(m, filename); + + REQUIRE(m->getNumRows() == numRows); + REQUIRE(m->getNumCols() == numCols); + + writeMM(m, resultPath); + + CHECK(compareContentsFromFile(filename, resultPath)); + std::filesystem::remove(resultPath); + + CHECK(m->get(5, 0) == 0.25599762); + CHECK(m->get(6, 0) == 0.13827993); + CHECK(m->get(200, 4) == 0.20001954); + + DataObjectFactory::destroy(m); +} \ No newline at end of file diff --git a/test/runtime/local/io/aig.mtx.meta b/test/runtime/local/io/aig.mtx.meta new file mode 100644 index 000000000..f84a4ff4f --- /dev/null +++ b/test/runtime/local/io/aig.mtx.meta @@ -0,0 +1,6 @@ +{ + "numRows": 4, + "numCols": 3, + "valueType": "si64", + "numNonZeros": 12 +} diff --git a/test/runtime/local/io/aik.mtx.meta b/test/runtime/local/io/aik.mtx.meta new file mode 100644 index 000000000..b04ed1b53 --- /dev/null +++ b/test/runtime/local/io/aik.mtx.meta @@ -0,0 +1,6 @@ +{ + "numRows": 4, + "numCols": 4, + "valueType": "si64", + "numNonZeros": 16 +} diff --git a/test/runtime/local/io/ais.mtx.meta b/test/runtime/local/io/ais.mtx.meta new file mode 100644 index 000000000..43bfe32f6 --- /dev/null +++ b/test/runtime/local/io/ais.mtx.meta @@ -0,0 +1,6 @@ +{ + "numRows": 3, + "numCols": 3, + "valueType": "si64", + "numNonZeros": 9 +} diff --git a/test/runtime/local/io/cig.mtx b/test/runtime/local/io/cig.mtx index f12363fd9..f30c9d439 100644 --- a/test/runtime/local/io/cig.mtx +++ b/test/runtime/local/io/cig.mtx @@ -1,7 +1,4 @@ %%MatrixMarket matrix coordinate integer general -% -% -% 1 2 3 9 9 50 1 1 1 2 1 2