#include "split.h"
#include<iostream>
#include<algorithm>
#include<vector>
#include<cassert>
#include<utility>
#include<random>
#include<chrono>
using namespace std;
namespace{
vector<int> ord;
vector<int> sz, pa, vis, low, dep;
vector<vector<int>> graph, ng;
void dfs(int node, int parent){
vis[node] = 1;
pa[node] = parent;
dep[node] = dep[parent] + 1;
low[node] = dep[node];
sz[node] = 1;
ord.push_back(node);
for(auto &x: graph[node]){
if(x == parent) continue;
if(vis[x]){
low[node] = min(low[node], dep[x]);
}
else{
dfs(x, node);
low[node] = min(low[node], low[x]);
sz[node] += sz[x];
ng[node].push_back(x);
ng[x].push_back(node);
}
}
}
void dfs_ban(int node, int parent){ // parent = banned subtree
ord.push_back(node);
for(auto &x: ng[node]){
if(x == parent) continue;
dfs_ban(x, node);
}
}
int dfs_cent(int node, int parent, int a){
for(auto &x: ng[node]){
if(x == parent) continue;
if(sz[x] >= a) return dfs_cent(x, node, a);
}
return node;
}
}
vector<int> find_split(int n, int a, int b, int c, vector<int> p, vector<int> q) {
vector<int> ans(n);
sz.resize(n);
graph.resize(n);
pa.resize(n);
vis.resize(n);
ng.resize(n);
low.resize(n);
dep.resize(n);
for(int i = 0; i < (int)p.size(); i++){
graph[p[i]].push_back(q[i]);
graph[q[i]].push_back(p[i]);
}
int na = 1, nb = 2, nc = 3;
if(a > b){
swap(a, b);
swap(na, nb);
}
if(b > c){
swap(b, c);
swap(nb, nc);
}
if(a > b){
swap(a, b);
swap(na, nb);
}
dfs(0, 0);
vector<int> pos(n);
for(int i = 0; i < n; i++) pos[ord[i]] = i;
for(int i = 0; i < n; i++){
if(sz[i] >= a && n - sz[i] >= b){
for(int j = pos[i]; j < pos[i] + a; j++) ans[ord[j]] = na;
vector<int>().swap(ord);
dfs_ban(pa[i], i);
for(int j = 0; j < b; j++) ans[ord[j]] = nb;
for(int j = 0; j < n; j++) if(ans[j] == 0) ans[j] = nc;
return ans;
}
else if(sz[i] >= b && n - sz[i] >= a){
for(int j = pos[i]; j < pos[i] + b; j++) ans[ord[j]] = nb;
vector<int>().swap(ord);
dfs_ban(pa[i], i);
for(int j = 0; j < a; j++) ans[ord[j]] = na;
for(int j = 0; j < n; j++) if(ans[j] == 0) ans[j] = nc;
return ans;
}
}
int imp = dfs_cent(0, 0, a);
vector<int> ina, inb;
for(int i = 0; i < n; i++) if(pos[i] < pos[imp] || pos[i] >= pos[imp] + sz[imp]) ina.push_back(i);
inb.push_back(imp);
for(auto &x: ng[imp]){
if(x == pa[imp]) continue;
if(low[x] >= dep[imp]){
for(int i = pos[x]; i < pos[x] + sz[x]; i++) inb.push_back(ord[i]);
}
else{
for(int i = pos[x]; i < pos[x] + sz[x]; i++) ina.push_back(ord[i]);
}
if((int)ina.size() >= a){
bool flag = 0;
for(auto &y: ng[imp]){
if(y == x) flag = 1;
if(!flag) continue;
if(y == x) continue;
for(int i = pos[y]; i < pos[y] + sz[y]; i++) inb.push_back(ord[i]);
}
if((int)inb.size() < b) swap(ina, inb);
assert((int)ina.size() >= a && (int)inb.size() >= b);
for(int i = 0; i < a; i++) ans[ina[i]] = na;
for(int i = 0; i < b; i++) ans[inb[i]] = nb;
for(int i = 0; i < n; i++) if(ans[i] == 0) ans[i] = nc;
return ans;
}
}
return ans;
}
// g++ -std=c++17 -Wall -Wextra -Wshadow -fsanitize=undefined -fsanitize=address -o run split.cpp grader.cpp
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |