제출 #800464

#제출 시각아이디문제언어결과실행 시간메모리
800464acatmeowmeowCat Exercise (JOI23_ho_t4)C++11
100 / 100
327 ms126200 KiB
#include <bits/stdc++.h>

using namespace std;

#define int long long 

const int N = 2e5, KMAX = 20;
int n, h[N + 5], euler[2*N + 5], lg[2*N + 5], tin[N + 5], tout[N + 5], timer = 1, d[N + 5], table[2*N + 5][KMAX + 5];
vector<int> adj[N + 5];

void dfs1(int u, int e) {
	tin[u] = ++timer;
	euler[timer] = u;
	for (auto&v : adj[u]) {
		if (v == e) continue;
		d[v] = d[u] + 1;
		dfs1(v, u);
		euler[++timer] = u;
	}
	tout[u] = timer;
}

int combine(int x, int y) { return d[x] < d[y] ? x : y; }

void build(int n) {
	lg[1] = 0;
	for (int i = 2; i <= n; i++) lg[i] = lg[i/2] + 1;
	for (int i = 1; i <= n; i++) table[i][0] = euler[i];
	for (int j = 1; j <= KMAX; j++) {
		for (int i = 1; i + (1ll << j) - 1 <= n; i++) {
			table[i][j] = combine(table[i][j - 1], table[i + (1ll << (j - 1))][j - 1]);
		}
	}
}

int lca(int x, int y) {
	if (tin[x] > tin[y]) swap(x, y);
	int k = lg[tin[y] - tin[x] + 1], w = combine(table[tin[x]][k], table[tin[y] - (1ll << k) + 1][k]);
	return d[x] + d[y] - 2*d[w];
}

int par[N + 5], sz[N + 5], mx[N + 5], highest[N + 5], indx[N + 5];

int find(int x) { return x == par[x] ? x : find(par[x]); }

bool same(int x, int y) { return find(x) == find(y); }

int combine2(int x, int y) { 
	if (!x) return y;
	else if (!y) return x;
	else return h[x] > h[y] ? x : y; 
}

void unite(int x, int y) {
	x = find(x), y = find(y);
	if (sz[x] < sz[y]) swap(x, y);
	sz[x] += sz[y], par[y] = x;
   	mx[x] = max(mx[x], mx[y]), highest[x] = combine2(highest[x], highest[y]);
}

signed main() {
	ios::sync_with_stdio(false);
	cin.tie(nullptr);
	cin >> n;
	for (int i = 1; i <= n; i++) cin >> h[i];
	for (int i = 1; i < n; i++) {
		int u, v;
		cin >> u >> v;
		adj[u].push_back(v);
		adj[v].push_back(u);
	}
	dfs1(1, 0);
	build(2*n);
	fill(sz + 1, sz + n + 1, 1ll);
	iota(par + 1, par + n + 1, 1ll);
	iota(highest + 1, highest + n + 1, 1ll);
	iota(indx + 1, indx + n + 1, 1ll);
	sort(indx + 1, indx + n + 1, [&](int a, int b) { return h[a] < h[b]; });
	for (int i = 1;i <= n; i++) {
		int index = indx[i];
		for (auto&v : adj[index]) {
			int root = find(v);
			if (h[highest[root]] > h[index]) continue;
			mx[index] = max(mx[index], mx[root] + lca(index, highest[root]));
		}
		for (auto&v : adj[index]) {
			int root = find(v);
			if (h[highest[root]] > h[index]) continue;
			if (!same(index, v)) unite(index, v);
		}
	}
	cout << *max_element(mx + 1, mx + n + 1) << '\n';
	return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...