답안 #763604

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
763604 2023-06-22T13:48:24 Z minhcool Tourism (JOI23_tourism) C++17
0 / 100
3 ms 7380 KB
#include<bits/stdc++.h>
using namespace std;

#define fi first
#define se second
#define pb push_back
#define mp make_pair

typedef pair<int, int> ii;
typedef pair<ii, int> iii;
typedef pair<ii, ii> iiii;

const int N = 3e5 + 5;

const int oo = 1e9 + 7, mod = 1e9 + 7;

/*
What I am thinking:
We just find their lca + total edges covered?
*/

int n, m, q;

vector<int> Adj[N];

int arr[N];

int cnt;
int d[N], mn_d[N][20], lg[N];
int d2[N];


int le[N], ri[N], pos[N];

void dfs(int u, int p){
	cnt++;
	le[u] = pos[u] = cnt;
	mn_d[cnt][0] = u;
	for(auto v : Adj[u]){
		if(v == p) continue;
		d[v] = d[u] + 1;
		dfs(v, u);
		cnt++;
		mn_d[cnt][0] = u;
	}
	ri[u] = cnt;
}

void prep(){
	for(int i = 1; i <= 19; i++){
		for(int j = 1; (j + (1LL << i) - 1) <= (n << 1); j++){
			int x = mn_d[j][i - 1], y = mn_d[j + (1LL << (i - 1))][i - 1];
			mn_d[j][i] = (d[x] < d[y] ? x : y);
		}
	}
	for(int i = 2; i <= (n << 1); i++) lg[i] = lg[i >> 1] + 1;
}

int lca(int x, int y){
	//x = pos[x], y = pos[y];
	if(x > y) swap(x, y);
	int k = lg[y - x + 1];
	int a = mn_d[x][k], b = mn_d[y - (1LL << k) + 1][k];
	return (d[a] < d[b] ? a : b);
}

bool cmp(int a, int b){
	return (pos[a] < pos[b]);
}

vector<iii> queries;

const int S = 400;

bool cmpp(iii a, iii b){
	if(a.fi.fi / S != b.fi.fi / S) return (a.fi.fi < b.fi.fi);
	else if(!((a.fi.fi / S) & 1)) return (a.fi.se < b.fi.se);
	else return (a.fi.se > b.fi.se);
}

int cntt[N];

int BIT[N];

int answer[N];

int cur_ans = 0;

int total = 0;

void upd(int pos, int val){
	assert(pos);
	total += val;
	for(; pos <= (n << 1); pos += pos & -pos) BIT[pos] += val;
}

int get(int pos){
	int ans = 0;
	for(; pos; pos -= pos & -pos) ans += BIT[pos];
	return ans;
}

int find(int val){// first num >= val, assume that there is a number like this
	int answer = 0, mx = (2 * n + 1);
	for(int i = 19; i >= 0; i--){
	    int temp = answer + (1LL << i);
		if(answer + (1LL << i) >= mx) continue;
		if(val > BIT[answer + (1LL << i)]){
			val -= BIT[answer + (1LL << i)];
			answer += (1LL << i);
		}
	}
	return answer + 1;
}

int lstt[N], nxtt[N];

void ins(int x){// insert into the set
	//assert(cur.find(x) == cur.end());
	/*
	if(!cur.size()){
		cur.insert(x);
		return;
	}*/
	if(!total){
		upd(x, 1);
		return;	
	}
	int lst = -oo, nxt = oo;
	int temp = get(x);
	if(temp > 0) lst = find(temp);
	if(temp < total) nxt = find(temp + 1);
	/*
	if(temp > 0){
	    lst = find(temp);
	    nxt = nxtt[lst];
	}
	else nxt = find(temp + 1);*/
	//if(temp < total) nxt = find(temp + 1); 
	/*
	set<int>::iterator it = cur.lower_bound(x);
	if(it != cur.end()) nxt = (*it);
	if(it != cur.begin()){
		it--;
		lst = (*it);
	}*/
	if(lst != -oo && nxt != oo) cur_ans -= (d2[nxt] - d[lca(lst, nxt)]);
	if(lst != -oo) cur_ans += (d2[x] - d[lca(lst, x)]);
	if(nxt != oo) cur_ans += (d2[nxt] - d[lca(nxt, x)]);
//	lstt[x] = lst;
//	nxtt[x] = nxt;
//	if(lst != -oo) nxtt[lst] = x;
//	if(nxt != oo) lstt[nxt] = x;
	//lstt[nxt] = nxtt[lst] = x;
	//cur.insert(x);
	upd(x, 1);
}

