Submission #1165418

#TimeUsernameProblemLanguageResultExecution timeMemory
1165418trvhungTwo Currencies (JOI23_currencies)C++20
100 / 100
1965 ms362140 KiB
#include <bits/stdc++.h>
// #include <ext/rope>
// #include <ext/pb_ds/assoc_container.hpp>

// using namespace __gnu_pbds;
// using namespace __gnu_cxx;
using namespace std;

// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,popcnt,lzcnt")

// #define   ordered_set tree<int, null_type,less<int>, rb_tree_tag,tree_order_statistics_node_update>
#define            ll long long
#define           ull unsigned long long
#define            ld long double
#define            pb push_back
#define  bit(mask, i) ((mask >> i) & 1)
#define            el '\n'
#define             F first
#define             S second

template <class X, class Y> bool maximize(X &x, const Y &y) { return (x < y ? x = y, 1 : 0); }
template <class X, class Y> bool minimize(X &x, const Y &y) { return (x > y ? x = y, 1 : 0); }

const int INF = 1e9;
const ll LINF = 1e18;
const int MOD = 1e9 + 7;
const int MULTI = 0;
const ld eps = 1e-9;
const int dx[4] = {0, 1, 0, -1}, dy[4] = {1, 0, -1, 0}; // R D L U
const int ddx[4] = {-1, 1, 1, -1}, ddy[4] = {1, 1, -1, -1}; // UR DR DL UL
const char cx[4] = {'R', 'D', 'L', 'U'};
const ll base = 31;
const int nMOD = 2;
const ll mods[] = {(ll)1e9 + 10777, (ll)1e9 + 19777, (ll)1e9 + 3, (ll)1e9 + 3777};

const int maxn = 1e5 + 5;
int n, m, q, minC = INF, maxC, zeroes[2 * maxn], a[2 * maxn], L[maxn], R[maxn], curChain, curPos, pos[maxn], chainID[maxn], chainHead[maxn], sz[maxn], h[maxn], par[maxn];
vector<int> adj[maxn], weightAtPos[maxn];
pair<int, int> E[maxn];

void dfs(int u) {
	sz[u] = 1;
	for (int v: adj[u])
		if (v != par[u]) {
			par[v] = u;
			h[v] = h[u] + 1;
			dfs(v);
			sz[u] += sz[v];
		}
}	

void hld(int u) {
	if (!chainHead[curChain]) chainHead[curChain] = u;
	chainID[u] = curChain; pos[u] = curPos++;

	int nxt = 0;
	for (int v: adj[u])
		if (v != par[u] && sz[v] > sz[nxt])
			nxt = v;

	if (nxt) hld(nxt);
	for (int v: adj[u])
		if (v != par[u] && v != nxt) {
			curChain++;
			hld(v);
		}
}

int LCA(int u, int v) {
	while (chainID[u] != chainID[v])
		if (chainID[u] > chainID[v])
			u = par[chainHead[chainID[u]]];
		else
			v = par[chainHead[chainID[v]]];

	return h[u] < h[v] ? u : v;
}

pair<int, ll> add(pair<int, ll> A, pair<int, ll> B) {
	return make_pair(A.F + B.F, A.S + B.S);
}

class wavelet_tree {
public:
	int low, high;
 
	wavelet_tree *l, *r;
 
	vector<int> freq;
	vector<ll> pref;

    wavelet_tree() : low(0), high(0), l(nullptr), r(nullptr) {}
 
	wavelet_tree(int *from, int *to, int x, int y) {
		low = x; high = y; 

		if (from >= to) return;
 
		if (high == low) {
			pref.reserve(to - from + 1);
			freq.reserve(to - from + 1);
			freq.pb(0); pref.pb(0);
			for (auto it = from; it != to; it++)
				freq.pb(freq.back() + 1), pref.pb(pref.back() + *it);
			return;
		}
 
		int mid = (low + high) >> 1;
		auto lessThanMid
			= [mid](int x) { return x <= mid; };
 
		freq.reserve(to - from + 1); freq.pb(0);
		pref.reserve(to - from + 1); pref.pb(0);
		for (auto it = from; it != to; it++) {
			freq.pb(freq.back() + lessThanMid(*it));
			pref.pb(pref.back() + *it);
		}
 
		auto pivot = stable_partition(from, to, lessThanMid);
 
		l = new wavelet_tree(from, pivot, low, mid);
		r = new wavelet_tree(pivot, to, mid + 1, high);
	}
 
	pair<int, ll> kOrLess(int l, int r, int k) {
		if (l > r || k < low) return make_pair(0, 0);
		if (high <= k) return make_pair(r - l + 1, pref[r] - pref[l - 1]);
 	
		int LtCount = freq[l - 1], RtCount = freq[r];
 	
		return add(this -> l -> kOrLess(LtCount + 1, RtCount, k), this -> r -> kOrLess(l - LtCount, r - RtCount, k));
	}
} wt;

struct BIT {
	int bit[maxn];

	void update(int id, int v) {
		for (; id <= n; id += id & -id)
			bit[id] += v;
	}

	int get(int id) {
		int res = 0;
		for (; id > 0; id -= id & -id)
			res += bit[id];
		return res;
	}

	int getRange(int l, int r) {
		if (l > r) return 0;
		return get(r) - get(l - 1);
	}
} BIT;

int getGold(int u, int v) {
	int lca = LCA(u, v), res = 0;

	while (chainID[u] != chainID[lca]) {
		res += BIT.getRange(pos[chainHead[chainID[u]]], pos[u]);
		u = par[chainHead[chainID[u]]];
	}

	while (chainID[v] != chainID[lca]) {
		res += BIT.getRange(pos[chainHead[chainID[v]]], pos[v]);
		v = par[chainHead[chainID[v]]];
	}

	if (pos[u] > pos[v]) swap(u, v);
	res += BIT.getRange(pos[u] + 1, pos[v]);

	return res;
}

int getZeroes(int l, int r) {
	return zeroes[r] - zeroes[l - 1];
}

pair<int, ll> getPath(int u, int v, int x) {
	int lca = LCA(u, v);
	pair<int, ll> res = make_pair(0, 0);

	while (chainID[u] != chainID[lca]) {
		res = add(res, wt.kOrLess(L[pos[chainHead[chainID[u]]]], R[pos[u]], x));
		res.F -= getZeroes(L[pos[chainHead[chainID[u]]]], R[pos[u]]);
		u = par[chainHead[chainID[u]]];
	}		

	while (chainID[v] != chainID[lca]) {
		res = add(res, wt.kOrLess(L[pos[chainHead[chainID[v]]]], R[pos[v]], x));
		res.F -= getZeroes(L[pos[chainHead[chainID[v]]]], R[pos[v]]);
		v = par[chainHead[chainID[v]]];
	}

	if (pos[u] > pos[v]) swap(u, v);
	res = add(res, wt.kOrLess(L[pos[u] + 1], R[pos[v]], x));
	res.F -= getZeroes(L[pos[u] + 1], R[pos[v]]);

	return res;
}

int query(int u, int v, int x, ll y) {
	int l = minC - 1, r = maxC, lim = -1;
	int gold = getGold(u, v), recGold = 0;
	ll sil = 0;

	while (l <= r) {
		int mid = (l + r) >> 1;
		auto get = getPath(u, v, mid);
		if (get.S <= y) {
			sil = get.S;
			recGold = get.F;
			lim = mid;
			l = mid + 1;
		} else r = mid - 1;
	}

	if (lim == -1) return (x >= gold ? x - gold : -1);
	if (lim < maxC) recGold += (y - sil) / (lim + 1);

	int fin = x - gold + recGold;
	return fin < 0 ? -1 : fin;
}

void solve() {
	cin >> n >> m >> q;
	for (int i = 1, u, v; i < n; ++i) {
		cin >> u >> v;
		adj[u].push_back(v);
		adj[v].push_back(u);
		E[i] = make_pair(u, v);
	}

	curChain = curPos = 1;
	dfs(1); hld(1);

	for (int i = 1, p, c; i <= m; ++i) {
		cin >> p >> c;
		int u = E[p].F, v = E[p].S;
			
		if (pos[u] < pos[v]) swap(u, v);
		weightAtPos[pos[u]].push_back(c);
		BIT.update(pos[u], 1);

		maximize(maxC, c);
		minimize(minC, c);
	}

	int cur = 0;
	for (int i = 1; i <= n; ++i) {
		if (weightAtPos[i].empty()) {
			L[i] = R[i] = cur + 1;
			a[cur++] = 0;
			continue;
		}
		L[i] = cur + 1; R[i] = cur + (int) weightAtPos[i].size();
		for (int x: weightAtPos[i])
			a[cur++] = x;
	}

	for (int i = 1; i <= cur; ++i)
		zeroes[i] = zeroes[i - 1] + (a[i - 1] == 0);

	wt = wavelet_tree(a, a + cur, 0, maxC);

	while (q--) {
		int s, t, x;
		ll y;
		cin >> s >> t >> x >> y;
		cout << query(s, t, x, y) << el;
	}
}

signed main() {
    ios_base::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);

    if (!MULTI) solve();
    else {
        int test; cin >> test;
        while (test--) solve();
    }
    
    return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...