Skip to content

Commit

Permalink
[DAPHNE-#822] Added MatrixMarket write support
Browse files Browse the repository at this point in the history
- Write support for CSRMatrix in coordinate system
- Write support for DenseMatrix in array format
- Support for recognizing symmetric, skew-symmetric matrices, and storing them accordingly in .mtx
- Added testcases, calling readMM, then writeMM and then comparing if the files are still the same
  • Loading branch information
ldirry committed Oct 2, 2024
1 parent 0e5e8a9 commit bfa0a93
Show file tree
Hide file tree
Showing 10 changed files with 541 additions and 7 deletions.
237 changes: 237 additions & 0 deletions src/runtime/local/io/WriteMM.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
#ifndef WRITE_MM_H
#define WRITE_MM_H

#include <runtime/local/datastructures/CSRMatrix.h>
#include <runtime/local/datastructures/DenseMatrix.h>
#include <runtime/local/datastructures/Frame.h>
#include <runtime/local/io/MMFile.h>

#include <cstdio>
#include <cstring>
#include <fstream>
#include <iomanip>
#include <iostream>
#include <type_traits>
#include <vector>

template <class DTArg> struct WriteMM { static void apply(const DTArg *arg, const char *filename) = delete; };

// Convenience function
template <class DTArg> void writeMM(const DTArg *arg, const char *filename) { WriteMM<DTArg>::apply(arg, filename); }

// ----------------------------------------------------------------------------
// DenseMatrix
// ----------------------------------------------------------------------------

template <typename VT> struct WriteMM<DenseMatrix<VT>> {
static void apply(const DenseMatrix<VT> *arg, const char *filename) {
const char *format = MM_DENSE_STR;
std::ofstream f(filename);
if (!f.is_open()) {
throw std::runtime_error("WriteMM::apply: Cannot open file");
}

const char *field;
if (std::is_integral<VT>::value) {
field = MM_INT_STR;
} else if (std::is_floating_point<VT>::value) {
field = MM_REAL_STR;
} else {
throw std::runtime_error("WriteMM::apply: Unsupported data type");
}

const char *symmetry = MM_GENERAL_STR;
if (isSymmetric(arg)) {
symmetry = MM_SYMM_STR;
} else if (isSkewSymmetric(arg)) {
symmetry = MM_SKEW_STR;
}

f << MatrixMarketBanner << " matrix " << format << " " << field << " " << symmetry << std::endl;

size_t rows = arg->getNumRows();
size_t cols = arg->getNumCols();
f << rows << " " << cols << std::endl;

const VT *values = arg->getValues();
if (!values) {
throw std::runtime_error("WriteMM::apply: Null pointer for 'values' in DenseMatrix");
}

if (strcmp(symmetry, MM_GENERAL_STR) == 0) {
for (size_t i = 0; i < cols; ++i) {
for (size_t j = 0; j < rows; ++j) {
size_t idx = j * cols + i;
if (strcmp(field, MM_REAL_STR) == 0)
f << std::scientific << std::setprecision(13) << values[idx] << std::endl;
else
f << values[idx] << std::endl;
}
}
} else if (strcmp(symmetry, MM_SYMM_STR) == 0) {
for (size_t i = 0; i < cols; ++i) {
for (size_t j = i; j < rows; ++j) {
size_t idx = j * cols + i;
if (strcmp(field, MM_REAL_STR) == 0)
f << std::scientific << std::setprecision(13) << values[idx] << std::endl;
else
f << values[idx] << std::endl;
}
}
} else if (strcmp(symmetry, MM_SKEW_STR) == 0) {
for (size_t i = 0; i < cols; ++i) {
for (size_t j = i + 1; j < rows; ++j) {
size_t idx = j * cols + i;
if (strcmp(field, MM_REAL_STR) == 0)
f << std::scientific << std::setprecision(13) << values[idx] << std::endl;
else
f << values[idx] << std::endl;
}
}
} else {
throw std::runtime_error("WriteMM::apply: Unsupported symmetry type");
}
f.close();
}

private:
static bool isSymmetric(const DenseMatrix<VT> *arg) {
size_t rows = arg->getNumRows();
size_t cols = arg->getNumCols();
if (rows != cols)
return false;
const VT *values = arg->getValues();
for (size_t i = 0; i < cols; ++i) {
for (size_t j = i + 1; j < rows; ++j) {
size_t idx1 = j + i * rows;
size_t idx2 = i + j * rows;
if (values[idx1] != values[idx2])
return false;
}
}
return true;
}

static bool isSkewSymmetric(const DenseMatrix<VT> *arg) {
size_t rows = arg->getNumRows();
size_t cols = arg->getNumCols();
if (rows != cols)
return false;
const VT *values = arg->getValues();
for (size_t i = 0; i < rows; ++i) {
size_t idx_diag = i + i * rows;
if (values[idx_diag] != 0)
return false;
}
for (size_t i = 0; i < cols; ++i) {
for (size_t j = i + 1; j < rows; ++j) {
size_t idx1 = j + i * rows;
size_t idx2 = i + j * rows;
if (values[idx1] != -values[idx2])
return false;
}
}
return true;
}
};

// ----------------------------------------------------------------------------
// CSRMatrix
// ----------------------------------------------------------------------------

template <typename VT> struct WriteMM<CSRMatrix<VT>> {
static void apply(const CSRMatrix<VT> *arg, const char *filename) {
const char *format = MM_SPARSE_STR;
std::ofstream f(filename);
if (!f.is_open()) {
throw std::runtime_error("WriteMM::apply: Cannot open file");
}

const char *field;
if (std::is_integral<VT>::value) {
field = MM_INT_STR;
} else if (std::is_floating_point<VT>::value) {
field = MM_REAL_STR;
} else {
throw std::runtime_error("WriteMM::apply: Unsupported data type");
}

const char *symmetry = MM_GENERAL_STR;

f << MatrixMarketBanner << " matrix " << format << " " << field << " " << symmetry << std::endl;

size_t rows = arg->getNumRows();
size_t cols = arg->getNumCols();
size_t nnz = countNNZ(arg, symmetry);

f << rows << " " << cols << " " << nnz << std::endl;

const size_t *rowOffsets = arg->getRowOffsets();
const size_t *colIdxs = arg->getColIdxs();
const VT *values = arg->getValues();

std::vector<std::vector<std::pair<size_t, VT>>> colEntries(cols);

for (size_t i = 0; i < rows; ++i) {
for (size_t idx = rowOffsets[i]; idx < rowOffsets[i + 1]; ++idx) {
size_t j = colIdxs[idx];
VT val = values[idx];
colEntries[j].emplace_back(i, val);
}
}

if (strcmp(symmetry, MM_GENERAL_STR) == 0) {
for (size_t j = 0; j < cols; ++j) {
for (const auto &entry : colEntries[j]) {
size_t i = entry.first;
VT val = entry.second;
if (strcmp(field, MM_REAL_STR) == 0) {
if (val >= 0) {
f << i + 1 << " " << j + 1 << " " << std::scientific << std::setprecision(13) << val
<< std::endl;
} else {
f << i + 1 << " " << j + 1 << " " << std::scientific << std::setprecision(13) << val
<< std::endl;
}
} else {
f << i + 1 << " " << j + 1 << " " << val << std::endl;
}
}
}
} else {
throw std::runtime_error("WriteMM::apply: Unsupported symmetry type");
}

f.close();
}

private:
static size_t countNNZ(const CSRMatrix<VT> *arg, const char *symmetry) {
size_t nnz = 0;
size_t rows = arg->getNumRows();
size_t cols = arg->getNumCols();

std::vector<std::vector<std::pair<size_t, VT>>> colEntries(cols);

const size_t *rowOffsets = arg->getRowOffsets();
const size_t *colIdxs = arg->getColIdxs();
const VT *values = arg->getValues();

for (size_t i = 0; i < rows; ++i) {
for (size_t idx = rowOffsets[i]; idx < rowOffsets[i + 1]; ++idx) {
size_t j = colIdxs[idx];
colEntries[j].emplace_back(i, values[idx]);
}
}

if (strcmp(symmetry, MM_GENERAL_STR) == 0) {
nnz = arg->getNumNonZeros();
} else {
throw std::runtime_error("WriteMM::apply: Unsupported symmetry type");
}

return nnz;
}
};

#endif // WRITE_MM_H
10 changes: 6 additions & 4 deletions src/runtime/local/kernels/Write.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <runtime/local/io/FileMetaData.h>
#include <runtime/local/io/WriteCsv.h>
#include <runtime/local/io/WriteDaphne.h>
#include <runtime/local/io/WriteMM.h>
#if USE_HDFS
#include <runtime/local/io/HDFS/WriteHDFS.h>
#endif
Expand All @@ -34,9 +35,7 @@
// Struct for partial template specialization
// ****************************************************************************

template <class DTArg> struct Write {
static void apply(const DTArg *arg, const char *filename, DCTX(ctx)) = delete;
};
template <class DTArg> struct Write { static void apply(const DTArg *arg, const char *filename, DCTX(ctx)) = delete; };

// ****************************************************************************
// Convenience function
Expand Down Expand Up @@ -82,10 +81,11 @@ template <typename VT> struct Write<DenseMatrix<VT>> {
// call WriteHDFS
writeHDFS(arg, filename, ctx);
#endif
} else if (ext == "mtx") {
writeMM(arg, filename);
}
}
};

