Submission #888828

#TimeUsernameProblemLanguageResultExecution timeMemory
888828vjudge1Two Currencies (JOI23_currencies)C++17
68 / 100
2386 ms123816 KiB
#include <bits/stdc++.h>
using namespace std;

#define int long long
#define ff first
#define ss second
#define all(a) a.begin(), a.end()
const int mod = 998244353;
const int N = 1e5;
int n, m, q;
vector<int > g[N+10];
vector< pair<int, int> > check, edge;
int depth[N+10], sub[N+10];
int up[N+10][20], sum[N+10][20];
int tin[N+10], tout[N+10];
vector<int> val[N+10];
int bigchild[N+10], pos[N+10], chain[N+10];
int timer = 1;
void dfs(int v, int par){
	tin[v] = timer++;
	up[v][0] = par;
	for(int to : g[v]){
		if(to == par) continue;
		depth[to] = depth[v] + 1;
		dfs(to, v);
		sub[v]+= sub[to];
		if(!bigchild[v] or sub[bigchild[v]] < sub[to]){
			bigchild[v] = to;
		}
	}
	sub[v] += 1, tout[v] = timer++;
}


int upper(int p, int a){
	return (tin[p] <= tin[a] && tout[p] >= tout[a]);
}


int lca(int a, int b){
	if(depth[b] < depth[a]) swap(a, b);
	int k = depth[b] - depth[a];
	for(int i = 0;i < 20; i++){
		if(k & (1<<i)) b = up[b][i];
	}
	if(a == b) return a;
	for(int i = 19; i >= 0; i--){
		if(up[b][i] != up[a][i]){
			a = up[a][i];
			b = up[b][i];
		}
	}
	return up[a][0];
	
}

void dfs2(int v, int par, int head){
	pos[v] = timer++;
	chain[v] = head;
	if(bigchild[v]){
		dfs2(bigchild[v], v, head);
	}
	for(int to : g[v]){
			if(to == par || to == bigchild[v]) continue;
			dfs2(to, v, to);
	}
}



/*
int query(int a, int b){
	if(chain[a] > chain[b]) swap(a, b);
	int sum = 0;
	while(chain[a] != chain[b]){
			if(chain[a] > chain[b]) swap(a, b);
			int l =pos[chain[b]], r = pos[b];
			//sum+= get(l, r, 1, 1, n);
			b = up[chain[b]][0];
	}
	if(chain[a] > chain[b]) swap(a, b);
	int l = pos[a], r = pos[b];
	//sum+= get(l, r, 1, 1, n);
	return sum; 
}
*/
int distance(int a, int b){
	int lc = lca(a, b);
	return (depth[a] + depth[b] - 2*depth[lc]);
}
int checks(int a, int d){
	int cnt = 0;
	for(int i = 0;i < 20; i++){
		if(d & (1<<i)){
			cnt+= sum[a][i];
			a = up[a][i];
		}
	}
	return cnt;
}

int go_k(int a, int d){
	for(int i = 0;i < 20; i++){
		if(d & (1<<i)) a = up[a][i];
	}
	return a;
	
}

struct node{
	vector<int> a, pr;
};


