Submission #745555

#TimeUsernameProblemLanguageResultExecution timeMemory
745555onebit1024Two Currencies (JOI23_currencies)C++17
100 / 100
3694 ms657712 KiB
#include <bits/stdc++.h>
using namespace std;
 
#define int long long
#define pb push_back
#define all(c) c.begin(), c.end()
#define endl "\n"
 
const double PI=3.141592653589;
 
 
void __print(int x) {cerr << x;}
void __print(long x) {cerr << x;}
void __print(unsigned x) {cerr << x;}
void __print(unsigned long x) {cerr << x;}
void __print(unsigned long long x) {cerr << x;}
void __print(float x) {cerr << x;}
void __print(double x) {cerr << x;}
void __print(long double x) {cerr << x;}
void __print(char x) {cerr << '\'' << x << '\'';}
void __print(const char *x) {cerr << '\"' << x << '\"';}
void __print(const string &x) {cerr << '\"' << x << '\"';}
void __print(bool x) {cerr << (x ? "true" : "false");}
 
template<typename T, typename V>
void __print(const pair<T, V> &x) {cerr << '{'; __print(x.first); cerr << ','; __print(x.second); cerr << '}';}
template<typename T>
void __print(const T &x) {int f = 0; cerr << '{'; for (auto &i: x) cerr << (f++ ? "," : ""), __print(i); cerr << "}";}
void _print() {cerr << "]\n";}
template <typename T, typename... V>
void _print(T t, V... v) {__print(t); if (sizeof...(v)) cerr << ", "; _print(v...);}
#ifndef ONLINE_JUDGE
#define dbg(x...) cerr << "LINE(" << __LINE__ << ") -> " <<"[" << #x << "] = ["; _print(x)
#else
#define dbg(x...)
#endif

struct segtree{
    int size =1;
    vector<int> arr;
    void init(int n){
        while(size < n)size*=2;
        arr.resize(size*2, 0);
    }
 
    void set(int i, int v, int x, int lx, int rx){
        if(rx - lx == 1){
            arr[x] = v;
            return;
        }
        int m = (lx+rx)/2;
        if(i < m){
            set(i, v, 2*x+1, lx, m);
        }else{
            set(i, v, 2*x+2, m, rx);
        }
        arr[x] = arr[2*x+1]+arr[2*x+2];
    }
 
    void set(int i, int v){
        set(i, v, 0, 0, size);
    }
 
    int sol(int l, int r, int x, int lx, int rx){
        if(lx>=r || rx<=l)return 0;
        if(lx>=l && rx<=r)return arr[x];
        int m = (lx+rx)/2;
        int right = sol(l,r,2*x+2,m,rx);
        int left = sol(l,r,2*x+1,lx,m);
        return right+left;
    }
 
    int sol(int l, int r){
        return sol(l, r, 0, 0, size);
    }
};


int sz = 1, ptr = 0;
int mxn = 8e5+5,mxk = 22;
vector<pair<int,int>>seg(mxn*mxk),child(mxn*mxk);
void init(int n){
    while(sz < n)sz*=2;
}

pair<int,int> merge(pair<int,int>a, pair<int,int>b){
    return {a.first+b.first,a.second+b.second};
}

void upd(int curr, int prev, int lx, int rx, int i, pair<int,int>v){
    if(rx-lx==1){
        seg[curr] = v;
        return;
    }
    int m = (lx+rx)/2;
    if(i < m){
        child[curr].first = ++ptr;
        child[curr].second = child[prev].second;
        upd(child[curr].first,child[prev].first,lx,m,i,v);
    }else{
        child[curr].second = ++ptr;
        child[curr].first  = child[prev].first;
        upd(child[curr].second,child[prev].second,m,rx,i,v);
    }
    seg[curr] = merge(seg[child[curr].first], seg[child[curr].second]);
}


void upd(int curr, int prev, int i, pair<int,int>v){
    upd(curr,prev,0,sz,i,v);
}

pair<int,int>sol(int curr, int lx, int rx, int l, int r){
    if(curr==0)return {0,0};
    if(lx >= l && rx <= r)return seg[curr];
    if(rx <= l || lx >= r)return {0,0};
    int m = (lx+rx)/2;
    return merge(sol(child[curr].first,lx,m,l,r),sol(child[curr].second,m,rx,l,r));
}

pair<int,int>sol(int curr, int l, int r){
    return sol(curr,0,sz,l,r);
}
int f,n,m,q;
vector<vector<int>>tax,up;
vector<set<int>>adj;
vector<int>in,out,dist,val,par;
int t = 1;
void comp(int u, int p){
    for(int v : adj[u]){
        if(v==p)continue;
        par[v] = u;
        comp(v,u);
    }
}

void dfs(int u, int p){
    in[u] = t;
    t++;
    for(int v : adj[u]){
        if(v==p)continue;
        dist[v] = dist[u]+1;
        up[v][0] = u;
        par[v] = u;
        for(int j = 1;j<=20;++j)up[v][j] = up[up[v][j-1]][j-1];
        dfs(v,u);
    }
    out[u] = t;
    t++;
}

