#include "incursion.h"
#include <bits/stdc++.h>
using namespace std;
int n;
vector<vector<int>> G;
vector<vector<int>> csz;
vector<int> sz;
void dfs(int u, int p){
if(p != 0) {
G[u].erase(find(G[u].begin(), G[u].end(), p));
G[u].push_back(p);
}
csz[u].clear();
sz[u] = 1;
for(int c: G[u]){
if(c != p){
dfs(c, u);
sz[u] += sz[c];
csz[u].push_back(sz[c]);
}
}
if(p) csz[u].push_back(n - sz[u]);
}
void Parse(vector<pair<int,int>> F){
n = F.size() + 1;
G.clear();
G.resize(n + 1);
csz.resize(n + 1);
for(pair<int,int> e: F){
int u = e.first, v = e.second;
G[u].push_back(v);
G[v].push_back(u);
}
sz.clear();
sz.resize(n + 1, 0);
dfs(1, 0);
}
bool dfs1(int u, int p, int t, vector<int> &mk){
if(u == t){
mk[u] = 1;
return 1;
}
for(int c: G[u]){
if(c != p){
if(dfs1(c, u, t, mk)){
mk[u] = 1;
return 1;
}
}
}
return 0;
}
vector<int> mark(vector<pair<int, int>> F, int safe) {
Parse(F);
vector<int> cen;
for(int i = 1; i <= n; i++){
if(*max_element(csz[i].begin(), csz[i].end()) <= n / 2) cen.push_back(i);
}
vector<int> res(n + 1, 0);
int cs = accumulate(cen.begin(), cen.end(), 0);
for(int c: cen){
dfs1(c, cs - c, safe, res);
}
// dfs1(safe, 0, res);
res.erase(res.begin());
return res;
}
vector<int> par;
template<class T>
void Erase(vector<T> &v, T e){
auto it = find(v.begin(), v.end(), e);
if(it == v.end()) return;
v.erase(it);
}
void dfs2(int u, int p){
if(p){
int idx = find(G[u].begin(), G[u].end(), p) - G[u].begin();
G[u].erase(G[u].begin() + idx);
csz[u].erase(csz[u].begin() + idx);
}
par[u] = p;
for(int c: G[u]){
if(c != p){
dfs2(c, u);
}
}
}
void locate(vector<pair<int, int>> F, int curr, int t) {
Parse(F);
vector<int> cen;
for(int i = 1; i <= n; i++){
if(*max_element(csz[i].begin(), csz[i].end()) <= n / 2) cen.push_back(i);
}
int cs = accumulate(cen.begin(), cen.end(), 0);
par.clear();
par.resize(n + 1, 0);
for(int c: cen){
dfs2(c, cs - c);
}
while(!t){
t = visit(par[curr]);
curr = par[curr];
}
int pre = curr;
while(1){
if(G[curr].empty()) break;
int idx = max_element(csz[curr].begin(), csz[curr].end()) - csz[curr].begin();
int nxt = G[curr][idx];
int szi = csz[curr][idx];
G[curr].erase(G[curr].begin() + idx);
csz[curr].erase(csz[curr].begin() + idx);
int nt = visit(nxt);
if(nt == 0){
t = visit(curr);
}else{
curr = nxt;
t = nt;
}
}
}