signed main(){
	ios::sync_with_stdio(0);
	cin.tie(0); cout.tie(0);
	cin >> n >> m >> q;
	int bambo = 1;
	for(int i = 1;i < n; i++){
		int a, b; cin >> a >> b;
		if(a != i or b != a+1) bambo = 0;
		edge.push_back({a, b});
		g[a].push_back(b);
		g[b].push_back(a);
	}
	
	dfs(1, 1);
	timer = 1;
	dfs2(1, 1, 1);
	int coins = 0;
	for(int i = 0;i < m; i++){
		int p, c; cin >> p >> c;
		check.push_back({p, c});
		int a = edge[p-1].ff, b = edge[p-1].ss;
		if(depth[a] < depth[b]) swap(a, b);
		//cout << a << " , " << b << " = " << c << '\n';
		sum[a][0] += 1;
		val[a].push_back(c);
		coins = c;
	}
	for(int j = 1;j < 20; j++){
		for(int i = 1;i <= n; i++){
			up[i][j] = up[up[i][j-1]][j-1];
			sum[i][j] = sum[i][j-1] + sum[up[i][j-1]][j-1];
		}
	}
	if(max({n, m, q}) <= 2000){
		while(q--){
			int s, t; cin >> s >> t;
			int x, y; cin >> x >> y;
			int lc = lca(s, t);
			vector<pair<int, int> > vec;
			int cnt = checks(s, depth[s] - depth[lc]);
			cnt+= checks(t, depth[t] - depth[lc]);
			while(s != lc){
				for(int sa : val[s]) vec.push_back({sa, 1});
				s = up[s][0];
			}
			while(t != lc){
				for(int sa : val[t]) vec.push_back({sa, 1});
				t = up[t][0];
			}
			//cout << "ancestor : ";
			//cout << lc << "\n";
			sort(all(vec), [&](auto A, auto B){
				return A.ff < B.ff;
			});
			
			for(auto [silver, gold] : vec){
			//	cout << silver << ' ' << gold << ", ";
				if(y >= silver){
					y-= silver;
				}else{
					x-= gold;
				}
			}
			//cout << '\n';
			cout << max(-1LL, x) << '\n';
		}
		return 0;
	}else if(!bambo){
		while(q--){
			int s, t; cin >> s >> t;
			int x, y; cin >> x >> y;
			int lc = lca(s, t);
			int cnt = checks(s, depth[s] - depth[lc]);
			cnt+= checks(t, depth[t] - depth[lc]);
			cnt-= (y / coins);
			x-= max(0LL, cnt);
			cout << max(-1LL, x) << '\n';
		}
		return 0;
	}else{
		vector<int> p;
		vector<pair<int, int> > range(n+1);
		for(int i = 1;i <= n; i++){
			range[i].ff = p.size();
			for(int x : val[i]){
				p.push_back(x);
			}
			range[i].ss = p.size()-1;
		}
		int sz = p.size();
		vector<node> t(4*sz);
		auto merge=[&](auto left, auto right){
			node res ={{}, {}};
			int i = 0, j = 0;
			while(i < (int)left.a.size() || j < (int)right.a.size()){
				if(i >= (int)left.a.size() && j >= (int)right.a.size()) break;
				if(i >= (int)left.a.size()){
					res.a.push_back(right.a[j]);
					j++;
				}else if(j >= (int)right.a.size()){
					res.a.push_back(left.a[i]);
					i++;
				}else{
					if(left.a[i] < right.a[j]){
						res.a.push_back(left.a[i]);
						i++;
					}else{
						res.a.push_back(right.a[j]);
						j++;
					}
				}
				res.pr.push_back((res.pr.empty() ? 0 : res.pr.back()) + res.a.back());
			}
			return res;
		};
		
		auto build=[&](auto build, int v, int vl, int vr)->auto{
			if(vl == vr){
				t[v] = {{p[vl]}, {p[vl]}};
				return;
			}
			int mid = (vl + vr)>>1;
			build(build, v<<1, vl, mid);
			build(build, v<<1|1, mid+1, vr);
			t[v] = merge(t[v<<1], t[v<<1|1]);
		};
		node zero = {{}, {}};
		auto get=[&](auto get, int l, int r, int v, int vl, int vr)->auto{
			if(l > vr or vl > r) return zero;
			if(l <= vl && r >= vr) return t[v];
			int mid = (vl + vr)>>1;
			return merge(get(get, l, r, v<<1, vl, mid), get(get, l, r, v<<1|1, mid+1, vr));
		};
		
		auto sum=[&](auto sum, int l, int r, int x, int v, int vl, int vr)->auto{
			if(l > vr or vl > r) return 0LL;
			if(l <= vl && r >= vr){
				int it = upper_bound(all(t[v].a), x) - t[v].a.begin();
				it--;
				if(it >= 0) return t[v].pr[it];
				else return 0LL;
			} 
			int mid = (vl + vr)>>1;
			return sum(sum, l, r, x, v<<1, vl, mid) + sum(sum, l, r, x, v<<1|1, mid+1, vr);
		};
		
		auto smaller=[&](auto smaller, int l, int r, int x, int v, int vl, int vr)->auto{
			if(l > vr or vl > r) return 0LL;
			if(l <= vl && r >= vr){
				int it = upper_bound(all(t[v].a), x) - t[v].a.begin();
				return it;
			} 
			int mid = (vl + vr)>>1;
			return smaller(smaller, l, r, x, v<<1, vl, mid) + smaller(smaller, l, r, x, v<<1|1, mid+1, vr);
		};
		
		build(build, 1, 0, sz-1);
		while(q--){
			int s, t; cin >> s >> t;
			int x, y; cin >> x >> y;
			if(depth[s] < depth[t]) swap(s, t);
			int lc = t;
			int cnt = checks(s, depth[s] - depth[lc]);
			
			int l = range[go_k(s, depth[s] - depth[lc] - 1)].ff, r = range[s].ss;
		
			int lo = -1, ro = (int)y + 2;
			while(lo + 1 < ro){
				int mid = (lo + ro)>>1;
				if(sum(sum, l, r, mid, 1, 0, sz-1) > y){
					ro = mid;
				}else lo = mid;
			}
			
			cnt-= smaller(smaller, l, r, ro-1, 1, 0, sz-1);
			y-= sum(sum, l, r, ro-1, 1, 0, sz-1);
			if(ro > 0) cnt-= (y/ro);
			 x-= cnt;		
			cout << max(-1LL, x) << '\n';
		}
	}
	
	
	return 0;
}

Compilation message (stderr)

currencies.cpp: In function 'int main()':
currencies.cpp:242:8: warning: variable 'get' set but not used [-Wunused-but-set-variable]
  242 |   auto get=[&](auto get, int l, int r, int v, int vl, int vr)->auto{
      |        ^~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...