Submission #1082662

# Submission time Handle Problem Language Result Execution time Memory
1082662 2024-09-01T04:45:15 Z Zero_OP JOI tour (JOI24_joitour) C++17
Compilation error
0 ms 0 KB
#include "joitour.h"
#include "bits/stdc++.h"

#define dbg(x) "[" << #x << " = " << (x) << "]"

using namespace std;

const int MAX = 2e5 + 5;
const int num10 = 0;
const int num12 = 1;
const int num0 = 2;
const int num2 = 3;
const int num1 = 4;

int N, a[MAX], par[MAX], depth[MAX], tin[18][MAX], tout[18][MAX], sz[MAX], id_subtree[18][MAX], timer_subtree, timer_dfs;
vector<int> adj[MAX];
bool used[MAX];
long long cnt_total[4][MAX], cnt_subtree[4][MAX], num_ways[MAX], ans;

//num_ways mean case f[c] = 1

struct Fenwick{
    vector<int> bit;
    Fenwick(int n) : bit(n + 1, 0) {}
    Fenwick() : bit() {}

    void update(int id, int val){
        for(; id < (int)bit.size(); id += id & (-id)){
          bit[id] += val;
        }
    }

    int query(int id){
        int sum = 0;
        for(; id > 0; id -= id & (-id)){
            sum += bit[id];
        }
        return sum;
    }

    void updateRange(int l, int r, int val){
        update(l, val);
        update(r + 1, -val);
    }

    int queryRange(int l, int r){
        return query(r) - query(l - 1);
    }
};

Fenwick ft[5][MAX];

int dfs_size(int u, int pre){
    sz[u] = 1;
    for(int v : adj[u]) if(v != pre && !used[v]){
        sz[u] += dfs_size(v, u);
    }
    return sz[u];
}

int find_centroid(int u, int pre, int target){
    for(int v : adj[u]) if(v != pre && !used[v] && sz[v] > (target / 2)){
        return find_centroid(v, u, target);
    }
    return u;
}

int c;
void dfs_subtree(int u, int pre, int layer){
    tin[layer][u] = ++timer_dfs;
    id_subtree[layer][u] = timer_subtree;

    if(a[u] == 0){
        ft[num0][c].update(tin[layer][u], +1);
        ++cnt_total[num0][c];
        ++cnt_subtree[num0][timer_subtree];
    } else if(a[u] == 2){
        ft[num2][c].update(tin[layer][u], +1);
        ++cnt_total[num2][c];
        ++cnt_subtree[num2][timer_subtree];
    }

    for(int v : adj[u]) if(v != pre && !used[v]){
        dfs_subtree(v, u, layer);
        if(pre == -1) ++timer_subtree;
    }
    
    tout[layer][u] = timer_dfs;

    if(a[u] == 1){
        ft[num1][c].updateRange(tin[layer][u], tout[layer][u], +1);
        int cnt10 = ft[num0][c].queryRange(tin[layer][u], tout[layer][u]);
        int cnt12 = ft[num2][c].queryRange(tin[layer][u], tout[layer][u]);
        cnt_subtree[num10][timer_subtree] += cnt10;
        cnt_subtree[num12][timer_subtree] += cnt12;
        cnt_total[num10][c] += cnt10;
        cnt_total[num12][c] += cnt12;
    } 
}

