이 제출은 이전 버전의 oj.uz에서 채점하였습니다. 현재는 제출 당시와는 다른 서버에서 채점을 하기 때문에, 다시 제출하면 결과가 달라질 수도 있습니다.
#include "split.h"
#include <bits/stdc++.h>
#define Loop(x,l,r) for (ll x = (l); x < (ll)(r); ++x)
#define LoopR(x,l,r) for (ll x = (r)-1; x >= (ll)(l); --x)
typedef long long ll;
typedef std::pair<int, int> pii;
typedef std::pair<ll , ll > pll;
using namespace std;
const int N = 100'010;
vector<int> A[N];
vector<int> C[N];
int sz[N], mn[N], height[N];;
bool vis[N];
void dfs0(int v, int h)
{
vis[v] = 1;
mn[v] = height[v] = h;
sz[v] = 1;
for (int u : A[v]) {
if (vis[u]) {
mn[v] = min(mn[v], height[u]);
continue;
}
C[v].push_back(u);
dfs0(u, h+1);
sz[v] += sz[u];
mn[v] = min(v, mn[u]);
}
}
bool is_in[N];
int add(int v, int rt) {
if (is_in[v] || mn[v] >= height[rt])
return 0;
int ans = sz[v];
for (int u : C[v]) {
if (is_in[u]) {
ans -= sz[u];
is_in[u] = 0;
}
}
is_in[v] = 1;
return ans;
}
int rem(int v, int rt) {
if (!is_in[v])
return 0;
is_in[v] = 0;
return sz[v];
}
int n, m;
void merge(set<pii> &a, set<pii> &b)
{
if (a.size() < b.size())
a.swap(b);
for (auto x : b)
a.insert(x);
b.clear();
}
int dfs1(int v, int sz1, int sz2, set<pii> &by_mn, set<pii> &by_sz, int &sum)
{
by_mn = {{mn[v], v}};
by_sz = {{sz[v], v}};
sum = 0;
for (int u : C[v]) {
set<pii> x, y;
int z;
int ret;
ret = dfs1(u, sz1, sz2, x, y, z);
if (ret != -1)
return ret;
merge(by_mn, x);
merge(by_sz, y);
sum += z;
}
while (by_sz.size()) {
int u = by_sz.begin()->second;
if (sz[v] - sz[u] < sz2)
break;
sum += add(u, v);
by_sz.erase(by_sz.begin());
}
while (by_mn.size()) {
int u = by_mn.begin()->second;
if (mn[u] > height[v])
break;
sum -= rem(u, v);
by_mn.erase(by_mn.begin());
}
if (sz[v] >= sz2 && n - sz[v] + sum >= sz1)
return v;
return -1;
}
void dfs2(int v, int rt, int sz_target, vector<int> &vec)
{
if (sz[rt] - sz[v] >= sz_target && mn[v] < height[rt]) {
vec.push_back(v);
return;
}
for (int u : C[v]) {
dfs2(u, rt, sz_target, vec);
}
}
void dfs_col(int v, vector<int> &col, int c)
{
col[v] = c;
for (int u : C[v])
dfs_col(u, col, c);
}
void dfs3(int v, vector<int> &col, int c, int &rem)
{
if (!rem)
return;
vis[v] = 1;
col[v] = c;
--rem;
for (int u : A[v]) {
if (vis[u])
continue;
dfs3(u, col, c, rem);
}
}
vector<int> solve(int a, int b, int sa, int sb, int sc)
{
set<pii> x, y;
int z;
memset(is_in, 0, sizeof(is_in));
int v = dfs1(0, a, b, x, y, z);
if (v == -1)
return {};
vector<int> vec;
dfs2(v, v, b, vec);
sort(vec.begin(), vec.end(), [](int i, int j) {
return sz[i] < sz[j];
});
vector<int> col(n);
dfs_col(0, col, -1);
dfs_col(v, col, -2);
int sza = n - sz[v], szb = sz[v];
while (vec.size() && sza < a) {
int u = vec.back();
vec.pop_back();
sza += sz[u];
szb -= sz[u];
dfs_col(u, col, -1);
}
assert(sza >= a);
assert(szb >= b);
Loop (i,0,n)
vis[i] = col[i] != -1;
dfs3(0, col, 0, a);
assert(a == 0);
Loop (i,0,n)
vis[i] = col[i] != -2;
dfs3(v, col, 1, b);
assert(b == 0);
Loop (i,0,n) {
if (col[i] < 0)
col[i] = 2;
}
Loop (i,0,n)
col[i] = vector<int>{sa, sb, sc}[col[i]];
return col;
}
vector<int> find_split(int _n, int a, int b, int c, vector<int> p, vector<int> q)
{
n = _n;
Loop (i,0,p.size()) {
int v = p[i], u = q[i];
A[v].push_back(u);
A[u].push_back(v);
}
dfs0(0, 0);
int sa = 1, sb = 2, sc = 3;
if (a > b) {
swap(a, b);
swap(sa, sb);
}
if (b > c) {
swap(b, c);
swap(sb, sc);
}
if (a > b) {
swap(a, b);
swap(sa, sb);
}
vector<int> ans;
if ((ans = solve(a, b, sa, sb, sc)).size())
return ans;
if ((ans = solve(b, a, sb, sa, sc)).size())
return ans;
return vector<int>(n, 0);
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |