제출 #1089563

#제출 시각아이디문제언어결과실행 시간메모리
1089563VMaksimoski008Two Currencies (JOI23_currencies)C++17
100 / 100
2775 ms238088 KiB
#include <bits/stdc++.h>

#define all(x) x.begin(), x.end()
#define rall(x) x.rbegin(), x.rend()
//#define int long long

using namespace std;

using ll = long long;
using pii = pair<int, int>;
using pll = pair<ll, ll>;

const int mod = 1e9 + 7;
const int LOG = 20;
const int maxn = 1e5 + 5;

struct Node {
    ll sum, cnt;
    Node *l, *r;

    Node(ll a, ll b) : sum(a), cnt(b), l(nullptr), r(nullptr) {}
    Node(Node *l, Node *r) : sum(0), cnt(0), l(l), r(r) {
        if(l) sum += l->sum;
        if(l) cnt += l->cnt;
        if(r) sum += r->sum;
        if(r) cnt += r->cnt;
    }
};

Node* build(int tl, int tr) {
    if(tl == tr) return new Node((ll)0, (ll)0);
    int tm = (tl + tr) / 2;
    return new Node(build(tl, tm), build(tm+1, tr));
}

Node* update(Node *u, int tl, int tr, int p, ll v) {
    if(tl == tr) return new Node(u->sum + v, u->cnt + (v > 0 ? 1 : -1));
    int tm = (tl + tr) / 2;
    if(p <= tm) return new Node(update(u->l, tl, tm, p, v), u->r);
    return new Node(u->l, update(u->r, tm+1, tr, p, v));
}

pll merge(pll a, pll b) {
    return { a.first + b.first, a.second + b.second };
}

pll query(Node *u, int tl, int tr, int l, int r) {
    if(tl > r || l > tr) return { 0, 0 };
    if(l <= tl && tr <= r) return { u->sum, u->cnt };
    int tm = (tl + tr) / 2;
    return merge(query(u->l, tl, tm, l, r), query(u->r, tm+1, tr, l, r));
}

vector<int> graph[maxn], vals[maxn], in(maxn), out(maxn), d(maxn), euler;
int timer = 0, up[maxn][20], cnt[maxn][20];

void dfs_init(int u, int p) {
    in[u] = timer++; euler.push_back(u);
    for(int i=1; i<20; i++) up[u][i] = up[up[u][i-1]][i-1];

    for(int &v : graph[u]) {
        if(v == p) continue;
        d[v] = d[u] + 1;
        up[v][0] = u;
        dfs_init(v, u);
    }

    out[u] = timer++;
    euler.push_back(u);
}

void dfs2(int u, int p) {
    for(int i=1; i<20; i++) cnt[u][i] = cnt[u][i-1] + cnt[up[u][i-1]][i-1];

    for(int &v : graph[u]) {
        if(v == p) continue;
        dfs2(v, u);
    }
}

int get_lca(int a, int b) {
    if(d[a] < d[b]) swap(a, b);
    int D = d[a] - d[b];
    for(int j=19; j>=0; j--)
        if(D & (1 << j)) a = up[a][j];
    if(a == b) return a;
    for(int j=19; j>=0; j--)
        if(up[a][j] != up[b][j]) a = up[a][j], b = up[b][j];
    return up[a][0];
}

int get_cnt(int a, int b) {
    if(d[a] < d[b]) swap(a, b);
    int D = d[a] - d[b], ans = 0;
    for(int j=19; j>=0; j--)
        if(D & (1 << j)) ans += cnt[a][j], a = up[a][j];
    if(a == b) return ans;

    for(int j=19; j>=0; j--) {
        if(up[a][j] != up[b][j]) {
            ans += cnt[a][j] + cnt[b][j];
            a = up[a][j]; b = up[b][j];
        }
    }

    return ans + cnt[a][0] + cnt[b][0];
}

signed main() {
    ios_base::sync_with_stdio(false);
    cout.tie(0); cin.tie(0);

    int n, m, q;
    cin >> n >> m >> q;
    vector<array<int, 2> > edges;

    const int B = sqrt(2 * n);

    for(int i=0; i<n-1; i++) {
        int a, b; cin >> a >> b;
        graph[a].push_back(b);
        graph[b].push_back(a);
        edges.push_back({ a, b });
    }

    dfs_init(1, 1);

    vector<pll> vec;
    for(int i=0; i<m; i++) {
        int id, x; cin >> id >> x;
        if(d[edges[id-1][0]] > d[edges[id-1][1]]) swap(edges[id-1][0], edges[id-1][1]);
        vec.push_back({ x, edges[id-1][1] });
        cnt[edges[id-1][1]][0]++;
    }

    dfs2(1, 1);

    sort(vec.begin(), vec.end());

    vector<Node*> roots; roots.push_back(build(0, 2*n-1));

    for(auto &[val, u] : vec) {
        roots.push_back(update(roots.back(), 0, 2*n-1, in[u], val));
        roots.push_back(update(roots.back(), 0, 2*n-1, out[u], -val));
    }

    auto Q = [&](int t, int a, int b) {
        pll x = query(roots[t], 0, 2*n-1, 0, in[a]);
        pll y = query(roots[t], 0, 2*n-1, 0, in[b]);
        pll z = query(roots[t], 0, 2*n-1, 0, in[get_lca(a, b)]);
        return pll{ x.first + y.first - 2 * z.first, x.second + y.second - 2 * z.second };
    };

    while(q--) {
        ll a, b, G, S;
        cin >> a >> b >> G >> S;

        if(a == b) {
            cout << G << '\n';
            continue;
        }
       
        ll total = get_cnt(a, b);

        int l=0, r=m, can=0;
        while(l <= r) {
            int mid = (l + r) / 2;
            auto res = Q(2*mid, a, b);
            if(res.first <= S) can = res.second, l = mid + 1;
            else r = mid - 1; 
        }

        cout << max(-1ll, G - total + can) << '\n';
    }

    return 0;
}

컴파일 시 표준 에러 (stderr) 메시지

currencies.cpp: In function 'int main()':
currencies.cpp:117:15: warning: unused variable 'B' [-Wunused-variable]
  117 |     const int B = sqrt(2 * n);
      |               ^
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...