void decompose(int u, int layer, int p){
    int size = dfs_size(u, -1);
    u = find_centroid(u, -1, size);
    used[u] = true;
    par[u] = p;
    depth[u] = layer;

    ft[num10][u] = Fenwick(size);
    ft[num12][u] = Fenwick(size);
    ft[num0][u] = Fenwick(size);
    ft[num1][u] = Fenwick(size);
    ft[num2][u] = Fenwick(size);

    c = u;
    timer_dfs = 0;
    for(int v : adj[u]) if(!used[v]){
        dfs_subtree(v, u, layer);
        num_ways[u] -= cnt_subtree[num0][timer_subtree] * cnt_subtree[num2][timer_subtree];
        ans -= cnt_subtree[num0][timer_subtree] * cnt_subtree[num12][timer_subtree];
        ans -= cnt_subtree[num2][timer_subtree] * cnt_subtree[num10][timer_subtree];
        ++timer_subtree;
    }

    num_ways[u] += cnt_total[num0][u] * cnt_total[num2][u];
    ans += cnt_total[num0][u] * cnt_total[num12][u];
    ans += cnt_total[num2][u] * cnt_total[num10][u];

    if(a[u] == 0) ans += cnt_total[num12][u];
    if(a[u] == 1) ans += num_ways[u];
    if(a[u] == 2) ans += cnt_total[num10][u];

    for(int v : adj[u]) if(!used[v]){
        decompose(v, layer + 1, u);
    }
}

void init(int N, std::vector<int> F, std::vector<int> U, std::vector<int> V, int Q) {
    ::N = N;
    for(int i = 0; i < N; ++i) a[i] = F[i];
    for(int i = 0; i < N - 1; ++i){
        int u = U[i], v = V[i];
        adj[u].push_back(v);
        adj[v].push_back(u);
    }

    decompose(0, 0, -1);
}

void change(int u, int nval) {
    if(a[u] == nval) return;
    int dep = depth[u], c = u;
    while(c != -1){
        if(c == u){
            if(a[u] == 0) ans -= cnt_total[num12][u];
            if(a[u] == 1) ans -= num_ways[u];
            if(a[u] == 2) ans -= cnt_total[num10][u];

            if(nval == 0) ans += cnt_total[num12][u];
            if(nval == 1) ans += num_ways[u];
            if(nval == 2) ans += cnt_total[num10][u];
        } else{
            int id = id_subtree[dep][u];

            ///remove data
            if(a[c] == 0) ans -= cnt_total[num12][c];
            if(a[c] == 1) ans -= num_ways[c];
            if(a[c] == 2) ans -= cnt_total[num10][c];

            num_ways[c] += cnt_subtree[num0][id] * cnt_subtree[num2][id];
            ans += cnt_subtree[num0][id] * cnt_subtree[num12][id];
            ans += cnt_subtree[num2][id] * cnt_subtree[num10][id];

            num_ways[c] -= cnt_total[num0][c] * cnt_total[num2][c];
            ans -= cnt_total[num0][c] * cnt_total[num12][c];
            ans -= cnt_total[num2][c] * cnt_total[num10][c];

            if(a[u] == 0){
                int contribution = ft[num1][c].query(tin[dep][u]);
                cnt_subtree[num10][id] -= contribution;
                --cnt_subtree[num0][id];

                cnt_total[num10][c] -= contribution;
                --cnt_total[num0][c];
                ft[num0][c].update(tin[dep][u], -1);
            } else if(a[u] == 1){
                int cnt10 = ft[num0][c].queryRange(tin[dep][u], tout[dep][u]);
                int cnt12 = ft[num2][c].queryRange(tin[dep][u], tout[dep][u]);
                cnt_subtree[num10][id] -= cnt10;
                cnt_subtree[num12][id] -= cnt12;
                cnt_total[num10][c] -= cnt10;
                cnt_total[num12][c] -= cnt12;
                ft[num1][c].updateRange(tin[dep][u], tout[dep][u], -1);
            } else{
                int contribution = ft[num1][c].query(tin[dep][u]);
                cnt_subtree[num12][id] -= contribution;
                --cnt_subtree[num2][id];
                
                cnt_total[num12][c] -= contribution;
                --cnt_total[num2][c];
                ft[num2][c].update(tin[dep][u], -1);
            }

            //end remove data
            //replace data

            if(nval == 0){
                int contribution = ft[num1][c].query(tin[dep][u]);
                cnt_subtree[num10][id] += contribution;
                ++cnt_subtree[num0][id];

                cnt_total[num10][c] += contribution;
                ++cnt_total[num0][c];
                ft[num0][c].update(tin[dep][u], +1);
            } else if(nval == 1){
                int cnt10 = ft[num0][c].queryRange(tin[dep][u], tout[dep][u]);
                int cnt12 = ft[num2][c].queryRange(tin[dep][u], tout[dep][u]);
                cnt_subtree[num10][id] += cnt10;
                cnt_subtree[num12][id] += cnt12;
                cnt_total[num10][c] += cnt10;
                cnt_total[num12][c] += cnt12;
                ft[num1][c].updateRange(tin[dep][u], tout[dep][u], +1);
            } else{
                int contribution = ft[num1][c].query(tin[dep][u]);
                cnt_subtree[num12][id] += contribution;
                ++cnt_subtree[num2][id];
                
                cnt_total[num12][c] += contribution;
                ++cnt_total[num2][c];
                ft[num2][c].update(tin[dep][u], +1);
            }

            num_ways[c] += cnt_total[num0][c] * cnt_total[num2][c];
            ans += cnt_total[num0][c] * cnt_total[num12][c];
            ans += cnt_total[num2][c] * cnt_total[num10][c];

            num_ways[c] -= cnt_subtree[num0][id] * cnt_subtree[num2][id];
            ans -= cnt_subtree[num0][id] * cnt_subtree[num12][id];
            ans -= cnt_subtree[num2][id] * cnt_subtree[num10][id];

            if(a[c] == 0) ans += cnt_total[num12][c];
            if(a[c] == 1) ans += num_ways[c];
            if(a[c] == 2) ans += cnt_total[num10][c];

            //end replace data
        }

        c = par[c]; --dep;
    }

    a[u] = nval;
}

