-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdsu.hpp
100 lines (87 loc) · 2.58 KB
/
dsu.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
/**
* @file dsu.hpp
* @brief Disjoint set union data structure
*/
#ifndef DSU_HPP
#define DSU_HPP
#include <functional>
#include <utility>
#include <vector>
template <class T = void, class Op = std::plus<>> class dsu;
template <class Op> class dsu<void, Op> : std::vector<int> {
public:
static const int same_set = -1;
explicit dsu(size_t n) : std::vector<int>(n, -1) {}
int find_set(int x) {
return (*this)[x] >= 0 ? (*this)[x] = find_set((*this)[x]) : x;
}
int size_of_set(int x) { return -(*this)[find_set(x)]; }
int union_sets(int u, int v) {
int gu = find_set(u), gv = find_set(v);
if (gu == gv)
return same_set;
if ((*this)[gu] < (*this)[gv])
std::swap(gu, gv);
(*this)[gv] += (*this)[gu];
(*this)[gu] = gv;
return gv;
}
size_t size() const { return std::vector<int>::size(); }
};
template <class T, class Op> class dsu : public dsu<void, Op> {
std::vector<T> dat;
Op op;
public:
using dsu<void, Op>::same_set;
explicit dsu(const std::vector<T> &v, const Op &op = {})
: dsu<void, Op>(v.size()), dat(v), op(op) {}
explicit dsu(size_t n, const T &init = {}, const Op &op = {})
: dsu(std::vector<T>(n, init), op) {}
using dsu<void, Op>::find_set;
const T &operator[](int x) const { return dat[find_set(x)]; }
T &operator[](int x) { return dat[find_set(x)]; }
int union_sets(int u, int v) {
int gu = find_set(u), gv = find_set(v);
int result = dsu<void, Op>::union_sets(gu, gv);
if (result != same_set)
dat[result] = op(dat[result], dat[result ^ gu ^ gv]);
return result;
}
};
class dsu_rollback : std::vector<int> {
std::vector<std::pair<int, int>> history;
public:
static const int same_set = -1;
static const size_t last_version = -1;
explicit dsu_rollback(size_t n) : std::vector<int>(n, -1) {}
int find_set(int x) { return (*this)[x] >= 0 ? find_set((*this)[x]) : x; }
int size_of_set(int x) { return -(*this)[find_set(x)]; }
int union_sets(int u, int v) {
int gu = find_set(u), gv = find_set(v);
if (gu == gv)
return same_set;
if ((*this)[gu] < (*this)[gv])
std::swap(gu, gv);
history.emplace_back(gu, (*this)[gu]);
(*this)[gv] += (*this)[gu];
(*this)[gu] = gv;
return gv;
}
void rollback(size_t version = last_version) {
if (version == last_version) {
if (history.empty())
return;
version = history.size() - 1;
}
while (history.size() > version) {
auto [gu, sz_gu] = history.back();
history.pop_back();
int gv = (*this)[gu];
(*this)[gv] -= sz_gu;
(*this)[gu] = sz_gu;
}
}
size_t size() const { return std::vector<int>::size(); }
size_t version() const { return history.size(); }
};
#endif