diff --git a/matrix.hpp b/matrix.hpp index 71b19d7..97b16d0 100644 --- a/matrix.hpp +++ b/matrix.hpp @@ -7,6 +7,7 @@ #define MATRIX_HPP #include +#include template class matrix; @@ -68,7 +69,7 @@ template class matrix { matrix(const std::vector> &v) : n(v.size()), m(v.empty() ? 0 : v[0].size()), dat(n * m) { for (size_t i = 0; i < n; ++i) { - std::ranges::copy(v[i], dat.begin() + i * m); + std::ranges::copy(v[i], begin(dat) + i * m); } } diff --git a/matrix_test.cpp b/matrix_test.cpp index dbfdf53..c61f2bb 100644 --- a/matrix_test.cpp +++ b/matrix_test.cpp @@ -113,6 +113,23 @@ TEST_CASE("matrix construction and member functions", "[matrix]") { a >>= 3; REQUIRE(a(0, 0) == 3); } + SECTION("std::vector conversion") { + std::vector> v = a; + REQUIRE(v[0][0] == 1); + REQUIRE(v[1][0] == 1); + REQUIRE(v[1][1] == 1); + REQUIRE(v[0][1] == 1); + matrix m = v; + REQUIRE(m(0, 0) == 1); + REQUIRE(m(1, 0) == 1); + REQUIRE(m(1, 1) == 1); + REQUIRE(m(0, 1) == 1); + std::vector> result = matrix(v) + m; + REQUIRE(result[0][0] == 2); + REQUIRE(result[1][0] == 2); + REQUIRE(result[1][1] == 2); + REQUIRE(result[0][1] == 2); + } } TEST_CASE("matrix arithmetic operator overload", "[matrix]") {