#include "split.h"
#include <bits/stdc++.h>
#define maxn 200005
#define fi first
#define se second
using namespace std;
using ii = pair<int, int>;
int N, A, B, C, M, cl[maxn], par[maxn], ID[4], sz[maxn], num[maxn], id = 0, cntA = 0, cntB = 0, cntC = 0, euler[maxn];
ii low[maxn];
vector<int> adj[maxn];
int kq[maxn];
void dfsA(int u, int dad) {
if (cntA == A) return;
++cntA; kq[u] = 1;
for (int v : adj[u])
if (!kq[v] && v != dad && (v == par[u] || u == par[v])) dfsA(v, u);
}
void dfsB(int u, int dad) {
if (cntB == B) return;
++cntB; kq[u] = 2;
for (int v : adj[u])
if (!kq[v] && v != dad && (v == par[u] || u == par[v])) dfsB(v, u);
}
void dfsC(int u, int dad) {
if (cntC == C) return;
++cntC; kq[u] = 3;
for (int v : adj[u])
if (!kq[v] && v != dad && (v == par[u] || u == par[v])) dfsC(v, u);
}
//A < B < C
void pfs(int u, int dad) {
cl[u] = sz[u] = 1; par[u] = dad;
num[u] = ++id;
low[u] = ii{id, u};
euler[id] = u;
for (int v : adj[u])
if (v != dad) {
if (!cl[v]) {
pfs(v, u);
sz[u] += sz[v];
low[u] = min(low[u], low[v]);
} else low[u] = min(low[u], ii{num[v], u});
}
}
int find_centroid(int u, int dad) {
int mx = 0;
for (int v : adj[u])
if (v != dad && (v == par[u] || u == par[v])) mx = max(mx, sz[v]);
if (mx < A) return u;
for (int v : adj[u])
if (v != dad && (v == par[u] || u == par[v]) && sz[v] >= A) return find_centroid(v, u);
}
vector<int> solve() {
pfs(0, -1);
int centroid = find_centroid(0, -1);
if (N - sz[centroid] >= A) {
if (sz[centroid] >= B) {
dfsA(par[centroid], centroid);
dfsB(centroid, par[centroid]);
for (int i = 0; i < N; i++) if (!kq[i]) kq[i] = 3;
return vector<int>(kq, kq+N);
}
dfsA(centroid, par[centroid]);
dfsB(par[centroid], centroid);
for (int i = 0; i < N; i++) if (!kq[i]) kq[i] = 3;
return vector<int>(kq, kq+N);
}
vector<int> lost;
int s1 = N - sz[centroid], s2 = sz[centroid];
for (int v : adj[centroid])
if (v != par[centroid] && (v == par[centroid] || par[centroid] == v)) {
if (low[v].fi < num[centroid]) {
s1 += sz[v];
s2 -= sz[v];
lost.emplace_back(low[v].se);
}
if (s1 >= A) break;
}
if (s1 < A) {
return vector<int>(kq, kq+N);
}
if (s2 >= B) {
kq[centroid] = 2;
for (int i : lost) dfsA(i, -1);
dfsA(par[centroid], centroid);
dfsB(centroid, par[centroid]);
for (int i = 0; i < N; i++) if (!kq[i]) kq[i] = 3;
return vector<int>(kq, kq+N);
}
if (s1 >= B) {
kq[centroid] = 1;
for (int i : lost) dfsB(i, -1);
dfsB(par[centroid], centroid);
dfsA(centroid, par[centroid]);
for (int i = 0; i < N; i++) if (!kq[i]) kq[i] = 3;
return vector<int>(kq, kq+N);
}
return vector<int>(kq, kq+N);
}
vector<int> find_split(int n, int a, int b, int c, vector<int> p, vector<int> q) {
vector<int> res(n, 0);
N = n; A = a; B = b; C = c; M = p.size();
for (int i = 0; i < M; i++) {
adj[p[i]].emplace_back(q[i]);
adj[q[i]].emplace_back(p[i]);
}
ID[1] = 1; ID[2] = 2; ID[3] = 3;
if (A > B) {
swap(A, B);
swap(ID[1], ID[2]);
}
if (A > C) {
swap(A, C);
swap(ID[1], ID[3]);
}
if (B > C) {
swap(B, C);
swap(ID[2], ID[3]);
}
vector<int> ans = solve();
for (int i = 0; i < N; i++) ans[i] = ID[ans[i]];
return ans;
}
/*
9 10
4 2 3
0 1
0 2
0 3
0 4
0 6
0 8
1 7
3 7
4 5
5 6
*/
Compilation message (stderr)
split.cpp: In function 'int find_centroid(int, int)':
split.cpp:63:1: warning: control reaches end of non-void function [-Wreturn-type]
63 | }
| ^
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |