이 제출은 이전 버전의 oj.uz에서 채점하였습니다. 현재는 제출 당시와는 다른 서버에서 채점을 하기 때문에, 다시 제출하면 결과가 달라질 수도 있습니다.
#include "split.h"
#include <bits/stdc++.h>
using namespace std;
#define sz(v) int(v.size())
#define ar array
typedef long long ll;
const int N = 2e5+10, MOD = 1e9+7;
struct DSU {
vector<int> par, sz;
DSU() {}
DSU(int n): par(n) {
iota(par.begin(), par.end(), 0);
sz.assign(n, 1);
}
int find_set(int v) {
return v == par[v] ? v : par[v] = find_set(par[v]);
}
bool union_sets(int a, int b) {
if ((a = find_set(a)) == (b = find_set(b))) return false;
if (sz[a] < sz[b]) swap(a, b);
par[b] = a, sz[a] += sz[b], sz[b] = 0;
return true;
}
} d;
int n, m, sub[N], par[N];
vector<int> adj[N];
int dfs_sub(int c, int p) {
sub[c] = 1, par[c] = p;
for (int nxt : adj[c]) if (nxt != p) {
sub[c] += dfs_sub(nxt, c);
}
return sub[c];
}
bool vis[N];
int get_leaf(int c) {
vis[c] = 1;
for (int nxt : adj[c]) if (!vis[nxt]) {
return get_leaf(nxt);
}
return c;
}
vector<int> build;
void dfs_build(int c, int cnt, int skip) {
vis[c] = 1;
if (c == skip) return;
if (sz(build) < cnt) build.push_back(c);
for (int nxt : adj[c]) if (!vis[nxt]) {
dfs_build(nxt, cnt, skip);
}
}
vector<int> gather(int root, int p, int cnt) {
// assert(dfs_sub(root, p) >= cnt);
memset(vis, 0, sizeof(vis));
vector<int> q{root}; vis[root] = vis[p] = 1;
for (int rep = 0; rep < sz(q); rep++) {
int c = q[rep];
for (int nxt : adj[c]) if (!vis[nxt]) {
q.push_back(nxt);
vis[nxt] = 1;
}
}
assert(sz(q) >= cnt);
q.resize(cnt);
// assert(sz(q) == cnt);
return q;
}
void dfs_dumb(int c, int p1, int p2, int p3, vector<int>& ans) {
vis[c] = 1;
if (p1) ans[c] = 1, p1--;
else if (p2) ans[c] = 2, p2--;
else if (p3) ans[c] = 3, p3--;
for (int nxt : adj[c]) if (!vis[nxt]) {
dfs_dumb(nxt, p1, p2, p3, ans);
}
}
vector<int> find_split(int _n, int p1, int p2, int p3, vector<int> p, vector<int> q) {
n = _n, m = sz(p);
for (int i = 0; i < m; i++) {
int x = p[i], y = q[i];
adj[x].push_back(y), adj[y].push_back(x);
}
vector<pair<int, int>> cols(3);
cols[0] = {p1, 1}, cols[1] = {p2, 2}, cols[2] = {p3, 3};
sort(cols.begin(), cols.end());
if (cols[0].first == 1) {
memset(vis, 0, sizeof(vis));
int leaf = get_leaf(0);
memset(vis, 0, sizeof(vis));
dfs_build(0, cols[1].first, leaf);
vector<int> ans(n, cols[2].second);
ans[leaf] = cols[0].second;
for (int x : build) ans[x] = cols[1].second;
return ans;
}
bool bad = 0;
for (int i = 0; i < n; i++) if (sz(adj[i]) != 2) bad = 1;
if (!bad) {
memset(vis, 0, sizeof(vis));
vector<int> ans(n);
dfs_dumb(0, p1, p2, p3, ans);
return ans;
}
assert(m == n-1);
dfs_sub(0, -1);
pair<int, int> cut{-1, -1};
for (int c = 1; c < n; c++) {
if (sub[c] >= cols[0].first && n - sub[c] >= cols[1].first) {
cut = {c, par[c]};
}
if (n - sub[c] >= cols[0].first && sub[c] >= cols[1].first) {
cut = {par[c], c};
}
}
if (cut == make_pair(-1, -1)) return vector<int>(n, 0);
vector<int> one = gather(cut.first, cut.second, cols[0].first);
vector<int> two = gather(cut.second, cut.first, cols[1].first);
// cout << sz(one) << ' ' << sz(two) << endl;
vector<int> ans(n, cols[2].second);
for (int x : one) ans[x] = cols[0].second;
for (int x : two) ans[x] = cols[1].second;
return ans;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |