#include<bits/stdc++.h>
using namespace std;
#define int long long
#define vec vector
struct DSU {
vec<int> par;
vec<int> sz;
vec<map<int, set<int>>> outc;
vec<map<int, set<int>>> inc;
vec<pair<int, int>> to_merge;
vec<int> total_inc_size;
int ans = 0;
DSU(int n) {
par = vec<int>(n);
iota(par.begin(), par.end(), 0);
sz = vec<int>(n, 1);
inc = vec<map<int, set<int>>>(n);
outc = vec<map<int, set<int>>>(n);
total_inc_size = vec<int>(n);
to_merge = {};
}
int root(int x) {
if(par[x] == x) return x;
par[x] = root(par[x]);
return par[x];
}
bool join(int x, int y) {
x = root(x);
y = root(y);
if(x==y) return false;
//cerr << "MERGING: " << x << ' ' << y << '\n';
if(sz[x] < sz[y]) swap(x, y);
ans -= total_inc_size[x] * sz[x];
ans -= total_inc_size[y] * sz[y];
total_inc_size[x] -= inc[x][y].size();
total_inc_size[y] = 0;
inc[x].erase(y);
inc[y].erase(x);
outc[x].erase(y);
outc[y].erase(x);
ans -= sz[x]*(sz[x]-1);
ans -= sz[y]*(sz[y]-1);
for(auto [c, _] : outc[y]) {
for(int vertex : inc[c][y]) {
inc[c][x].insert(vertex);
}
inc[c].erase(y);
}
for(auto [c, _] : inc[y]) {
for(int vertex : outc[c][y]) {
outc[c][x].insert(vertex);
}
outc[c].erase(y);
}
for(auto [c, vertices] : outc[y]) {
for(int v : vertices) {
outc[x][c].insert(v);
//assert(root(v) == c);
if(inc[x][c].size() > 0) {
to_merge.push_back({x, v});
}
}
}
for(auto [c, vertices] : inc[y]) {
for(int v : vertices) {
total_inc_size[x] -= inc[x][c].size();
inc[x][c].insert(v);
//assert(root(v) == c);
if(outc[x][c].size() > 0) {
to_merge.push_back({x, v});
}
total_inc_size[x] += inc[x][c].size();
}
}
inc[y] = {};
outc[y] = {};
par[y] = x;
sz[x] += sz[y];
sz[y] = 0;
ans += total_inc_size[x] * sz[x];
ans += sz[x]*(sz[x]-1);
return true;
}
void merge_all() {
while(to_merge.size() > 0) {
auto [x, y] = to_merge.back();
to_merge.pop_back();
join(x, y);
}
}
void make_edge(int u, int v) {
for(int i = 0; i<par.size(); i++) {
for(auto [c, vs] : outc[i]) {
//for(int v : vs) assert(root(v) == c);
}
}
for(int i = 0; i<par.size(); i++) {
for(auto [c, vs] : inc[i]) {
for(int v : vs) assert(root(v) == c);
}
}
if(root(u) == root(v)) return;
ans -= total_inc_size[root(v)] * sz[root(v)];
outc[root(u)][root(v)].insert(u);
total_inc_size[root(v)] -= inc[root(v)][root(u)].size();
inc[root(v)][root(u)].insert(u);
total_inc_size[root(v)] += inc[root(v)][root(u)].size();
ans += total_inc_size[root(v)] * sz[root(v)];
if(outc[root(u)][root(v)].size() > 0 && inc[root(u)][root(v)].size() > 0) {
to_merge.push_back({u, v});
merge_all();
}
}
};
const int MXN = 100'005;
int32_t main() {
int N, M;
cin >> N >> M;
DSU dsu(N);
while(M--) {
int u, v;
cin >> u >> v;
u--;v--;
dsu.make_edge(u, v);
cout << dsu.ans << '\n';
}
}
Compilation message
joitter2.cpp: In member function 'void DSU::make_edge(long long int, long long int)':
joitter2.cpp:108:19: warning: comparison of integer expressions of different signedness: 'long long int' and 'std::vector<long long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
108 | for(int i = 0; i<par.size(); i++) {
| ~^~~~~~~~~~~
joitter2.cpp:113:19: warning: comparison of integer expressions of different signedness: 'long long int' and 'std::vector<long long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
113 | for(int i = 0; i<par.size(); i++) {
| ~^~~~~~~~~~~
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
0 ms |
348 KB |
Output is correct |
2 |
Correct |
0 ms |
348 KB |
Output is correct |
3 |
Incorrect |
0 ms |
348 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
0 ms |
348 KB |
Output is correct |
2 |
Correct |
0 ms |
348 KB |
Output is correct |
3 |
Incorrect |
0 ms |
348 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
0 ms |
348 KB |
Output is correct |
2 |
Correct |
0 ms |
348 KB |
Output is correct |
3 |
Incorrect |
0 ms |
348 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |