#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
struct unionFind{
int n;
int par[100002];
ll sz[100002];
set<int> st[100002]; /// 집합의 원소 종류, 그리고 해당 집합에 edge를 연결한 원소 종류
set<int> in[100002]; /// x번 집합을 가리키고 있는 집합의 번호
set<int> out[100002]; /// x번 집합이 가리키고 있는 집합의 번호
ll ans = 0;
void init(int _n){
n = _n;
for(int i=1; i<=n; i++) par[i] = i, sz[i] = 1, st[i].insert(i);
ans = 0;
}
int find(int x){
if(x==par[x]) return x;
return par[x] = find(par[x]);
}
queue<pair<int, int> > que;
void merge(){
while(!que.empty()){
int x = que.back().first, y = que.back().second; que.pop();
x = find(x), y = find(y);
if(x==y) continue;
if(st[x].size()+in[x].size()+out[x].size() < st[y].size()+in[y].size()+out[y].size()) swap(x, y);
/// (tox, x, ytox) -> y, (toy, y, xtoy) -> x
/// 제거해줘야 하는 것: x, y 안쪽에서의 묶음. 양쪽에 포함된 수는 제거 필요.
ans += sz[x]*st[y].size() + sz[y]*st[x].size();
for(int p: st[y]){
if(st[x].count(p)) ans -= sz[x] + sz[y];
else st[x].insert(p);
}
/// 연결 제거
in[x].erase(y), in[y].erase(x);
out[x].erase(y), out[y].erase(x);
/// 그래프 관리
for(int p: out[y]){
in[p].erase(y);
addEdge(x, p);
}
for(int p: in[y]){
out[p].erase(x);
addEdge(p, x);
}
in[y].clear(), out[y].clear(), st[y].clear();
par[y] = x, sz[x] += sz[y];
}
}
void addEdge(int x, int y){
out[x].insert(y);
in[y].insert(x);
/// 이걸로 인해 사이클이 생기면 merge
if(in[x].count(y)){
que.push(make_pair(x, y));
}
}
} dsu;
int n, q;
int main(){
scanf("%d %d", &n, &q);
dsu.init(n);
for(int i=1; i<=q; i++){
int x, y;
scanf("%d %d", &x, &y);
y = dsu.find(y);
int xp = dsu.find(x);
/// 이미 같은 컴포넌트 or 이미 가리키고 있음
if(xp!=y && !dsu.st[y].count(x)){
dsu.st[y].insert(x);
dsu.ans += dsu.sz[y];
x = dsu.find(x);
dsu.addEdge(x, y);
while(!dsu.que.empty()) dsu.merge();
}
printf("%lld\n", dsu.ans);
// printf("After query %d: \n", i);
// for(int i=1; i<=n; i++) {printf("st[%d]: ", i); for(auto j: dsu.st[i]) printf("%d ", j); printf(",\t");} puts("");
// for(int i=1; i<=n; i++) {printf("in[%d]: ", i); for(auto j: dsu.in[i]) printf("%d ", j); printf(",\t");} puts("");
// for(int i=1; i<=n; i++) {printf("out[%d]: ", i); for(auto j: dsu.out[i]) printf("%d ", j); printf(",\t");} puts("");
}
}
Compilation message
joitter2.cpp: In function 'int main()':
joitter2.cpp:75:10: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
75 | scanf("%d %d", &n, &q);
| ~~~~~^~~~~~~~~~~~~~~~~
joitter2.cpp:79:14: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
79 | scanf("%d %d", &x, &y);
| ~~~~~^~~~~~~~~~~~~~~~~
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
7 ms |
14292 KB |
Output is correct |
2 |
Correct |
8 ms |
14292 KB |
Output is correct |
3 |
Correct |
6 ms |
14360 KB |
Output is correct |
4 |
Correct |
8 ms |
14292 KB |
Output is correct |
5 |
Correct |
7 ms |
14392 KB |
Output is correct |
6 |
Correct |
6 ms |
14340 KB |
Output is correct |
7 |
Incorrect |
7 ms |
14420 KB |
Output isn't correct |
8 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
7 ms |
14292 KB |
Output is correct |
2 |
Correct |
8 ms |
14292 KB |
Output is correct |
3 |
Correct |
6 ms |
14360 KB |
Output is correct |
4 |
Correct |
8 ms |
14292 KB |
Output is correct |
5 |
Correct |
7 ms |
14392 KB |
Output is correct |
6 |
Correct |
6 ms |
14340 KB |
Output is correct |
7 |
Incorrect |
7 ms |
14420 KB |
Output isn't correct |
8 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
7 ms |
14292 KB |
Output is correct |
2 |
Correct |
8 ms |
14292 KB |
Output is correct |
3 |
Correct |
6 ms |
14360 KB |
Output is correct |
4 |
Correct |
8 ms |
14292 KB |
Output is correct |
5 |
Correct |
7 ms |
14392 KB |
Output is correct |
6 |
Correct |
6 ms |
14340 KB |
Output is correct |
7 |
Incorrect |
7 ms |
14420 KB |
Output isn't correct |
8 |
Halted |
0 ms |
0 KB |
- |