Submission #1233864

#TimeUsernameProblemLanguageResultExecution timeMemory
1233864shjeongRegions (IOI09_regions)C++20
30 / 100
8023 ms196608 KiB
#include <bits/stdc++.h>
using namespace std;
#pragma GCC optimize("O3")
#pragma GCC optimize("Ofast")
#pragma optimize("unroll-loops")
#define fast ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define all(x) x.begin(), x.end()
#define rll(x) x.rbegin(), x.rend()
#define COMP(x) x.erase(unique(all(x)), x.end())
#define MOD 1000000007
#define MOD2 998244353
#define sz(x) (ll)x.size()
typedef __int128_t lll;
typedef long long ll;
typedef unsigned long long ull;
typedef long double ld;
typedef pair<ll,ll> pll;
typedef pair<ll, pll> PP;
const ll Lnf = 2e18;
ll n, r, q;
struct Node{
    ll l, r, v;
    Node(): l(-1), r(-1), v(0){}
};
ll root[252525];
struct PST{
    vector<Node> tree;
    ll ins(){ tree.push_back(Node()); return sz(tree)-1; }
    void init(ll node, ll s, ll e){
        if(s==e)return;
        tree[node].l = ins(), tree[node].r = ins();
        ll mid = s+e>>1;
        init(tree[node].l,s,mid); init(tree[node].r,mid+1,e);
    }
    void upd(ll prv, ll cur, ll s, ll e, ll i, ll d){
        tree[cur].v = tree[prv].v + d;
        if(s==e)return;
        ll mid = s+e>>1;
        if(i<=mid){
            tree[cur].r = tree[prv].r;
            tree[cur].l = ins();
            upd(tree[prv].l,tree[cur].l,s,mid,i,d);
        }
        else{
            tree[cur].l = tree[prv].l;
            tree[cur].r = ins();
            upd(tree[prv].r,tree[cur].r,mid+1,e,i,d);
        }
    }
    ll query(ll prv, ll cur, ll s, ll e, ll l, ll r){
        if(e<l or r<s or cur<0)return 0;
        if(l<=s and e<=r)return tree[cur].v - (prv<0?0:tree[prv].v);
        ll mid = s+e>>1;
        return query(tree[prv].l,tree[cur].l,s,mid,l,r) + query(tree[prv].r,tree[cur].r,mid+1,e,l,r);
    }
} seg;
ll in[252525], out[252525];
ll cnt;
vector<ll> adj[252525];
ll ans[501][501];   //루트보다 큰 것들에 대한 답
int res[250001][501];    //각 정점->루트보다 큰 거 답(올라가면서)
ll regions[250001];
ll conv[250001];    //region convert
ll sqn;
void dfs(ll x, ll prv=-1){
    in[x] = ++cnt;
    root[cnt] = seg.ins();
    seg.upd(root[cnt-1],root[cnt],1,r,regions[x],1);
    for(auto next : adj[x])if(prv!=next){
        for(int i = 1 ; i <= sqn ; i++)res[next][i] = res[x][i] + (regions[x] == i);
        dfs(next,x);
    }
    out[x] = cnt;
}
vector<ll> reg[252525];
int main(){
    fast;
    cin>>n>>r>>q;
    vector<array<ll,2>> cnt(r+1);
    cin>>regions[1]; cnt[regions[1]][0]++;
    for(int i = 2 ; i <= n ; i++){
        ll x; cin>>x; adj[x].push_back(i);
        cin>>regions[i]; cnt[regions[i]][0]++;
    }
    for(int i = 1 ; i <= r ; i++)cnt[i][1] = i;
    sort(rll(cnt));
    for(int i = 0 ; i < r ; i++){
        conv[cnt[i][1]] = i+1;
    }
    for(int i = 1 ; i <= n ; i++)regions[i] = conv[regions[i]], reg[regions[i]].push_back(i);
    ll t = sqrt(n);
    for(int i = 1 ; i <= r ; i++){
        if(t > sz(reg[i]))break;
        sqn=i;
    }
    root[0] = seg.ins(); seg.init(0,1,r);
    dfs(1);
    for(int i = 1 ; i <= n ; i++)if(regions[i]<=sqn and in[i] != out[i]){
        for(int j = 1 ; j <= sqn ; j++){
            ans[regions[i]][j] += seg.query(root[in[i]],root[out[i]],1,r,j,j);
        }
    }
    while(q--){
        ll r1, r2; cin>>r1>>r2; r1 = conv[r1], r2 = conv[r2];
        if(r1>sqn){
            ll s = 0;
            for(auto i : reg[r1])s += seg.query(root[in[i]],root[out[i]],1,r,r2,r2);
            cout<<s<<endl;
        }
        else if(r2<=sqn){
            cout<<ans[r1][r2]<<endl;
        }
        else{
            ll s = 0;
            for(auto i : reg[r2])s += res[i][r1];
            cout<<s<<endl;
        }
    }
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...