제출 #1162774

#제출 시각아이디문제언어결과실행 시간메모리
1162774yeediotJOI tour (JOI24_joitour)C++20
0 / 100
1749 ms396832 KiB
#include<bits/stdc++.h>
using namespace std;
#define F first
#define S second
#define all(x) x.begin(),x.end()
#define pii pair<long long,long long>
#define pb push_back
#define sz(x) (long long)(x.size())
#define chmin(x,y) x=min(x,y)
#ifdef local
void CHECK();
void setio(){
    freopen("/Users/iantsai/cpp/input.txt","r",stdin);
    freopen("/Users/iantsai/cpp/output.txt","w",stdout);
}
#else
void setio(){}
#endif
#define TOI_is_so_de ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);setio();
const long long mxn = 2e5 + 5;
struct BIT{
    vector<vector<long long>>bit;
    vector<long long>sum;
    void init(long long tmr){
        sum = vector<long long>(3);
        bit = vector<vector<long long>>(3, vector<long long>(tmr + 1));
    }
    void m(long long d, long long p, long long v){
        for(; p < sz(bit[d]); p += p & -p) bit[d][p] += v;
    }
    void m(long long d, long long l, long long r, long long v){
        m(d, l, v);
        m(d, r + 1, -v);
    }
    long long q(long long d, long long p){
        long long r = 0;
        for(; p; p -= p & -p) r += bit[d][p];
        return r;
    }
    long long q(long long d, long long l, long long r){
        return q(d, r) - q(d, l - 1);
    }
};
long long ans = 0, ty[mxn], cnt[mxn], tmr, neg[mxn];
vector<long long>adj[mxn], p[mxn], in[mxn], out[mxn], num[mxn], ord;
int cal[mxn][3];
vector<vector<long long>>sum[mxn], s2[mxn];
bool vis[mxn];
BIT bt[mxn];
void upd(long long v, long long d, bool f){
    if(d == -1){
        //cout << ans << '\n';
        long long pa = p[v].back(), l = in[v].back(), r = out[v].back(), id = num[v].back();
        if(ty[v] == 1) ans += d * (bt[pa].sum[0] * bt[pa].sum[2] - neg[pa]);
        else if(ty[v] == 0){
            ans += d * cal[pa][2];
        }
        else if(ty[v] == 2){
            ans += d * cal[pa][0];
        }
        //cout << v << ' ' << pa << ' ' << bt[pa].sum[0] << ' ' << bt[pa].sum[2] << ' ' << neg[pa] << ' ' << cal[pa][0] << ' ' << cal[pa][2] << ' ' << ans << '\n';
    }
    for(long long _ = 0; _ < sz(p[v]) - 1; _++){
        long long pa = p[v][_], l = in[v][_], r = out[v][_], id = num[v][_];
        //cout << v << ' ' << pa << ' ' << l << ' ' << r << ' ' << id << ' ' << ans << '\n';
        if(ty[v] == 0){
            if(d == 1){
                bt[pa].m(0, l, d);
                sum[pa][id][0] += d;
                bt[pa].sum[0] += d;
                neg[pa] += sum[pa][id][2] * d;
            }
            int val = d * bt[pa].q(1, l);//10
            cal[pa][0] += val;
            s2[pa][id][0] += val;
            ans += (bt[pa].sum[2] - sum[pa][id][2]) * val;
            ans += (cal[pa][2] - s2[pa][id][2]) * d;
            if(d == -1){
                bt[pa].m(0, l, d);
                sum[pa][id][0] += d;
                bt[pa].sum[0] += d;
                neg[pa] += sum[pa][id][2] * d;
            }
        }
        else if(ty[v] == 1){
            if(d == 1){
                bt[pa].m(1, l, r, d);
                bt[pa].sum[1] += d;
                sum[pa][id][1] += d;
            }
            int val = bt[pa].q(0, l, r) * d;//10
            cal[pa][0] += val;//10
            s2[pa][id][0] += val;
            ans += val * (bt[pa].sum[2] - sum[pa][id][2]);
            val = bt[pa].q(2, l, r) * d;//12
            cal[pa][2] += val;//12
            s2[pa][id][2] += val;
            ans += val * (bt[pa].sum[0] - sum[pa][id][0]);
            if(d == -1){
                bt[pa].m(1, l, r, d);
                bt[pa].sum[1] += d;
                sum[pa][id][1] += d;
            }
        }
        else{
            if(d == 1){
                bt[pa].m(2, l, d);
                sum[pa][id][2] += d;
                bt[pa].sum[2] += d;
                neg[pa] += sum[pa][id][0] * d;
            }
            int val = d * bt[pa].q(1, l);//12
            cal[pa][2] += val;
            s2[pa][id][2] += val;
            ans += (bt[pa].sum[0] - sum[pa][id][0]) * val;
            ans += (cal[pa][0] - s2[pa][id][0]) * d;
            if(d == -1){
                bt[pa].m(2, l, d);
                sum[pa][id][2] += d;
                bt[pa].sum[2] += d;
                neg[pa] += sum[pa][id][0] * d;
            }
        }
    }
    if(d == 1){
        //cout << ans << '\n';
        long long pa = p[v].back(), l = in[v].back(), r = out[v].back(), id = num[v].back();
        if(ty[v] == 1) ans += d * (bt[pa].sum[0] * bt[pa].sum[2] - neg[pa]);
        else if(ty[v] == 0){
            ans += d * (cal[pa][2]);
        }
        else if(ty[v] == 2){
            ans += d * (cal[pa][0]);
        }
        //cout << v << ' ' << pa << ' ' << bt[pa].sum[0] << ' ' << bt[pa].sum[2] << ' ' << neg[pa] << ' ' << cal[pa][0] << ' ' << cal[pa][2] << ' ' << ans << '\n';
    }
}
void cnt_sz(long long v, long long pa){
    cnt[v] = 1;
    for(auto u : adj[v]){
        if(u == pa or vis[u]) continue;
        cnt_sz(u, v);
        cnt[v] += cnt[u];
    }
}
long long find(long long v, long long pa, long long tar){
    for(auto u : adj[v]){
        if(u == pa or vis[u]) continue;
        if(cnt[u] * 2 > tar) return find(u, v, tar);
    }
    return v;
}
void dfs(long long v, long long pa, long long top, long long cc){
    p[v].pb(top);
    in[v].pb(++tmr);
    num[v].pb(cc);
    for(auto u : adj[v]){
        if(u == pa or vis[u]) continue;
        dfs(u, v, top, cc);
    }
    out[v].pb(tmr);
}
void cd(long long v){
    cnt_sz(v, v);
    v = find(v, v, cnt[v]);
    vis[v] = 1;
    ord.pb(v);
    tmr = 0;
    long long cc = 0;
    p[v].pb(v);
    in[v].pb(++tmr);
    num[v].pb(-1);
    for(auto u : adj[v]){
        if(vis[u]) continue;
        dfs(u, v, v, cc);
        cc++;
    }
    sum[v] = vector<vector<long long>>(cc, vector<long long>(3));
    s2[v] = vector<vector<long long>>(cc, vector<long long>(3));
    out[v].pb(tmr);
    for(long long i = 0; i < 3; i ++){
        bt[v].init(tmr);
    }
    for(auto u : adj[v]){
        if(vis[u]) continue;
        cd(u);
    }
}






