제출 #1162048

#제출 시각아이디문제언어결과실행 시간메모리
1162048guagua0407JOI tour (JOI24_joitour)C++20
100 / 100
2180 ms400392 KiB
//#include "grader.cpp"
#include "joitour.h"
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define pii pair<ll,ll>
#define f first
#define s second
#define all(x) x.begin(),x.end()
#define _ ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);

namespace{
    const ll mxn=2e5+5;
    vector<ll> adj[mxn];
    vector<ll> a(mxn);
    vector<ll> sz(mxn);
    vector<bool> visited(mxn);
    vector<vector<ll>> st(mxn);
    vector<vector<ll>> en(mxn);
    vector<vector<pair<ll,ll>>> par(mxn);
    vector<vector<ll>> cnt(mxn,vector<ll>(3));
    vector<vector<vector<ll>>> cntc(mxn);
    vector<vector<ll>> sum(mxn,vector<ll>(3));
    vector<vector<vector<ll>>> sumc(mxn);
    vector<ll> mulsum(mxn);
    ll res=0;
    struct BIT{
        vector<ll> bit;
        ll n;
        void init(ll _n){
            n=_n;
            bit=vector<ll>(n+1,0);
        }
        void update(ll pos,ll val){
            if(pos>n) return;
            for(;pos<=n;pos+=(pos&-pos)){
                bit[pos]+=val;
            }
        }
        ll query(ll pos){
            ll ans=0;
            for(;pos>0;pos-=(pos&-pos)){
                ans+=bit[pos];
            }
            return ans;
        }
    };
    vector<vector<BIT>> bit(mxn,vector<BIT>(3));
    ll timer=1;
    ll dfs(ll v,ll p=0){
        sz[v]=1;
        for(auto u:adj[v]){
            if(u==p or visited[u]) continue;
            sz[v]+=dfs(u,v);
        }
        return sz[v];
    }
    ll find(ll v,ll tot,ll p=0){
        for(auto u:adj[v]){
            if(u==p or visited[u]) continue;
            if(sz[u]*2>tot) return find(u,tot,v);
        }
        return v;
    }
    void dfs1(ll v,ll p,ll pp,ll ppp){
        st[v].push_back(++timer);
        //cout<<"p "<<p<<' '<<v<<'\n';
        par[v].push_back({ppp,pp});
        for(auto u:adj[v]){
            if(u==p or visited[u]) continue;
            dfs1(u,v,pp,ppp);
        }
        en[v].push_back(timer);
    }
    void centroid(ll v=1){
        //cout<<"cen "<<v<<'\n';
        v=find(v,dfs(v));
        timer=1;
        st[v].push_back(1);
        par[v].push_back({v,-1});
        ll cnt=0;
        for(auto u:adj[v]){
            if(visited[u]) continue;
            dfs1(u,v,cnt,v);
            cnt++;
        }
        cntc[v]=vector<vector<ll>>(cnt,vector<ll>(3));
        sumc[v]=vector<vector<ll>>(cnt,vector<ll>(3));
        en[v].push_back(timer);
        for(ll t=0;t<3;t++){
            bit[v][t].init(timer);
        }
        visited[v]=true;
        for(auto u:adj[v]){
            if(visited[u]) continue;
            centroid(u);
        }
    }
    void add(ll v,ll d,bool b=false){
        ll tmp=0;
        for(ll i=0;i<(ll)par[v].size();i++){
            auto p=par[v][i];
            if(v==p.f){
                if(a[v]==1){
                    tmp+=1ll*cnt[v][2]*cnt[v][0];
                    tmp-=mulsum[v];//sum(cnt[u][0]*cnt[u][2])
                }
                else{
                    tmp+=sum[v][2-a[v]];
                }
                //cout<<p.f<<' '<<tmp<<'\n';
                continue;
            }
            if(a[v]==1){
                bit[p.f][1].update(st[v][i],d);
                bit[p.f][1].update(en[v][i]+1,-d);
                ll cnt0=bit[p.f][0].query(en[v][i])-bit[p.f][0].query(st[v][i]-1);
                sumc[p.f][p.s][0]+=cnt0*d;
                sum[p.f][0]+=cnt0*d;
                tmp+=1ll*cnt0*(cnt[p.f][2]-cntc[p.f][p.s][2]);
                ll cnt2=bit[p.f][2].query(en[v][i])-bit[p.f][2].query(st[v][i]-1);
                sumc[p.f][p.s][2]+=cnt2*d;
                sum[p.f][2]+=cnt2*d;
                tmp+=1ll*cnt2*(cnt[p.f][0]-cntc[p.f][p.s][0]);
                if(a[p.f]==0 and ((!b) or p.f<v)){
                    tmp+=cnt2;
                }
                else if(a[p.f]==2 and ((!b) or p.f<v)){
                    tmp+=cnt0;
                }
            }
            else{
                bit[p.f][a[v]].update(st[v][i],d);
                ll cnt1=bit[p.f][1].query(st[v][i]);
                cnt[p.f][a[v]]+=d;
                cntc[p.f][p.s][a[v]]+=d;
                sum[p.f][a[v]]+=cnt1*d;
                sumc[p.f][p.s][a[v]]+=cnt1*d;
                tmp+=(sum[p.f][2-a[v]]-sumc[p.f][p.s][2-a[v]]);
                tmp+=1ll*cnt1*(cnt[p.f][2-a[v]]-cntc[p.f][p.s][2-a[v]]);
                if(a[p.f]==1 and ((!b) or p.f<v)){
                    tmp+=(cnt[p.f][2-a[v]]-cntc[p.f][p.s][2-a[v]]);
                }
                else if(a[p.f]==2-a[v] and ((!b) or p.f<v)){
                    tmp+=cnt1;
                }
                mulsum[p.f]+=d*cntc[p.f][p.s][2-a[v]];
            }
            //cout<<p.f<<' '<<tmp<<'\n';
        }
        res+=tmp*d;
        //cout<<v<<' '<<d<<' '<<res<<'\n';
    }
};

void init(int n, std::vector<int> F, std::vector<int> U, std::vector<int> V,
          int Q){
    for(ll i=1;i<=n;i++){
        a[i]=F[i-1];
    }
    for(ll i=0;i<n-1;i++){
        U[i]++;
        V[i]++;
        adj[V[i]].push_back(U[i]);
        adj[U[i]].push_back(V[i]);
    }
    centroid();
    /*for(ll i=1;i<=n;i++){
            cout<<i<<'\n';
        for(auto p:par[i]){
            cout<<p.f<<' '<<p.s<<'\n';
        }
    }*/
    for(ll i=1;i<=n;i++){
        add(i,1,true);
    }
}

void change(int x, int y) {
    x++;
    //cout<<'\n';
    //cout<<"q "<<x<<' '<<y<<'\n';
    add(x,-1);
    a[x]=y;
    add(x,1);
}

long long num_tours() {
    return res;
}
/*
11
0 1 1 2 1 0 2 1 2 2 1
0 1
0 2
0 3
1 4
1 5
2 7
2 8
3 9
5 6
9 10
0

9
0 1 2 0 1 2 0 1 2
0 1
1 2
2 3
3 4
4 5
5 6
6 7
7 8
0
*/
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...