int lca(int u, int v){
    if(dist[v] > dist[u])swap(v,u);
    int k = dist[u]-dist[v];
    for(int j = 20;j>=0;--j){
        if(k&(1ll<<j))u = up[u][j];
    }
    if(v==u)return u;
    for(int j = 20;j>=0;--j){
        if(up[u][j] != up[v][j])u = up[u][j], v = up[v][j];
    }
    return up[v][0];
}

int go(int u, int k){
    for(int j = 0;j<=20;++j){
        if(k&(1ll<<j))u = up[u][j];
    }
    return u;
}
void solve()
{
    cin >> f >> m >> q;
    vector<pair<int,int>>edges = {{0,0}};
    n = f+m+1;
    tax.resize(n+1);
    adj.resize(n+1);
    up = vector<vector<int>>(n+1, vector<int>(21));
    dist.resize(n+1);
    in.resize(n+1);
    out.resize(n+1);
    val.resize(n+1);
    par.resize(n+1);
    for(int i = 1;i<f;++i){
        int u,v;
        cin >> u >> v;
        edges.pb({u,v});
        adj[u].insert(v);
        adj[v].insert(u);
    }
    comp(1,-1);
    for(int i = 1;i<=m;++i){
        int p,c;
        cin >> p >> c;
        int u = edges[p].first, v = edges[p].second;
        if(par[u]==v)swap(v,u);
        // par[v] = u
        tax[v].pb(c);
    }

    int p = f+1;

    for(int i = 1;i<=f;++i){
        if(tax[i].empty())continue;
        val[i] = tax[i][0];
        int prev = i;
        adj[par[i]].erase(i);
        adj[i].erase(par[i]);
        for(int j = 1;j<tax[i].size();++j){
            adj[p].insert(prev);
            adj[prev].insert(p);
            val[p] = tax[i][j];
            prev = p;
            p++;
            
        }
        adj[prev].insert(par[i]);
        adj[par[i]].insert(prev);
    }
    dfs(1,-1);

    init(2*n+1);

    segtree st;
    st.init(2*n+1);
    vector<pair<int,int>>feed={{0,0}};
    for(int i = 1;i<=n;++i){
        if(val[i]){
            st.set(in[i],1);
            st.set(out[i],-1);
            feed.pb({val[i],i});
        }
    }

    sort(all(feed));
    vector<int>root(mxn+1);

    for(int i = 1;i<feed.size();++i){
        int prev_root = root[i-1];
        int curr_root = ++ptr;
        int u = feed[i].second;
        root[i] = curr_root;
        upd(curr_root,prev_root,in[u],{feed[i].first,1});
        curr_root = ++ptr;
        prev_root = root[i];
        upd(curr_root,prev_root,out[u],{-feed[i].first,-1});
        root[i] = curr_root;
    }
    int sz = feed.size();

    while(q--){
        int u,v,g,s;
        cin >> u >> v >> g >> s;
        if(in[v] < in[u])swap(v,u);
        int l = 1, r = sz;
        int L = lca(u,v);
        
        int k1 = go(v,dist[v]-dist[L]-1), k2 = 0;
        if(u!=L)k2 = go(u,dist[u]-dist[L]-1);
        int tot = 0;
        if(u==L)tot = st.sol(in[k1],in[v]+1);
        else tot=st.sol(in[k1],in[v]+1)+st.sol(in[k2],in[u]+1);
        
        int res = -1;
        if(g>=tot)res = g-tot;
        while(l<= r){
            int m = (l+r)/2;
            pair<int,int>get;
            if(u==L){
                get = sol(root[m],in[k1],in[v]+1);
            }else get = merge(sol(root[m],in[k1],in[v]+1),sol(root[m],in[k2],in[u]+1));

            if(get.first <= s){
                int left = tot-get.second;
                res = max(res,max(-1ll,g-left));
                l = m+1;
            }else r = m-1;
        }
        cout << res << endl;
    }
}   


int32_t main()
{
 
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    cout.tie(NULL);
 
 
    // freopen("input.txt", "r", stdin);
    // freopen("output.txt", "w", stdout);
    
 
    int T=1;
    for(int i = 1;i<=T;++i)
    {
        // cout << "Case #" << i << ": ";
        solve();
    }
}

Compilation message (stderr)

currencies.cpp: In function 'void solve()':
currencies.cpp:209:24: warning: comparison of integer expressions of different signedness: 'long long int' and 'std::vector<long long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  209 |         for(int j = 1;j<tax[i].size();++j){
      |                       ~^~~~~~~~~~~~~~
currencies.cpp:238:20: warning: comparison of integer expressions of different signedness: 'long long int' and 'std::vector<std::pair<long long int, long long int> >::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  238 |     for(int i = 1;i<feed.size();++i){
      |                   ~^~~~~~~~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...