This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
#include <bits/stdc++.h>
using namespace std;
#define dbg(x) x
#define prt(x) dbg(cerr << x)
#define pv(x) dbg(cerr << #x << " = " << x << '\n')
#define pv2(x) dbg(cerr << #x << " = " << x.first << ',' << x.second << '\n')
#define parr(x) dbg(prt(#x << " = { "); for (auto y : x) prt(y << ' '); prt("}\n");)
#define parr2(x) dbg(prt(#x << " = { "); for (auto [y, z] : x) prt(y << ',' << z << " "); prt("}\n");)
#define parr2d(x) dbg(prt(#x << ":\n"); for (auto arr : x) {parr(arr);} prt('\n'));
#define parr2d2(x) dbg(prt(#x << ":\n"); for (auto arr : x) {parr2(arr);} prt('\n'));
const int lg = 20;
/*
tree
block nodes 1 by 1
on each move:
block a node
if you block the node with the cat on it
the cat moves to the MAXIMUM REACHABLE NODE
goal is make cat move as much as possible
note that whatever node you are at, you always have a number of subgraphs
you can break off into
pick the subgraph that yields you the best answer when you go to the max
value inside it
this is kind of the equivalent of blocking every other subgraph, then blocking
this node so you end up in the subgraph you chose
for (each subgraph)
dp[node] = max(dp[node], dp[subgraph] + dist(node, max in subgraph))
how do you get the max in subgraph though
same algo as centroid
if we can impl this in n^2, we get like 31 points
other special cases: path, binary tree
n^2:
find the max within the given subgraphs with a centroid-like dnc, except with
more layers
optimize n^2:
maybe reconstruct the "centroid" (really the maxvals) tree
but you still kinda need to find the max node in each subgraph
that is the bottleneck here
how to quickly find the max node in a subgraph?
also maybe you can like add the nodes back to the tree in order
starting with an empty graph
and maybe you can tell like when a node's component is a proper subgraph?
always is
and you are always adding the max to the subgraph...
so you always do dsu
*/
int main() {
ios::sync_with_stdio(0); cin.tie(0);
int n;
cin >> n;
vector<int> p(n), at(n);
for (int i = 0; i < n; i++) {
cin >> p[i];
p[i]--;
at[p[i]] = i;
}
vector<vector<int>> edges(n);
bool path = true;
for (int i = 0; i < n - 1; i++) {
int x, y;
cin >> x >> y;
x--; y--;
if (y != x + 1) path = false;
edges[x].push_back(y);
edges[y].push_back(x);
}
if (path) {
vector<vector<int>> rmq(n, vector<int>(lg, 0));
for (int i = 0; i < n; i++) {
rmq[i][0] = p[i];
}
for (int j = 1; j < lg; j++) {
for (int i = 0; i < n - (1 << j) + 1; i++) {
rmq[i][j] = max(rmq[i][j - 1], rmq[i + (1 << (j - 1))][j - 1]);
}
}
function<int(int, int)> rmx = [&] (int l, int r) {
int p2 = 31 - __builtin_clz(r - l + 1);
return max(rmq[l][p2], rmq[r - (1 << p2) + 1][p2]);
};
vector<long long> best(n, 0);
function<void(int, int, int)> dnc = [&] (int l, int r, int k) {
if (k != l) {
int kl = at[rmx(l, k - 1)];
dnc(l, k - 1, kl);
best[k] = max(best[k], best[kl] + k - kl);
}
if (k != r) {
int kr = at[rmx(k + 1, r)];
dnc(k + 1, r, kr);
best[k] = max(best[k], best[kr] + kr - k);
}
};
dnc(0, n - 1, at[n - 1]);
cout << best[at[n - 1]] << '\n';
} else {
vector<int> dep(n, 0), tpar(n, 0);
vector<long long> ans(n, 0);
function<void(int, int)> dfs_dep = [&] (int node, int par) {
tpar[node] = par;
for (auto next : edges[node]) {
if (next == par) continue;
dep[next] = dep[node] + 1;
dfs_dep(next, node);
}
};
dfs_dep(at[n - 1], at[n - 1]);
vector<vector<int>> anc(n, vector<int>(lg, 0));
for (int i = 0; i < n; i++) {
anc[i][0] = tpar[i];
}
for (int j = 1; j < lg; j++) {
for (int i = 0; i < n; i++) {
anc[i][j] = anc[anc[i][j - 1]][j - 1];
}
}
function<int(int, int)> kth_anc = [&] (int node, int k) {
for (int i = lg - 1; i >= 0; i--) {
if (1 << i <= k) {
node = anc[node][i];
k -= 1 << i;
}
}
return node;
};
function<int(int, int)> dist = [&] (int x, int y) {
if (dep[x] > dep[y]) swap(x, y);
int ret = dep[y] - dep[x];
y = kth_anc(y, dep[y] - dep[x]);
if (x == y) return ret;
for (int i = lg - 1; i >= 0; i--) {
if (anc[x][i] != anc[y][i]) {
x = anc[x][i];
y = anc[y][i];
ret += 1 << (i + 1);
}
}
return ret + 2;
};
vector<int> sz(n, 1), grp(n), mx = p;
iota(grp.begin(), grp.end(), 0);
function<int(int)> rep = [&] (int i) {
while (i != grp[i]) i = grp[i] = grp[grp[i]];
return i;
};
function<void(int, int)> merge = [&] (int i, int j) {
i = rep(i);
j = rep(j);
if (i == j) return;
if (sz[i] < sz[j]) swap(i, j);
sz[i] += sz[j];
grp[j] = i;
mx[i] = max(mx[i], mx[j]);
};
for (int i = 0; i < n; i++) {
int node = at[i], cmx = -1;
for (auto next : edges[node]) {
if (p[next] > p[node]) continue;
ans[node] = max(ans[node], ans[at[mx[rep(next)]]] + dist(at[mx[rep(next)]], node));
}
for (auto next : edges[node]) {
if (p[next] > p[node]) continue;
merge(next, node);
}
}
long long res = 0;
for (int i = 0; i < n; i++) {
res = max(res, ans[i]);
}
cout << res << '\n';
}
}
Compilation message (stderr)
Main.cpp: In function 'int main()':
Main.cpp:161:25: warning: unused variable 'cmx' [-Wunused-variable]
161 | int node = at[i], cmx = -1;
| ^~~
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |