Submission #763605

#TimeUsernameProblemLanguageResultExecution timeMemory
763605minhcoolTourism (JOI23_tourism)C++17
100 / 100
3866 ms43904 KiB
#pragma GCC optimize("unroll-loops")
#pragma gcc optimize("Ofast")
#include<bits/stdc++.h>
using namespace std;

#define fi first
#define se second
#define pb push_back
#define mp make_pair

typedef pair<int, int> ii;
typedef pair<ii, int> iii;
typedef pair<ii, ii> iiii;

const int N = 3e5 + 5;

const int oo = 1e9 + 7, mod = 1e9 + 7;

/*
What I am thinking:
We just find their lca + total edges covered?
*/

int n, m, q;

vector<int> Adj[N];

int arr[N];

int cnt;
int d[N], mn_d[N][20], lg[N];
int d2[N];


int le[N], ri[N], pos[N];

void dfs(int u, int p){
	cnt++;
	le[u] = pos[u] = cnt;
	mn_d[cnt][0] = u;
	for(auto v : Adj[u]){
		if(v == p) continue;
		d[v] = d[u] + 1;
		dfs(v, u);
		cnt++;
		mn_d[cnt][0] = u;
	}
	ri[u] = cnt;
}

void prep(){
	for(int i = 1; i <= 19; i++){
		for(int j = 1; (j + (1LL << i) - 1) <= (n << 1); j++){
			int x = mn_d[j][i - 1], y = mn_d[j + (1LL << (i - 1))][i - 1];
			mn_d[j][i] = (d[x] < d[y] ? x : y);
		}
	}
	for(int i = 2; i <= (n << 1); i++) lg[i] = lg[i >> 1] + 1;
}

int lca(int x, int y){
	//x = pos[x], y = pos[y];
	if(x > y) swap(x, y);
	int k = lg[y - x + 1];
	int a = mn_d[x][k], b = mn_d[y - (1LL << k) + 1][k];
	return (d[a] < d[b] ? a : b);
}

bool cmp(int a, int b){
	return (pos[a] < pos[b]);
}

vector<iii> queries;

const int S = 400;

bool cmpp(iii a, iii b){
	if(a.fi.fi / S != b.fi.fi / S) return (a.fi.fi < b.fi.fi);
	else if(!((a.fi.fi / S) & 1)) return (a.fi.se < b.fi.se);
	else return (a.fi.se > b.fi.se);
}

int cntt[N];

int BIT[N];

int answer[N];

int cur_ans = 0;

int total = 0;

void upd(int pos, int val){
	assert(pos);
	total += val;
	for(; pos <= (n << 1); pos += pos & -pos) BIT[pos] += val;
}

int get(int pos){
	int ans = 0;
	for(; pos; pos -= pos & -pos) ans += BIT[pos];
	return ans;
}

int find(int val){// first num >= val, assume that there is a number like this
	int answer = 0, mx = (2 * n + 1);
	for(int i = 18; i >= 0; i--){
		if((answer + (1LL << i)) >= mx) continue;
		if(val > BIT[answer + (1LL << i)]){
			val -= BIT[answer + (1LL << i)];
			answer += (1LL << i);
		}
	}
	return answer + 1;
}

int nxtt[N], lstt[N];

void ins(int x){// insert into the set
	//assert(cur.find(x) == cur.end());
	/*
	if(!cur.size()){
		cur.insert(x);
		return;
	}*/
	int lst = -oo, nxt = oo;
	int temp = get(x);
	if(temp > 0){
		lst = find(temp);
		nxt = nxtt[lst];
	}
	else if(temp < total) nxt = find(temp + 1); 
	/*
	set<int>::iterator it = cur.lower_bound(x);
	if(it != cur.end()) nxt = (*it);
	if(it != cur.begin()){
		it--;
		lst = (*it);
	}*/
	if(lst != -oo && nxt != oo) cur_ans -= (d2[nxt] - d[lca(lst, nxt)]);
	if(lst != -oo) cur_ans += (d2[x] - d[lca(lst, x)]);
	if(nxt != oo) cur_ans += (d2[nxt] - d[lca(nxt, x)]);
	lstt[x] = lst, nxtt[x] = nxt;
	if(lstt[x] != -oo) nxtt[lstt[x]] = x;
	if(nxtt[x] != oo) lstt[nxtt[x]] = x;
	//cur.insert(x);
	upd(x, 1);
}

