Submission #813273

#TimeUsernameProblemLanguageResultExecution timeMemory
813273hugo_pmTourism (JOI23_tourism)C++17
100 / 100
1405 ms135632 KiB
#include <bits/stdc++.h>
#define int long long
using namespace std;

#define all(v) (v).begin(), (v).end()
#define rall(v) (v).rbegin(), (v).rend()
#define rep(i, a, b) for(int i = (a); i < (b); i++)
#define sz(v) ((int)((v).size()))

template<typename T>
void chmax(T &x, const T &v) { if (x < v) x = v; }
template<typename T>
void chmin(T &x, const T &v) { if (x > v) x = v; }

using pii = pair<int, int>;
using vi = vector<int>;

string to_string(string s) { return s; }
template <typename T> string to_string(T v) {
	bool first = true;
	string res = "[";
	for (const auto &x : v) {
		if (!first)
			res += ", ";
		first = false;
		res += to_string(x);
	}
	res += "]";
	return res;
}

template <typename A, typename B>
string to_string(pair<A, B> p) {
  return "(" + to_string(p.first) + ", " + to_string(p.second) + ")";
}

void dbg_out() { cout << endl; }
template <typename Head, typename... Tail> void dbg_out(Head H, Tail... T) {
	cout << ' ' << to_string(H);
	dbg_out(T...);
}

#ifdef DEBUG
#define dbg(...) cout << "(" << #__VA_ARGS__ << "):", dbg_out(__VA_ARGS__)
#else
#define dbg(...)
#endif

struct FT {
	vector<int> s;
	FT(int n) : s(n) {}
	void update(int pos, int dif) { // a[pos] += dif
		for (; pos < sz(s); pos |= pos + 1) s[pos] += dif;
	}
	int query(int pos) { // sum of values in [0, pos)
		int res = 0;
		for (; pos > 0; pos &= pos - 1) res += s[pos-1];
		return res;
	}
};

