#include <bits/stdc++.h>
#define ll long long
#define pii pair<int, int>
#define fst first
#define snd second
using namespace std;
const int IN = 0, OUT = 1;
int N, M;
int rep[100001][2], par[100001]; // 0 -> in; 1 -> out
ll sz[100001];
set<pii> deg[100001][2];
ll sum = 0;
queue<pii> Q;
int find(int x) {return par[x] = (par[x] == x) ? x : find(par[x]);}
void join(int a, int b) {
a = find(a), b = find(b);
if (a == b) return;
sum -= sz[a] * sz[a];
sum -= sz[b] * sz[b];
int Ain = rep[a][IN], Bin = rep[b][IN];
sum -= deg[Ain][IN].size() * sz[a];
sum -= deg[Bin][IN].size() * sz[b];
sz[a] += sz[b];
par[b] = a;
sum += sz[a] * sz[a];
if (deg[Ain][IN].size() < deg[Bin][IN].size()) swap(Ain, Bin);
for (const auto &[x, y] : deg[Bin][IN]) {
deg[Ain][IN].insert({x, y});
deg[x][OUT].erase({Bin, y});
deg[x][OUT].insert({Ain, y});
}
deg[Bin][IN].clear();
sum += deg[Ain][IN].size() * sz[a];
rep[a][IN] = Ain;
int Aout = rep[a][OUT], Bout = rep[b][OUT];
if (deg[Aout][OUT].size() < deg[Bout][OUT].size()) swap(Aout, Bout);
for (const auto &[x, y] : deg[Bout][OUT]) {
deg[Aout][OUT].insert({x, y});
deg[x][IN].erase({Bout, y});
deg[x][IN].insert({Aout, y});
auto it = deg[rep[find(x)][OUT]][OUT].lower_bound(make_pair(Ain, -1));
if (it != deg[rep[find(x)][OUT]][OUT].end()) {
if (it -> first == Ain) Q.push({a, find(x)});
}
}
deg[Bout][OUT].clear();
rep[a][OUT] = Aout;
//cerr << a << " join " << b << " " << sum << "\n";
}
void eraseAllX(int x, set<pii>& S) {
auto it = S.lower_bound(make_pair(x, -1));
while (it != S.end()) {
if (it -> first == x) it = S.erase(it);
else break;
}
}
int main() {
ios :: sync_with_stdio(0); cin.tie(0);
cin >> N >> M;
for (int i = 0; i < N; i++) {
rep[i][IN] = rep[i][OUT] = i;
par[i] = i;
sz[i] = 1;
}
for (int i = 0; i < M; i++) {
int u, v; cin >> u >> v; u--; v--;
int ruo = rep[find(u)][OUT], rui = rep[find(u)][IN];
int rvo = rep[find(v)][OUT], rvi = rep[find(v)][IN];
if (deg[ruo][OUT].count({rvi, u}) == 0) sum += sz[find(v)];
deg[ruo][OUT].insert({rvi, u});
deg[rvi][IN].insert({ruo, u});
auto it = deg[rui][IN].lower_bound(make_pair(rvo, -1));
if (it != deg[rui][IN].end()) {
if (it -> first == rvo) Q.push({find(u), find(v)});
}
while (Q.size()) {
auto [a, b] = Q.front(); Q.pop();
a = find(a), b = find(b);
ll cntA = deg[rep[a][IN]][IN].size(), cntB = deg[rep[b][IN]][IN].size();
eraseAllX(rep[b][OUT], deg[rep[a][IN]][IN]);
eraseAllX(rep[a][OUT], deg[rep[b][IN]][IN]);
cntA -= deg[rep[a][IN]][IN].size();
cntB -= deg[rep[b][IN]][IN].size();
eraseAllX(rep[b][IN], deg[rep[a][OUT]][OUT]);
eraseAllX(rep[a][IN], deg[rep[b][OUT]][OUT]);
sum -= cntA * sz[b] + cntB * sz[a];
//cerr << "SUM UP TO HERE: " << sum << "\n";
join(a, b);
//cerr << ":: " << a << " JOIN " << b << "\n";
}
cout << sum << "\n";
}
return 0;
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
2 ms |
10840 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
2 ms |
10840 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
2 ms |
10840 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |