# | Time | Username | Problem | Language | Result | Execution time | Memory |
---|---|---|---|---|---|---|---|
1125680 | salmon | Cat Exercise (JOI23_ho_t4) | C++20 | 0 ms | 0 KiB |
#include <bits/stdc++.h>
using namespace std;
int N;
int lst[200100];
vector<int> adjlst[200100];
int u,v;
bool visited[200100];
pair<int,int> st[200100 * 4];
int pre[200100];
int invpre[200100];
int post[200100];
int cont = 0;
int plst[200100];
int parent[200100][30];
int d[200100];
void build(int i, int s, int e){
if(s == e){
st[i] = {plst[s],s};
return;
}
int m = (s + e)/2;
build(i * 2,s,m);
build(i * 2 + 1,m + 1, e);
st[i] = max(st[i * 2], st[i * 2 + 1]);
}
pair<int,int> query(int i, int s, int e, int S, int E){
if(S <= s && e <= E){
return st[i];
}
pair<int,int> ii = {-1,-1};
int m = (s + e)/2;
if(S <= m){
ii = max(ii, query(i * 2,s,m,S,E));
}
if(m < E){
ii = max(ii, query(i * 2 + 1,m + 1, e,S,E));
}
return ii;
}
void deactivate(int i, int s, int e, int it){
if(s == e){
st[i] = {0,0};
return;
}
int m = (s + e)/2;
if(it <= m){
deactivate(i * 2, s,m,it);
}
else{
deactivate(i * 2 + 1, m + 1, e, it);
}
st[i] = max(st[i * 2], st[i * 2 + 1]);
}
void dfs(int i, int p, int de){
parent[i][0] = p;
d[i] = de;
pre[i] = cont;
invpre[pre[i]] = i;
plst[pre[i]] = lst[i];
cont++;
for(int j : adjlst[i]){
if(p == j) continue;
dfs(j,i, de + 1);
}
post[i] = cont - 1;
}
int dist(int a, int b){
if(d[b] < d[a]) swap(a,b);
int ans = d[b] - d[a];
for(int i = 25; i >= 0; i--){
if(parent[b][i] != -1 && d[parent[b][i]] >= d[a]){
b = parent[b][i];
}
}
if(a == b) return ans;
for(int j = 25; j >= 0; j--){
if(parent[b][j] != parent[a][j] && parent[b][j] != -1){
b = parent[b][j];
a = parent[a][j];
ans += (1<<(j)) * 2;
}
}
return ans + 2;
}
int solve(int i, int h){
visited[i] = true;
deactivate(1,0,N-1,pre[i]);
int ans = 0;
for(int j : adjlst[i]){
if(visited[j]) continue;
if(j == parent[i][0]) continue;
pair<int,int> ii = query(1,0,N - 1, pre[j], post[j]);
ans = max(ans,dist(invpre[ii.second],i) + solve(invpre[ii.second],j));
}
int p = h;
if(parent[i][0] != -1 && !visited[parent[i]][0]){
pair<int,int> ii = query(1,0,N - 1, pre[p], post[p]);
//printf("%d %d\n",ii.first,ii.second);
ans = max(ans,dist(invpre[ii.second],i) + solve(invpre[ii.second],h));
}
return ans;
}
int main(){
scanf(" %d",&N);
int r;
for(int i = 0; i < N; i++){
scanf(" %d",&lst[i]);
if(lst[i] == N) r = i;
visited[i] = false;
}
for(int i = 0; i < N - 1; i++){
scanf(" %d",&u);
scanf(" %d",&v);
u--;
v--;
adjlst[u].push_back(v);
adjlst[v].push_back(u);
}
dfs(r,-1,0);
for(int i = 0; i < N; i++){
for(int j = 1; j < 30; j++){
parent[i][j] = -1;
if(parent[i][j - 1] != -1) parent[i][j] = parent[parent[i][j - 1]][j - 1];
}
}
build(1,0,N-1);
//printf("s %d\n",query(1,0,N-1,0,N-1).first);
printf("%d\n",solve(r,r));
}