#ifdef local
void solve(){
    long long n;
    cin >> n;
    for(long long i = 1; i <= n; i++){
        cin >> ty[i];
    }
    for(long long i = 1; i < n; i++){
        long long a, b;
        cin >> a >> b;
        a++, b++;
        adj[a].pb(b);
        adj[b].pb(a);
    }
    cd(1);
    reverse(all(ord));
    for(auto v : ord){
        upd(v, 1, 0);
    }
    cout << "prt" << ans << '\n' << '\n';
    long long q;
    cin >> q;
    while(q--){
        int x, y;
        cin >> x >> y;
        x++;
        upd(x, -1, 1);
        ty[x] = y;
        upd(x, 1, 1);
        cout << "prt" << ans << '\n';
    }
}
signed main(){
    TOI_is_so_de;
    long long t = 1;
    //cin >> t;
    while(t--){
        solve();
    }
    #ifdef local
    CHECK();
    #endif
}
#else 
#include "joitour.h"
void init(int n, vector<int>typ, vector<int>ea, vector<int>eb, int q){
    for(long long i = 1; i <= n; i ++){
        ty[i] = typ[i - 1];
    }
    for(long long i = 1; i < n; i++){
        ea[i - 1]++;
        eb[i - 1]++;
        adj[ea[i - 1]].pb(eb[i - 1]);
        adj[eb[i - 1]].pb(ea[i - 1]);
    }
    cd(1);
    reverse(all(ord));
    for(auto v : ord){
        upd(v, 1, 0);
    }
}
void change(int x, int y){
    upd(x + 1, -1, 1);
    ty[x + 1] = y;
    upd(x + 1, 1, 1);
}
long long num_tours(){
    return ans;
}
#endif
/*
input:
 
*/
#ifdef local
void CHECK(){
    cerr << "\n[Time]: " << 1000.0 * clock() / CLOCKS_PER_SEC << " ms.\n";
    function<bool(string,string)> compareFiles = [](string p1, string p2)->bool {
        std::ifstream file1(p1);
        std::ifstream file2(p2);
        if(!file1.is_open() || !file2.is_open()) return false;
        std::string line1, line2;
        while (getline(file1, line1) && getline(file2, line2)) {
            if (line1 != line2)return false;
        }
        long long cnta = 0, cntb = 0;
        while(getline(file1,line1))cnta++;
        while(getline(file2,line2))cntb++;
        return cntb - cnta <= 1;
    };
    bool check = compareFiles("output.txt","expected.txt");
    if(check) cerr<<"ACCEPTED\n";
    else cerr<<"WRONG ANSWER!\n";
}
#else
#endif



#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...