const int LG = 17;
static_assert((1<<LG) > 100'000);

struct Req {
	int left, right, iReq;
};
string to_string(Req rq) { return to_string(make_pair(rq.left,rq.right)); }
struct Spot {
	int node, pos;
};

signed main() {
	ios::sync_with_stdio(false);
	cin.tie(0);
	int nbNode, nbSpot, nbReq;
	cin >> nbNode >> nbSpot >> nbReq;
	vector<vi> adj(nbNode);
	rep(iEdge, 0, nbNode-1) {
		int u, v; cin >> u >> v;
		--u, --v;
		adj[u].push_back(v);
		adj[v].push_back(u);
	}
	vector<vi> anc(LG, vi(nbNode, -1));
	vector<int> tin(nbNode), tout(nbNode), prof(nbNode);
	{
		int lastTime = 0;
		auto Dfs = [&] (auto dfs, int node, int parent) -> void {
			anc[0][node] = parent;
			rep(lvl, 0, LG-1) {
				if (anc[lvl][node] == -1) break;
				anc[lvl+1][node] = anc[lvl][anc[lvl][node]];
			}
			tin[node] = lastTime++;
			for (int child : adj[node]) if (child != parent) {
				prof[child] = prof[node]+1;
				dfs(dfs, child, node);
			}
			tout[node] = lastTime++;
		};
		Dfs(Dfs, 0, -1);
	}
	auto ancestor = [&] (int ance, int child) {
		return tin[ance] <= tin[child] && tout[child] <= tout[ance];
	};
	auto lca = [&] (int u, int v) {
		if (prof[u] > prof[v]) swap(u, v);
		int firstJmp = prof[v] - prof[u];
		rep(lvl, 0, LG) if ((1<<lvl) & firstJmp) {
			v = anc[lvl][v];
		}
		assert(prof[u] == prof[v]);
		if (u == v) return u;
		for (int lvl = LG-1; lvl >= 0; --lvl) {
			int pu = anc[lvl][u], pv = anc[lvl][v];
			if (pu != -1 && pv != -1 && pu != pv) {
				u = pu, v = pv;
			}
		}
		assert(u != v);
		u = anc[0][u], v = anc[0][v];
		assert(u == v);
		return u;
	};
	vector<Spot> spots(nbSpot);
	rep(iSpot, 0, nbSpot) {
		auto &s = spots[iSpot];
		cin >> s.node; --s.node;
		s.pos = iSpot;
	}
	vector<int> answers(nbReq);
	vector<Req> allReqs(nbReq);
	rep(iReq, 0, nbReq) {
		int L, R;
		cin >> L >> R;
		--L, --R;
		allReqs[iReq] = {L, R, iReq};
	}
	auto compTin = [&] (const Spot &x, const Spot &y) {
		return tin[x.node] < tin[y.node];
	};
	auto Solve = [&] (auto solve, int leftDPR, int rightDPR, vector<Req> reqs) -> void {
		if (reqs.empty()) return;
		assert(leftDPR <= rightDPR);
		int mid = (leftDPR+rightDPR)/2, szDPR = (rightDPR-leftDPR+1);
		dbg(leftDPR, mid, rightDPR, reqs);
		vector<Spot> subset(begin(spots)+leftDPR, begin(spots)+rightDPR+1);
		sort(all(subset), compTin);
		const int FICTIVE_LCA_POS = -1e6;
		rep(i, 0, rightDPR-leftDPR) { // attention!
			subset.push_back({lca(subset[i].node, subset[i+1].node), FICTIVE_LCA_POS});
		}
		sort(all(subset), compTin);

		int szItv = sz(subset);
		vector<vector<pii>> adjVirt(szItv);
		auto Virt = [&] (auto virt, int iCur) -> int {
			int iLook = iCur+1;
			while (iLook < szItv && ancestor(subset[iCur].node, subset[iLook].node)) {
				int virtWeight = prof[subset[iLook].node] - prof[subset[iCur].node];
				adjVirt[iCur].emplace_back(iLook, virtWeight);
				adjVirt[iLook].emplace_back(iCur, virtWeight);
				iLook = virt(virt, iLook);
			}
			return iLook;
		};
		assert(Virt(Virt, 0) == szItv);
		FT fenwick(szDPR);
		vector<vector<pii>> sweepLine(szDPR);
		auto add = [&] (int cl, int cr, int weight) {
			if (leftDPR <= cl) {
				fenwick.update(0, weight);
				// remove in 0, readd in cr
				sweepLine[cl-leftDPR].push_back({0, -weight});
				if (cr <= rightDPR)
					sweepLine[cl-leftDPR].push_back({cr-leftDPR, weight});
			} else if (cr <= rightDPR) {
				fenwick.update(cr-leftDPR, weight);
			}
		};
		auto Dfs = [&] (auto dfs, int vNode, int parent) -> pii {
			int closestLeft = -1e9, closestRight = 1e9;
			const int &curPos = subset[vNode].pos;
			if (curPos != FICTIVE_LCA_POS) {
				if (curPos < mid)
					closestLeft = curPos;
				else if (curPos > mid)
					closestRight = curPos;
			}
			for (auto [vChild, weight] : adjVirt[vNode]) if (vChild != parent) {
				auto [ccleft, ccright] = dfs(dfs, vChild, vNode);
				add(ccleft, ccright, weight);
				chmax(closestLeft, ccleft), chmin(closestRight, ccright);
			}
			dbg(vNode, closestLeft, closestRight);
			return {closestLeft, closestRight};
		};

		int root = 0;
		while (subset[root].pos != mid) ++root;
		Dfs(Dfs, root, -1);
		vector<Req> sL, sR;
		vector<vector<Req>> toComp(szDPR);
		for (const auto &rq : reqs) {
			if (rq.right < mid)
				sL.push_back(rq);
			else if (rq.left > mid)
				sR.push_back(rq);
			else
				toComp[rq.left-leftDPR].push_back(rq);
		}
		rep(compCl, 0, szDPR) {
			for (const auto &rq : toComp[compCl]) {
				answers[rq.iReq] = fenwick.query(rq.right-leftDPR+1);
			}
			for (const auto &[compCr, delta] : sweepLine[compCl]) {
				fenwick.update(compCr, delta);
			}
		}
		solve(solve, leftDPR, mid-1, sL);
		solve(solve, mid+1, rightDPR, sR);
	};
	Solve(Solve, 0, nbSpot-1, allReqs);
	rep(iReq, 0, nbReq) {
		cout << answers[iReq]+1 << '\n';
	}
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...