Submission #1303723

#TimeUsernameProblemLanguageResultExecution timeMemory
1303723g4yuhgTourism (JOI23_tourism)C++20
100 / 100
642 ms81568 KiB
#include<bits/stdc++.h>
typedef long long ll;
#define pii pair<ll, ll>
#define fi first
#define se second
#define endl '\n'
#define TASK ""
#define N 300005
#define LOG 17
using namespace std;

const ll inf = 1e18;

bool ghuy4g;

struct Qr {
	ll l, r, id;
} qr[N];

ll n, m, q;
vector<ll> adj[N], g[N];

ll sz[N], high[N], par[N][LOG + 2], cur_pos, cur_chain, lst[N], c[N], kq[N];
ll chain_head[N], chain_id[N];

void dfs(ll u, ll parent) {
	sz[u] = 1;
	for (auto v : adj[u]) {
		if (v == parent) continue;
		high[v] = high[u] + 1;
		par[v][0] = u;
		dfs(v, u);
		sz[u] += sz[v];
	}
}

void hld(ll u, ll parent) {
	if (chain_head[cur_chain] == 0) {
		chain_head[cur_chain] = u;
	}
	lst[cur_chain] = u;
	chain_id[u] = cur_chain;
	ll nxt = 0;
	for (auto v : adj[u]) {
		if (v == parent) continue;
		if (sz[v] > sz[nxt]) {
			nxt = v;
		}
	}
	if (nxt) {
		hld(nxt, u);
	}
	for (auto v : adj[u]) {
		if (v == parent || v == nxt) continue;
		cur_chain ++ ;
		hld(v, u);
	}
}

ll lca(ll u, ll v) {
	if (high[u] > high[v]) swap(u, v);
	for (int j = LOG; j >= 0; j --) {
		if (high[v] - high[u] >= (1 << j)) {
			v = par[v][j];
		}
	}
	if (u == v) return u;
	for (int j = LOG; j >= 0; j --) {
		if (par[u][j] != par[v][j]) {
			u = par[u][j];
			v = par[v][j];
		}
	}
	return par[u][0];
}

ll st[4 * N], st1[4 * N];

ll cb(ll u, ll v) {
	if (u == -1) return v;
	if (v == -1) return u;
	return lca(u, v);
}

ll cb1(ll u, ll v) {
	if (u == -1) return v;
	if (v == -1) return u;
	if (high[u] < high[v]) return u;
	return v;
}

void build(ll id, ll l, ll r) {
	if (l == r) {
		st[id] = c[l];
		st1[id] = chain_head[chain_id[c[l]]];
		return;
	}
	ll mid = (l + r) / 2;
	build(id * 2, l, mid);
	build(id * 2 + 1, mid + 1, r);
	st[id] = cb(st[id * 2], st[id * 2 + 1]);
	st1[id] = cb1(st1[id * 2], st1[id * 2 + 1]);
}

ll get(ll id, ll l, ll r, ll L, ll R) {
	if (l > R || r < L) {
		return -1;
	}
	if (L <= l && r <= R) {
		return st[id];
	}
	ll mid = (l + r) / 2;
	return cb(get(id * 2, l, mid, L, R), get(id * 2 + 1, mid + 1, r, L, R));
}

ll get1(ll id, ll l, ll r, ll L, ll R) {
	if (l > R || r < L) {
		return -1;
	}
	if (L <= l && r <= R) {
		return st1[id];
	}
	ll mid = (l + r) / 2;
	return cb1(get1(id * 2, l, mid, L, R), get1(id * 2 + 1, mid + 1, r, L, R));
}

void dfs2(ll u, ll parent) {
	for (auto v : adj[u]) {
		if (v == parent) continue;
		dfs2(v, u);
	}
	ll head = chain_head[chain_id[u]];
	ll ph = par[head][0];
	if (ph) {
		for (auto it : g[u]) {
			g[ph].push_back(it);
		}
	}
}

void pre() {
	dfs(1, 1);
	cur_pos = cur_chain = 1;
	hld(1, 1);
	for (int j = 1; j <= LOG; j ++) {
		for (int i = 1; i <= n; i ++) {
			ll p = par[i][j - 1];
			par[i][j] = par[p][j - 1];
		}
	}
	dfs2(1, 1);
	build(1, 1, m);
}