void er(int x){// erase from the set
	int temp = get(x);
	int nxt = oo, lst = -oo;
	if(temp) lst = find(temp - 1);
	if(temp < total) nxt = find(temp + 1);
	//lst = lstt[x], nxt = nxtt[x];
	/*
	set<int>::iterator it = cur.find(x);
	assert(it != cur.end());
	int lst = -oo, nxt = oo;
	it++;
	if(it != cur.end()) nxt = (*it);
	it--;
	if(it != cur.begin()){
		it--;
		lst = (*it);
		it++;
	}*/
	if(lst != -oo && nxt != oo) cur_ans += (d2[nxt] - d[lca(lst, nxt)]);
	if(lst != -oo) cur_ans -= (d2[x] - d[lca(lst, x)]);
	if(nxt != oo) cur_ans -= (d2[nxt] - d[lca(nxt, x)]);
//	if(lstt[x] != -oo) nxtt[lstt[x]] = nxtt[x];
//	if(nxtt[x] != oo) lstt[nxtt[x]] = lstt[x];
	upd(x, -1);
	//assert(se.find(x) != se.end());
}

mt19937_64 rng(chrono::steady_clock::now().time_since_epoch().count());

int rnd(int l, int r){
    int temp = abs((int)rng() % (r - l + 1));
    return temp + l;
}

void process(){
	cin >> n >> m >> q;
	for(int i = 1; i < n; i++){
		int x, y;
		cin >> x >> y;
		Adj[x].pb(y);
		Adj[y].pb(x);
	}
	/*
	n = m = q = 100000;
	for(int i = 2; i <= n; i++){
	    int x = rnd(1, i - 1), y = i;
	    Adj[x].pb(y);
		Adj[y].pb(x);
	}*/
	dfs(1, 1);
	prep();
	for(int i = 1; i <= n; i++) d2[pos[i]] = d[i];
	for(int i = 1; i <= m; i++) cin >> arr[i];
	for(int i = 0; i < q; i++){
		int l, r;
		cin >> l >> r;
		queries.pb({{l, r}, i});
	}
	//for(int i = 1; i <= m; i++) arr[i] = rnd(1, n);
	/*
	for(int i = 0; i < q; i++){
	    int l = rnd(1, m), r = rnd(l, m);
	    queries.pb({{l, r}, i});
	}*/
	sort(queries.begin(), queries.end(), cmpp);
	int lstl = 1, lstr = 0;
	for(int i = 1; i <= n; i++){
	    lstt[i] = -oo, nxtt[i] = oo;
	}
	for(auto it : queries){
		while(lstr < it.fi.se){
			lstr++;
			if(!cntt[arr[lstr]]) ins(pos[arr[lstr]]);
			cntt[arr[lstr]]++;
		}
		while(lstl > it.fi.fi){
			lstl--;
			if(!cntt[arr[lstl]]) ins(pos[arr[lstl]]);
			cntt[arr[lstl]]++;
		}
		//cout << lstl << " " << lstr << " " << cur_ans << "\n";
		//for(auto it : cur) cout << it << " ";
		//cout << cur_ans << "\n";
		while(lstr > it.fi.se){
			cntt[arr[lstr]]--;
			if(!cntt[arr[lstr]]) er(pos[arr[lstr]]);
			lstr--;
		}
		while(lstl < it.fi.fi){
			cntt[arr[lstl]]--;
			if(!cntt[arr[lstl]]) er(pos[arr[lstl]]);
			lstl++;
		}
		//cout << it.se << "\n";
		//cout << cur_ans << "\n";
		//cout << it.fi.fi << " " << it.fi.se << " " << lstl << " " << lstr << "\n";
		//for(auto it : cur) cout << it << " ";
		//cout << "\n";
		if(total){
			int a = find(1), b = find(total);
			answer[it.se] = cur_ans + d2[a] - d[lca(a, b)] + 1;
		}
		else answer[it.se] = 0;
	}
	for(int i = 0; i < q; i++) cout << answer[i] << "\n";
}

signed main(){
	ios_base::sync_with_stdio(0);
	//freopen("tourist.inp", "r", stdin);
	//freopen("tourist.out", "w", stdout);
	cin.tie(0);
	process();
}

Compilation message

tourism.cpp: In function 'int find(int)':
tourism.cpp:106:10: warning: unused variable 'temp' [-Wunused-variable]
  106 |      int temp = answer + (1LL << i);
      |          ^~~~
# 결과 실행 시간 메모리 Grader output
1 Incorrect 3 ms 7380 KB Output isn't correct
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 3 ms 7380 KB Output isn't correct
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 3 ms 7380 KB Output isn't correct
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 3 ms 7380 KB Output isn't correct
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 3 ms 7380 KB Output isn't correct
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 3 ms 7380 KB Output isn't correct
2 Halted 0 ms 0 KB -