diff --git a/utilities.hpp b/utilities.hpp index c4e9832..6a8f62b 100644 --- a/utilities.hpp +++ b/utilities.hpp @@ -191,4 +191,38 @@ template struct magic_vector { } }; +template class vector_rollback { + std::vector data; + std::vector> history; + + public: + static constexpr size_t last_version = -1; + + template + vector_rollback(Args... args) : data(args...), history() {} + size_t size() const { return data.size(); } + size_t version() const { return history.size(); } + T modify(size_t i, T x) { + history.emplace_back(i, data[i]); + data[i] = x; + return x; + } + const T &front() const { return data.front(); } + const T &back() const { return data.back(); } + const T &operator[](size_t i) const { return data[i]; } + const T &at(size_t i) const { return data.at(i); } + void rollback(size_t version = last_version) { + if (version == last_version) { + if (history.empty()) + return; + version = history.size() - 1; + } + while (history.size() > version) { + auto [i, x] = history.back(); + history.pop_back(); + data[i] = x; + } + } +}; + #endif diff --git a/utilities_test.cpp b/utilities_test.cpp index 0265a1a..bfbecd1 100644 --- a/utilities_test.cpp +++ b/utilities_test.cpp @@ -76,3 +76,18 @@ TEST_CASE("compression_vector constructors", "[utilities]") { v3.compress(); REQUIRE(v3[0] == ""); } + +TEST_CASE("vector with rollback", "[utilities]") { + std::vector v{2, 4, 3, 5, 8, 3, 5, 5}; + vector_rollback v2(v); + v2.modify(3, 10); + REQUIRE(v2[3] == 10); + v2.rollback(); + REQUIRE(v2[3] == 5); + v2.modify(3, 10); + v2.modify(4, 10); + v2.modify(3, 5); + v2.rollback(1); + REQUIRE(v2[3] == 10); + REQUIRE(v2[4] == 8); +}