long long num_tours() {
    return ans;
}

int main(){
    ios_base::sync_with_stdio(0);
    cin.tie(0);

    freopen("task.inp", "r", stdin);
    freopen("task.out", "w", stdout);

    int N; cin >> N;
    vector<int> F(N);

    for(int i = 0; i < N; ++i) cin >> F[i];

    vector<int> U(N), V(N);
    for(int i = 0; i < N - 1; ++i) cin >> U[i] >> V[i];

    int Q; cin >> Q;
    init(N, F, U, V, Q);
    
    cout << num_tours() << "\n";
    while(Q--){
        int X, Y;
        cin >> X >> Y;
        change(X, Y);
        cout << num_tours() << '\n';
    }

    return 0;
}

/*


Test 1 : 
Input
3
0 1 2
0 1
1 2
0

Output : 
1

Test 2 : 
Input :
3
0 1 2
0 1
1 2
2
2 0
0 2

Output
1
0
1


Test 3 : 
Input : 
7
1 0 2 2 0 1 0
0 1
0 2
1 3
1 4
2 5
2 6
7
0 0
1 1
2 0
3 0
4 2
5 2
6 2

Ouptut : 3
0
4
4
0
4
5
5


*/

Compilation message

joitour.cpp: In function 'int main()':
joitour.cpp:261:12: warning: ignoring return value of 'FILE* freopen(const char*, const char*, FILE*)' declared with attribute 'warn_unused_result' [-Wunused-result]
  261 |     freopen("task.inp", "r", stdin);
      |     ~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~
joitour.cpp:262:12: warning: ignoring return value of 'FILE* freopen(const char*, const char*, FILE*)' declared with attribute 'warn_unused_result' [-Wunused-result]
  262 |     freopen("task.out", "w", stdout);
      |     ~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~
/usr/bin/ld: /tmp/ccwyd08c.o: in function `main':
stub.cpp:(.text.startup+0x0): multiple definition of `main'; /tmp/ccbETdwf.o:joitour.cpp:(.text.startup+0x0): first defined here
collect2: error: ld returned 1 exit status