From c692850dde3d42a337a715385261b77dd1068589 Mon Sep 17 00:00:00 2001 From: STommydx Date: Fri, 5 Apr 2024 17:35:28 -0400 Subject: [PATCH] feat(templates): matrix implementation based on valarray --- matrix.hpp | 345 ++++++++++++++++++++++++++++++++++++++++++++++++ matrix_test.cpp | 161 ++++++++++++++++++++++ 2 files changed, 506 insertions(+) create mode 100644 matrix.hpp create mode 100644 matrix_test.cpp diff --git a/matrix.hpp b/matrix.hpp new file mode 100644 index 0000000..71b19d7 --- /dev/null +++ b/matrix.hpp @@ -0,0 +1,345 @@ +/** + * @file matrix.hpp + * @brief Provides math matrix operations + */ + +#ifndef MATRIX_HPP +#define MATRIX_HPP + +#include + +template class matrix; + +#define NON_MEMBER_BINARY_OP(OP) \ + template \ + matrix operator OP(const matrix &a, const matrix &b); \ + template matrix operator OP(const matrix &a, const T & b); \ + template matrix operator OP(const T & a, const matrix &b); + +NON_MEMBER_BINARY_OP(+) +NON_MEMBER_BINARY_OP(-) +NON_MEMBER_BINARY_OP(*) +NON_MEMBER_BINARY_OP(/) +NON_MEMBER_BINARY_OP(%) +NON_MEMBER_BINARY_OP(&) +NON_MEMBER_BINARY_OP(|) +NON_MEMBER_BINARY_OP(^) +NON_MEMBER_BINARY_OP(<<) +NON_MEMBER_BINARY_OP(>>) +#undef NON_MEMBER_BINARY_OP + +#define NON_MEMBER_BINARY_PREDICATE(OP) \ + template \ + matrix operator OP(const matrix &a, const matrix &b); \ + template \ + matrix operator OP(const matrix &a, const T & b); \ + template \ + matrix operator OP(const T & a, const matrix &b); + +NON_MEMBER_BINARY_PREDICATE(&&) +NON_MEMBER_BINARY_PREDICATE(||) +NON_MEMBER_BINARY_PREDICATE(==) +NON_MEMBER_BINARY_PREDICATE(!=) +NON_MEMBER_BINARY_PREDICATE(<) +NON_MEMBER_BINARY_PREDICATE(>) +NON_MEMBER_BINARY_PREDICATE(<=) +NON_MEMBER_BINARY_PREDICATE(>=) +#undef NON_MEMBER_BINARY_PREDICATE + +template matrix matmul(const matrix &a, const matrix &b); + +template class matrix { + size_t n, m; + std::valarray dat; + + public: + static constexpr size_t none_axis = -1; + + /* + * Constructors and assignment operators + */ + explicit matrix(size_t count_n, size_t count_m, const T &val = {}) + : n(count_n), m(count_m), dat(val, count_n * count_m) {} + explicit matrix(size_t count_n, size_t count_m, + const std::valarray &vals) + : n(count_n), m(count_m), dat(vals) {} + explicit matrix(size_t count_n, size_t count_m, std::valarray &&vals) + : n(count_n), m(count_m), dat(std::move(vals)) {} + matrix(const std::vector> &v) + : n(v.size()), m(v.empty() ? 0 : v[0].size()), dat(n * m) { + for (size_t i = 0; i < n; ++i) { + std::ranges::copy(v[i], dat.begin() + i * m); + } + } + + matrix &operator=(const std::valarray &other) { + dat = other; + return *this; + } + matrix &operator=(const std::valarray &&other) { + dat = std::move(other); + return *this; + } + matrix &operator=(const T &val) { + dat = val; + return *this; + } + + /* + * Other matrix builders + */ + static matrix zeros(size_t count_n, size_t count_m) { + return matrix(count_n, count_m, 0); + } + static matrix ones(size_t count_n, size_t count_m) { + return matrix(count_n, count_m, 1); + } + static matrix eye(size_t count_n, size_t count_m = none_axis, + size_t k = 0) { + if (count_m == none_axis) { + count_m = count_n; + } + matrix res(count_n, count_m, 0); + res.diagonal(k) = 1; + return res; + } + static matrix identity(size_t count_n) { return eye(count_n); } + + /* + * Element access + */ + const T &at(size_t i, size_t j) const { return dat[i * m + j]; } + T &at(size_t i, size_t j) { return dat[i * m + j]; } + std::valarray row(size_t i) const { + return dat[std::slice(i * m, m, 1)]; + } + std::slice_array row(size_t i) { return dat[std::slice(i * m, m, 1)]; } + std::valarray col(size_t j) const { return dat[std::slice(j, n, m)]; } + std::slice_array col(size_t j) { return dat[std::slice(j, n, m)]; } + std::valarray diagonal(size_t offset) const { + return dat[std::slice(offset, std::min(n, m - offset), m + 1)]; + } + std::slice_array diagonal(size_t offset) { + return dat[std::slice(offset, std::min(n, m - offset), m + 1)]; + } + matrix transpose() const { + matrix res(m, n); + for (size_t i = 0; i < n; ++i) + res.col(i) = row(i); + return res; + } + const std::valarray &data() const { return dat; } + std::valarray &data() { return dat; } + std::valarray flatten() const { return dat; } + std::vector> to_vector() const { + std::vector> res(n, std::vector(m)); + for (size_t i = 0; i < n; ++i) { + std::ranges::copy(row(i), res[i].begin()); + } + return res; + } + + /* + * Metadata + */ + size_t size() const { return n * m; } + std::pair shape() const { return {n, m}; } + + /* + * Aggregate operations + */ + template + using aggregate_t = typename std::conditional_t< + KeepDims, matrix, + std::conditional_t>>; + template &> Aggregate> + aggregate_t aggregate(Aggregate f = {}) const { + // note that partial function templates specialization is not allowed + if constexpr (KeepDims) { + if constexpr (Axis == none_axis) { + return matrix(1, 1, std::invoke(f, dat)); + } else if constexpr (Axis == 0) { + matrix ret(1, m); + for (size_t i = 0; i < m; ++i) { + ret(0, i) = std::invoke(f, col(i)); + } + return ret; + } else if constexpr (Axis == 1) { + matrix ret(n, 1); + for (size_t i = 0; i < n; ++i) { + ret(i, 0) = std::invoke(f, row(i)); + } + return ret; + } else { + return *this; + } + } else { + if constexpr (Axis == none_axis) { + return std::invoke(f, dat); + } else if constexpr (Axis == 0) { + std::valarray ret(m); + for (size_t i = 0; i < m; ++i) { + ret[i] = std::invoke(f, col(i)); + } + return ret; + } else if constexpr (Axis == 1) { + std::valarray ret(n); + for (size_t i = 0; i < n; ++i) { + ret[i] = std::invoke(f, row(i)); + } + return ret; + } else { + return dat; + } + } + } + + template + aggregate_t sum() const { + return aggregate(&std::valarray::sum); + } + template + aggregate_t min() const { + return aggregate(&std::valarray::min); + } + template + aggregate_t max() const { + return aggregate(&std::valarray::max); + } + + /* + * Operator overloads + */ + std::valarray operator[](size_t i) const { return row(i); } + std::slice_array operator[](size_t i) { return row(i); } + const T &operator()(size_t i, size_t j) const { return at(i, j); } + T &operator()(size_t i, size_t j) { return at(i, j); } + operator std::vector>() const { return to_vector(); } + + matrix operator+() const { return matrix(n, m, +dat); } + matrix operator-() const { return matrix(n, m, -dat); } + matrix operator~() const { return matrix(n, m, ~dat); } + matrix operator!() const { return matrix(n, m, !dat); } + +#define MEMBER_BINARY_OP(OP) \ + matrix &operator OP(const matrix & m) { \ + dat OP m.dat; \ + return *this; \ + } \ + matrix &operator OP(const T & x) { \ + dat OP x; \ + return *this; \ + } + + MEMBER_BINARY_OP(+=) + MEMBER_BINARY_OP(-=) + MEMBER_BINARY_OP(*=) + MEMBER_BINARY_OP(/=) + MEMBER_BINARY_OP(%=) + MEMBER_BINARY_OP(&=) + MEMBER_BINARY_OP(|=) + MEMBER_BINARY_OP(^=) + MEMBER_BINARY_OP(<<=) + MEMBER_BINARY_OP(>>=) +#undef MEMBER_BINARY_OP + +/* + * Non-member functions + */ +#define NON_MEMBER_BINARY_OP(OP) \ + friend matrix operator OP<>(const matrix &a, const matrix &b); \ + friend matrix operator OP<>(const matrix &a, const T & b); \ + friend matrix operator OP<>(const T & a, const matrix &b); + + NON_MEMBER_BINARY_OP(+) + NON_MEMBER_BINARY_OP(-) + NON_MEMBER_BINARY_OP(*) + NON_MEMBER_BINARY_OP(/) + NON_MEMBER_BINARY_OP(%) + NON_MEMBER_BINARY_OP(&) + NON_MEMBER_BINARY_OP(|) + NON_MEMBER_BINARY_OP(^) + NON_MEMBER_BINARY_OP(<<) + NON_MEMBER_BINARY_OP(>>) +#undef NON_MEMBER_BINARY_OP + +#define NON_MEMBER_BINARY_PREDICATE(OP) \ + friend matrix operator OP<>(const matrix &a, const matrix &b); \ + friend matrix operator OP<>(const matrix &a, const T & b); \ + friend matrix operator OP<>(const T & a, const matrix &b); + + NON_MEMBER_BINARY_PREDICATE(&&) + NON_MEMBER_BINARY_PREDICATE(||) + NON_MEMBER_BINARY_PREDICATE(==) + NON_MEMBER_BINARY_PREDICATE(!=) + NON_MEMBER_BINARY_PREDICATE(<) + NON_MEMBER_BINARY_PREDICATE(<=) + NON_MEMBER_BINARY_PREDICATE(>) + NON_MEMBER_BINARY_PREDICATE(>=) +#undef NON_MEMBER_BINARY_PREDICATE + + friend matrix matmul<>(const matrix &a, const matrix &b); +}; + +#define NON_MEMBER_BINARY_OP(OP) \ + template \ + matrix operator OP(const matrix &a, const matrix &b) { \ + return matrix(a.n, a.m, a.dat OP b.dat); \ + } \ + template \ + matrix operator OP(const matrix &a, const T & b) { \ + return matrix(a.n, a.m, a.dat OP b); \ + } \ + template \ + matrix operator OP(const T & a, const matrix &b) { \ + return matrix(b.n, b.m, a OP b.dat); \ + } + +NON_MEMBER_BINARY_OP(+) +NON_MEMBER_BINARY_OP(-) +NON_MEMBER_BINARY_OP(*) +NON_MEMBER_BINARY_OP(/) +NON_MEMBER_BINARY_OP(%) +NON_MEMBER_BINARY_OP(&) +NON_MEMBER_BINARY_OP(|) +NON_MEMBER_BINARY_OP(^) +NON_MEMBER_BINARY_OP(<<) +NON_MEMBER_BINARY_OP(>>) +#undef NON_MEMBER_BINARY_OP + +#define NON_MEMBER_BINARY_PREDICATE(OP) \ + template \ + matrix operator OP(const matrix &a, const matrix &b) { \ + return matrix(a.n, a.m, a.dat OP b.dat); \ + } \ + template \ + matrix operator OP(const matrix &a, const T & b) { \ + return matrix(a.n, a.m, a.dat OP b); \ + } \ + template \ + matrix operator OP(const T & a, const matrix &b) { \ + return matrix(b.n, b.m, a OP b.dat); \ + } + +NON_MEMBER_BINARY_PREDICATE(&&) +NON_MEMBER_BINARY_PREDICATE(||) +NON_MEMBER_BINARY_PREDICATE(==) +NON_MEMBER_BINARY_PREDICATE(!=) +NON_MEMBER_BINARY_PREDICATE(<) +NON_MEMBER_BINARY_PREDICATE(<=) +NON_MEMBER_BINARY_PREDICATE(>) +NON_MEMBER_BINARY_PREDICATE(>=) +#undef NON_MEMBER_BINARY_PREDICATE + +template matrix matmul(const matrix &a, const matrix &b) { + matrix result(a.n, b.m); + for (size_t i = 0; i < a.n; ++i) { + for (size_t j = 0; j < b.m; ++j) { + result.at(i, j) = (a.row(i) * b.col(j)).sum(); + } + } + return result; +} + +#endif diff --git a/matrix_test.cpp b/matrix_test.cpp new file mode 100644 index 0000000..dbfdf53 --- /dev/null +++ b/matrix_test.cpp @@ -0,0 +1,161 @@ +#include "matrix.hpp" + +#include + +#include +#include + +TEST_CASE("matrix construction and member functions", "[matrix]") { + matrix a(2, 2, 1), b(2, 2, 3); + SECTION("matrix construction and element access") { + a.at(0, 0) = 2; + REQUIRE(a(0, 0) == 2); + REQUIRE(a(0, 1) == 1); + REQUIRE(a(1, 0) == 1); + REQUIRE(a(1, 1) == 1); + a[0] = 3; + REQUIRE(a(0, 0) == 3); + REQUIRE(a(0, 1) == 3); + REQUIRE(a(1, 0) == 1); + REQUIRE(a(1, 1) == 1); + a.col(1) = 4; + REQUIRE(a(0, 0) == 3); + REQUIRE(a(0, 1) == 4); + REQUIRE(a(1, 0) == 1); + REQUIRE(a(1, 1) == 4); + a.row(1) = a.col(1); + REQUIRE(a(0, 0) == 3); + REQUIRE(a(0, 1) == 4); + REQUIRE(a(1, 0) == 4); + REQUIRE(a(1, 1) == 4); + a(0, 1) = 2; + matrix c = a.transpose(); + REQUIRE(c(0, 0) == 3); + REQUIRE(c(0, 1) == 4); + REQUIRE(c(1, 0) == 2); + REQUIRE(c(1, 1) == 4); + } + SECTION("matrix construction functions") { + matrix c = matrix::eye(2, 3); + REQUIRE(c(0, 0) == 1); + REQUIRE(c(0, 1) == 0); + REQUIRE(c(0, 2) == 0); + REQUIRE(c(1, 0) == 0); + REQUIRE(c(1, 1) == 1); + REQUIRE(c(1, 2) == 0); + matrix d = matrix::eye(3, 2, 1); + REQUIRE(d(0, 0) == 0); + REQUIRE(d(0, 1) == 1); + REQUIRE(d(1, 0) == 0); + REQUIRE(d(1, 1) == 0); + REQUIRE(d(2, 0) == 0); + REQUIRE(d(2, 1) == 0); + } + SECTION("aggregate functions") { + a(0, 0) = 2; + a(0, 1) = 0; + REQUIRE(a.sum() == 4); + REQUIRE(std::ranges::equal(a.sum<0>(), std::vector{3, 1})); + REQUIRE(std::ranges::equal(a.sum<1>(), std::vector{2, 2})); + auto aggregated = a.sum<0, true>(); + REQUIRE(aggregated.shape() == std::make_pair(1, 2)); + REQUIRE(std::ranges::equal(std::valarray(aggregated.row(0)), + std::vector{3, 1})); + REQUIRE(a.max() == 2); + REQUIRE(a.min() == 0); + } + SECTION("operator overloads") { + a += b; + REQUIRE(a(0, 0) == 4); + a -= b; + REQUIRE(a(0, 0) == 1); + a *= b; + REQUIRE(a(0, 0) == 3); + a /= b; + REQUIRE(a(0, 0) == 1); + a %= b; + REQUIRE(a(0, 0) == 1); + a ^= b; + REQUIRE(a(0, 0) == 2); + a |= b; + REQUIRE(a(0, 0) == 3); + a &= b; + REQUIRE(a(0, 0) == 3); + a <<= b; + REQUIRE(a(0, 0) == 24); + a >>= b; + REQUIRE(a(0, 0) == 3); + + REQUIRE((+a)(0, 0) == 3); + REQUIRE((-a)(0, 0) == -3); + REQUIRE((~a)(0, 0) == ~3); + REQUIRE((!a)(0, 0) == 0); + + a = 1; + a += 3; + REQUIRE(a(0, 0) == 4); + a -= 3; + REQUIRE(a(0, 0) == 1); + a *= 3; + REQUIRE(a(0, 0) == 3); + a /= 3; + REQUIRE(a(0, 0) == 1); + a %= 3; + REQUIRE(a(0, 0) == 1); + a ^= 3; + REQUIRE(a(0, 0) == 2); + a |= 3; + REQUIRE(a(0, 0) == 3); + a &= 3; + REQUIRE(a(0, 0) == 3); + a <<= 3; + REQUIRE(a(0, 0) == 24); + a >>= 3; + REQUIRE(a(0, 0) == 3); + } +} + +TEST_CASE("matrix arithmetic operator overload", "[matrix]") { + matrix a(2, 2, 5); + matrix b(2, 2, 7); + +#define BINARY_OP_TEST_CASE(OP) \ + SECTION(#OP) { \ + auto c = a OP b; \ + REQUIRE(c(0, 0) == (5 OP 7)); \ + auto d = a OP 7; \ + REQUIRE(d(0, 0) == (5 OP 7)); \ + auto e = 7 OP a; \ + REQUIRE(e(0, 0) == (7 OP 5)); \ + } + + BINARY_OP_TEST_CASE(+) + BINARY_OP_TEST_CASE(-) + BINARY_OP_TEST_CASE(*) + BINARY_OP_TEST_CASE(/) + BINARY_OP_TEST_CASE(%) + BINARY_OP_TEST_CASE(&) + BINARY_OP_TEST_CASE(|) + BINARY_OP_TEST_CASE(^) + BINARY_OP_TEST_CASE(<<) + BINARY_OP_TEST_CASE(>>) + BINARY_OP_TEST_CASE(&&) + BINARY_OP_TEST_CASE(||) + BINARY_OP_TEST_CASE(==) + BINARY_OP_TEST_CASE(!=) + BINARY_OP_TEST_CASE(<) + BINARY_OP_TEST_CASE(<=) + BINARY_OP_TEST_CASE(>) + BINARY_OP_TEST_CASE(>=) +#undef BINARY_OP_TEST_CASE +} + +TEST_CASE("matrix multiplication", "[matrix]") { + matrix mat(2, 2); + mat.at(0, 0) = mat.at(0, 1) = mat.at(1, 0) = 1; + const auto res = matmul(mat, mat); + REQUIRE(res[0][0] == 2); + REQUIRE(res[0][1] == 1); + REQUIRE(res[1][0] == 1); + REQUIRE(res[1][1] == 1); +}