struct Node {
	ll l, r, i, d;
};
vector<Node> vt;

ll findL(ll id, set<ll>&s) {
	auto it = s.lower_bound(id);
	if (it == s.begin()) {
		return 1;
	}
	it -- ;
	return (*it) + 1;
}

ll findR(ll id, set<ll>&s) {
	auto it = s.lower_bound(id);
	if (it == s.end()) {
		return m;
	}
	return (*it) - 1;
}

void xly(ll chain) {
	//cout << "chain " << chain << " head " << chain_head[chain] << endl;
	set<ll> s;
	ll u = lst[chain];
	ll d = high[u] - high[chain_head[chain]] + 1;
	while (true) {
		//cout << "  u: " << u << endl;
		for (auto id : g[u]) {
			ll L = findL(id, s);
			ll R = findR(id, s);
			if (L <= id && id <= R) {
				vt.push_back({L, R, id, d});
				//cout << "   add " << L << " " << R << " " << id << " " << d << endl;
			}
			s.insert(id);
		}
		for (auto id : g[u]) {
			//s.insert(id);
		}
		if (u == chain_head[chain]) break; // d == 0 thi ko can xet
		d -- ;
		u = par[u][0];
	}
}

ll bit[N];

void upd(ll u, ll v) {
	ll idx = u;
	while (idx <= m) {
		bit[idx] += v;
		idx += idx & (-idx);
	}
}

ll get(ll u) {
	ll idx = u, ans = 0;
	while (idx > 0) {
		ans += bit[idx];
		idx -= idx & (-idx);
	}
	return ans;
}

void upd_rand(ll l, ll r, ll d) {
	upd(l, d);
	upd(r + 1, -d);
}

void solve() {
	for (int chain = 1; chain <= cur_chain; chain ++) {
		xly(chain);
	}
	sort(vt.begin(), vt.end(), [&](Node A, Node B) {
		return A.l < B.l;
	});
	sort(qr + 1, qr + 1 + q, [&](Qr A, Qr B) {
		return A.l < B.l;
	});
	ll cur = 0;
	priority_queue<pii, vector<pii>, greater<pii>> pq;
	for (int i = 1; i <= q; i ++) {
		while (cur < vt.size() && vt[cur].l <= qr[i].l) {
			upd_rand(vt[cur].i, vt[cur].r, vt[cur].d);
			pq.push({vt[cur].i, cur});
			cur ++ ;
		}
		while (pq.size() && pq.top().fi < qr[i].l) {
			ll id = pq.top().se;
			upd_rand(vt[id].i, vt[id].r, -vt[id].d);
			pq.pop();
		}
		kq[qr[i].id] = get(qr[i].r);
		ll x = get(1, 1, m, qr[i].l, qr[i].r); // lca cua ca doan
		ll cao = get1(1, 1, m, qr[i].l, qr[i].r); // dinh top ma cno lay
		kq[qr[i].id] -= high[x];
	}
	for (int i = 1; i <= q; i ++) {
		cout << kq[i] << endl;
	}
}

bool klinh;

signed main() {
   	//freopen("test.inp", "r", stdin);
	//freopen("test.out", "w", stdout);
	//srand(time(0));
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    
   	cin >> n >> m >> q;
   	for (int i = 1; i <= n - 1; i ++) {
   		ll u, v;
   		cin >> u >> v;
   		adj[u].push_back(v);
   		adj[v].push_back(u);
   	}
   	for (int i = 1; i <= m; i ++) {
   		cin >> c[i];
   		g[c[i]].push_back(i);
   	}
   	for (int i = 1; i <= q; i ++) {
   		cin >> qr[i].l >> qr[i].r;
   		qr[i].id = i;
   	}
   	pre();
   	solve();
    
   	cerr << "Time elapsed: " << 1.0 * clock() / CLOCKS_PER_SEC << " s.\n";
   	cerr << fabs(&klinh - &ghuy4g) / 1048576.0;
   	
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...