void er(int x){// erase from the set
	int nxt = oo, lst = -oo;
	/*
	if(temp > 1) lst = find(temp - 1);
	if(temp < total) nxt = find(temp + 1);*/
	lst = lstt[x];
	nxt = nxtt[x];
	/*
	set<int>::iterator it = cur.find(x);
	assert(it != cur.end());
	int lst = -oo, nxt = oo;
	it++;
	if(it != cur.end()) nxt = (*it);
	it--;
	if(it != cur.begin()){
		it--;
		lst = (*it);
		it++;
	}*/
	if(lst != -oo && nxt != oo) cur_ans += (d2[nxt] - d[lca(lst, nxt)]);
	if(lst != -oo) cur_ans -= (d2[x] - d[lca(lst, x)]);
	if(nxt != oo) cur_ans -= (d2[nxt] - d[lca(nxt, x)]);
	if(lstt[x] != -oo) nxtt[lstt[x]] = nxtt[x];
	if(nxtt[x] != oo) lstt[nxtt[x]] = lstt[x];
	upd(x, -1);
	//assert(se.find(x) != se.end());
}

void process(){
	cin >> n >> m >> q;
	for(int i = 1; i < n; i++){
		int x, y;
		cin >> x >> y;
		Adj[x].pb(y);
		Adj[y].pb(x);
	}
	dfs(1, 1);
	prep();
	for(int i = 1; i <= n; i++) d2[pos[i]] = d[i];
	for(int i = 1; i <= m; i++) cin >> arr[i];
	for(int i = 0; i < q; i++){
		int l, r;
		cin >> l >> r;
		queries.pb({{l, r}, i});
	}
	sort(queries.begin(), queries.end(), cmpp);
	int lstl = 1, lstr = 0;
	for(int i = 1; i <= n; i++){
		lstt[i] = -oo, nxtt[i] = oo;
	}
	for(auto it : queries){
		while(lstr < it.fi.se){
			lstr++;
			if(!cntt[arr[lstr]]) ins(pos[arr[lstr]]);
			cntt[arr[lstr]]++;
		}
		while(lstl > it.fi.fi){
			lstl--;
			if(!cntt[arr[lstl]]) ins(pos[arr[lstl]]);
			cntt[arr[lstl]]++;
		}
		//for(auto it : cur) cout << it << " ";
		//cout << cur_ans << "\n";
		while(lstr > it.fi.se){
			cntt[arr[lstr]]--;
			if(!cntt[arr[lstr]]) er(pos[arr[lstr]]);
			lstr--;
		}
		while(lstl < it.fi.fi){
			cntt[arr[lstl]]--;
			if(!cntt[arr[lstl]]) er(pos[arr[lstl]]);
			lstl++;
		}
		//cout << cur_ans << "\n";
		//cout << it.fi.fi << " " << it.fi.se << " " << lstl << " " << lstr << "\n";
		//for(auto it : cur) cout << it << " ";
		//cout << "\n";
		if(total){
			int a = find(1), b = find(total);
			answer[it.se] = cur_ans + d2[a] - d[lca(a, b)] + 1;
		}
	}
	for(int i = 0; i < q; i++) cout << answer[i] << "\n";
}

signed main(){
	ios_base::sync_with_stdio(0);
//	freopen("tourist.inp", "r", stdin);
//	freopen("tourist.out", "w", stdout);
	cin.tie(0);
	process();
}


Compilation message (stderr)

tourism.cpp:2: warning: ignoring '#pragma gcc optimize' [-Wunknown-pragmas]
    2 | #pragma gcc optimize("Ofast")
      |
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...