// ----------------------------------------------------------------------------
// Frame
// ----------------------------------------------------------------------------
Expand Down Expand Up @@ -121,6 +121,8 @@ template <typename VT> struct Write<Matrix<VT>> {
MetaDataParser::writeMetaData(filename, metaData);
writeCsv(arg, file);
closeFile(file);
} else if (ext == "mtx") {
writeMM(arg, filename); // Write Matrix in MatrixMarket format
} else {
throw std::runtime_error("[Write.h] - generic Matrix type currently only supports csv "
"file extension.");
Expand Down
1 change: 1 addition & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ set(TEST_SOURCES
runtime/local/io/ReadCsvTest.cpp
runtime/local/io/ReadParquetTest.cpp
runtime/local/io/ReadMMTest.cpp
runtime/local/io/WriteMMTest.cpp
runtime/local/io/WriteDaphneTest.cpp
runtime/local/io/ReadDaphneTest.cpp
runtime/local/io/DaphneSerializerTest.cpp
Expand Down
67 changes: 67 additions & 0 deletions test/api/cli/io/WriteTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,49 @@
#include <catch.hpp>

#include <filesystem>
#include <fstream>
#include <iostream>
#include <string>

const std::string dirPath = "test/api/cli/io/";
const std::string dirPath2 = "test/runtime/local/io/";

bool compareFiles(const std::string &filePath1, const std::string &filePath2) {

std::ifstream file1(filePath1, std::ios::binary);
std::ifstream file2(filePath2, std::ios::binary);

if (!file1.is_open() || !file2.is_open()) {
std::cerr << "Cannot open one or both files." << std::endl;
return false;
}

std::string line1, line2;
bool filesAreEqual = true;

while (std::getline(file1, line1)) {
if (!std::getline(file2, line2)) {
filesAreEqual = false;
break;
}

if (line1 != line2) {
filesAreEqual = false;
break;
}
}

if (filesAreEqual && std::getline(file2, line2)) {
if (!line2.empty()) {
filesAreEqual = false;
}
}

file1.close();
file2.close();

return filesAreEqual;
}

TEST_CASE("writeMatrixCSV_Full", TAG_IO) {
std::string csvPath = dirPath + "matrix_full.csv";
Expand All @@ -41,4 +81,31 @@ TEST_CASE("writeMatrixCSV_View", TAG_IO) {
std::string("outPath=\"" + csvPath + "\"").c_str());
compareDaphneToRef(dirPath + "matrix_view_ref.csv", dirPath + "readMatrix.daphne", "--args",
std::string("inPath=\"" + csvPath + "\"").c_str());
}

TEST_CASE("writeMatrixMtxaig", TAG_IO) {
std::string expectedPath = dirPath2 + "aig.mtx";
std::string resultPath = "out.mtx";
checkDaphneStatusCode(StatusCode::SUCCESS, dirPath + "readAndWriteMtx.daphne", "--args",
std::string("inPath=\"" + expectedPath + "\"").c_str());
CHECK(compareFiles(expectedPath, resultPath));
std::filesystem::remove(resultPath); // remove old file if it still exists
}

TEST_CASE("writeMatrixMtxaik", TAG_IO) {
std::string expectedPath = dirPath2 + "aik.mtx";
std::string resultPath = "out.mtx";
checkDaphneStatusCode(StatusCode::SUCCESS, dirPath + "readAndWriteMtx.daphne", "--args",
std::string("inPath=\"" + expectedPath + "\"").c_str());
CHECK(compareFiles(expectedPath, resultPath));
std::filesystem::remove(resultPath); // remove old file if it still exists
}

TEST_CASE("writeMatrixMtxais", TAG_IO) {
std::string expectedPath = dirPath2 + "ais.mtx";
std::string resultPath = "out.mtx";
checkDaphneStatusCode(StatusCode::SUCCESS, dirPath + "readAndWriteMtx.daphne", "--args",
std::string("inPath=\"" + expectedPath + "\"").c_str());
CHECK(compareFiles(expectedPath, resultPath));
std::filesystem::remove(resultPath); // remove old file if it still exists
}
3 changes: 3 additions & 0 deletions test/api/cli/io/readAndWriteMtx.daphne
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
X = readMatrix($inPath);
print(X);
write(X, "out.mtx");
Loading

0 comments on commit bfa0a93

Please sign in to comment.