#include "bits/stdc++.h"
using namespace std;
#define ll long long
#define ln '\n'
const int N = 2e5 + 5;
const int LG = 20;
int n, glit = 1, p[N], mx[LG][N], up[LG][N], where[N], sz[N];
vector<int> adj[N];
int dfs(int u, int par){
up[0][u] = par;
mx[0][glit] = u;
where[u] = glit++;
int tot = 1;
for (auto v: adj[u]){
if (v == par) continue;
tot += dfs(v, u);
}
sz[where[u]] = tot;
return tot;
}
int answer(int l, int r){
if (l > r) return 0;
int lg = log2(r-l+1);
return p[mx[lg][l]] > p[mx[lg][r-(1<<lg)+1]]? mx[lg][l]: mx[lg][r-(1<<lg)+1];
}
bool is_ancestor(int u, int v){
return where[u] <= where[v] && where[v] <= where[u] + sz[where[u]] - 1;
}
int lca(int u, int v){
if (is_ancestor(u, v)) return u;
if (is_ancestor(v, u)) return v;
for (int i = LG-1; i >= 0; i--){
if (!is_ancestor(up[i][u], v)) u = up[i][u];
}
return up[0][u];
}
ll dist(int u, int v){
int ances = lca(u, v);
ll ans = 0;
for (int i = LG-1; i >= 0; i--){
if (!is_ancestor(up[i][u], ances)) {ans += (1 << i); u = up[i][u];}
if (!is_ancestor(up[i][v], ances)) {ans += (1 << i); v = up[i][v];}
}
return ans + (u != ances) + (v != ances);
}
int get_max(int v, const vector<int>& btm){
int nxt;
if (btm.empty()) nxt = answer(where[v],
where[v] + sz[where[v]] - 1);
else {
nxt = answer(where[v], where[btm.front()] - 1);
// cout << where[v] << ' ' << where[btm.front()] - 1 << ln;
for (int i = 0; i + 1 < (int)btm.size(); i++){
nxt = max(nxt,
answer(where[btm[i]] + sz[where[btm[i]]], where[btm[i+1]] - 1),
[&](int x, int y){return p[x] < p[y];});
// cout << nxt << ln;
}
nxt = max(nxt,
answer(where[btm.back()] + sz[where[btm.back()]], where[v] + sz[where[v]] - 1),
[&](int x, int y){return p[x] < p[y];});
// cout << nxt << ln;
}
return nxt;
}
// int db = 0;
ll f(int u, int tp, vector<int>& btm){
// if (db++ > 30) exit(0);
// cout << u << ' ' << tp << "| ";
// for (auto xx: btm) cout << xx << ' ';
// cout << ln;
sort(btm.begin(), btm.end(), [&](int x, int y){return where[x] < where[y];});
ll res = 0, nxt;
for (auto v: adj[u]){
if (v == up[0][u]) continue;
nxt = get_max(v, btm);
// if (nxt) cout << "into " << nxt << ln;
if (nxt) {
vector<int> new_btm;
res = max(res, dist(u, nxt) + f(nxt, v, new_btm));
}
}
btm.push_back(u);
sort(btm.begin(), btm.end(), [&](int x, int y){return where[x] < where[y];});
nxt = get_max(tp, btm);
// cout << u << ' ' << tp << "| ";
// for (auto xx: btm) cout << xx << ' ';
// cout << ln;
if (nxt) {
// cout << "going up " << nxt << ln;
res = max(res, dist(u, nxt) + f(nxt, tp, btm));
}
btm.pop_back();
return res;
}
void solve(){
cin >> n;
for (int i = 1; i <= n; i++) cin >> p[i];
for (int i = 0; i < n-1; i++){
int u, v; cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
}
int root = -1;
for (int i = 1; i <= n; i++) if (p[i] == n) root = i;
dfs(root, root);
for (int i = 1; i < LG; i++){
for (int j = 1; j <= n; j++) up[i][j] = up[i-1][up[i-1][j]];
for (int j = 1; j + (1 << (i-1)) <= n; j++){
mx[i][j] = (p[mx[i-1][j]] > p[mx[i-1][j + (1 << (i-1))]])? mx[i-1][j]: mx[i-1][j + (1 << (i-1))];
}
}
// for (int i = 1; i <= n; i++) cout << where[i] << ' ';
// cout << ln;
vector<int> btm;
cout << f(root, root, btm) << ln;
// cout << max(0, 0, [&](int x, int y){return p[x] < p[y];});
// btm.push_back(5); btm.push_back(4);
// cout << get_max(3, btm);
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
// int TT; cin >> TT;
// while (TT--) {solve();}
solve();
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
3 ms |
25180 KB |
Output is correct |
2 |
Correct |
3 ms |
27228 KB |
Output is correct |
3 |
Correct |
4 ms |
27228 KB |
Output is correct |
4 |
Incorrect |
3 ms |
27228 KB |
Output isn't correct |
5 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
3 ms |
25180 KB |
Output is correct |
2 |
Correct |
3 ms |
27228 KB |
Output is correct |
3 |
Correct |
4 ms |
27228 KB |
Output is correct |
4 |
Incorrect |
3 ms |
27228 KB |
Output isn't correct |
5 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
3 ms |
25180 KB |
Output is correct |
2 |
Correct |
3 ms |
27228 KB |
Output is correct |
3 |
Correct |
4 ms |
27228 KB |
Output is correct |
4 |
Incorrect |
3 ms |
27228 KB |
Output isn't correct |
5 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
3 ms |
25180 KB |
Output is correct |
2 |
Correct |
3 ms |
27228 KB |
Output is correct |
3 |
Correct |
4 ms |
27228 KB |
Output is correct |
4 |
Incorrect |
3 ms |
27228 KB |
Output isn't correct |
5 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
3 ms |
25180 KB |
Output is correct |
2 |
Correct |
3 ms |
27228 KB |
Output is correct |
3 |
Correct |
4 ms |
27228 KB |
Output is correct |
4 |
Incorrect |
3 ms |
27228 KB |
Output isn't correct |
5 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
3 ms |
25180 KB |
Output is correct |
2 |
Incorrect |
92 ms |
29380 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
3 ms |
25180 KB |
Output is correct |
2 |
Correct |
3 ms |
27228 KB |
Output is correct |
3 |
Correct |
4 ms |
27228 KB |
Output is correct |
4 |
Incorrect |
3 ms |
27228 KB |
Output isn't correct |
5 |
Halted |
0 ms |
0 KB |
- |