답안 #1082661

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
1082661 2024-09-01T04:29:29 Z Zero_OP JOI tour (JOI24_joitour) C++17
6 / 100
864 ms 196944 KB
#include "joitour.h"
#include "bits/stdc++.h"

#define dbg(x) "[" << #x << " = " << (x) << "]"

using namespace std;

const int MAX = 2e5 + 5;
const int num10 = 0;
const int num12 = 1;
const int num0 = 2;
const int num2 = 3;
const int num1 = 4;

int N, a[MAX], par[MAX], depth[MAX], tin[18][MAX], tout[18][MAX], sz[MAX], id_subtree[18][MAX], timer_subtree, timer_dfs;
vector<int> adj[MAX];
bool used[MAX];
long long cnt_total[4][MAX], cnt_subtree[4][MAX], num_ways[MAX], ans;

//num_ways mean case f[c] = 1

struct Fenwick{
    vector<int> bit;
    Fenwick(int n) : bit(n + 1, 0) {}
    Fenwick() : bit() {}

    void update(int id, int val){
        for(; id < (int)bit.size(); id += id & (-id)){
          bit[id] += val;
        }
    }

    int query(int id){
        int sum = 0;
        for(; id > 0; id -= id & (-id)){
            sum += bit[id];
        }
        return sum;
    }

    void updateRange(int l, int r, int val){
        update(l, val);
        update(r + 1, -val);
    }

    int queryRange(int l, int r){
        return query(r) - query(l - 1);
    }
};

Fenwick ft[5][MAX];

int dfs_size(int u, int pre){
    sz[u] = 1;
    for(int v : adj[u]) if(v != pre && !used[v]){
        sz[u] += dfs_size(v, u);
    }
    return sz[u];
}

int find_centroid(int u, int pre, int target){
    for(int v : adj[u]) if(v != pre && !used[v] && sz[v] > (target / 2)){
        return find_centroid(v, u, target);
    }
    return u;
}

int c;
void dfs_subtree(int u, int pre, int layer){
    tin[layer][u] = ++timer_dfs;
    id_subtree[layer][u] = timer_subtree;

    if(a[u] == 0){
        ft[num0][c].update(tin[layer][u], +1);
        ++cnt_total[num0][c];
        ++cnt_subtree[num0][timer_subtree];
    } else if(a[u] == 2){
        ft[num2][c].update(tin[layer][u], +1);
        ++cnt_total[num2][c];
        ++cnt_subtree[num2][timer_subtree];
    }

    for(int v : adj[u]) if(v != pre && !used[v]){
        dfs_subtree(v, u, layer);
        if(pre == -1) ++timer_subtree;
    }
    
    tout[layer][u] = timer_dfs;

    if(a[u] == 1){
        ft[num1][c].updateRange(tin[layer][u], tout[layer][u], +1);
        int cnt10 = ft[num0][c].queryRange(tin[layer][u], tout[layer][u]);
        int cnt12 = ft[num2][c].queryRange(tin[layer][u], tout[layer][u]);
        cnt_subtree[num10][timer_subtree] += cnt10;
        cnt_subtree[num12][timer_subtree] += cnt12;
        cnt_total[num10][c] += cnt10;
        cnt_total[num12][c] += cnt12;
    } 
}

void decompose(int u, int layer, int p){
    int size = dfs_size(u, -1);
    u = find_centroid(u, -1, size);
    used[u] = true;
    par[u] = p;
    depth[u] = layer;

    ft[num10][u] = Fenwick(size);
    ft[num12][u] = Fenwick(size);
    ft[num0][u] = Fenwick(size);
    ft[num1][u] = Fenwick(size);
    ft[num2][u] = Fenwick(size);

    c = u;
    timer_dfs = 0;
    for(int v : adj[u]) if(!used[v]){
        dfs_subtree(v, u, layer);
        num_ways[u] -= cnt_subtree[num0][timer_subtree] * cnt_subtree[num2][timer_subtree];
        ans -= cnt_subtree[num0][timer_subtree] * cnt_subtree[num12][timer_subtree];
        ans -= cnt_subtree[num2][timer_subtree] * cnt_subtree[num10][timer_subtree];
        ++timer_subtree;
    }

    num_ways[u] += cnt_total[num0][u] * cnt_total[num2][u];
    ans += cnt_total[num0][u] * cnt_total[num12][u];
    ans += cnt_total[num2][u] * cnt_total[num10][u];

    if(a[u] == 0) ans += cnt_total[num12][u];
    if(a[u] == 1) ans += num_ways[u];
    if(a[u] == 2) ans += cnt_total[num10][u];

    for(int v : adj[u]) if(!used[v]){
        decompose(v, layer + 1, u);
    }
}

void init(int N, std::vector<int> F, std::vector<int> U, std::vector<int> V, int Q) {
    ::N = N;
    for(int i = 0; i < N; ++i) a[i] = F[i];
    for(int i = 0; i < N - 1; ++i){
        int u = U[i], v = V[i];
        adj[u].push_back(v);
        adj[v].push_back(u);
    }

    decompose(0, 0, -1);
}

void change(int u, int nval) {
    if(a[u] == nval) return;
    int dep = depth[u], c = u;
    while(c != -1){
        if(c == u){
            if(a[u] == 0) ans -= cnt_total[num12][u];
            if(a[u] == 1) ans -= num_ways[u];
            if(a[u] == 2) ans -= cnt_total[num10][u];

            if(nval == 0) ans += cnt_total[num12][u];
            if(nval == 1) ans += num_ways[u];
            if(nval == 2) ans += cnt_total[num10][u];
        } else{
            int id = id_subtree[dep][u];

            ///remove data
            if(a[c] == 0) ans -= cnt_total[num12][c];
            if(a[c] == 1) ans -= num_ways[c];
            if(a[c] == 2) ans -= cnt_total[num10][c];

            num_ways[c] += cnt_subtree[num0][id] * cnt_subtree[num2][id];
            ans += cnt_subtree[num0][id] * cnt_subtree[num12][id];
            ans += cnt_subtree[num2][id] * cnt_subtree[num10][id];

            num_ways[c] -= cnt_total[num0][c] * cnt_total[num2][c];
            ans -= cnt_total[num0][c] * cnt_total[num12][c];
            ans -= cnt_total[num2][c] * cnt_total[num10][c];

            if(a[u] == 0){
                int contribution = ft[num1][c].query(tin[dep][u]);
                cnt_subtree[num10][id] -= contribution;
                --cnt_subtree[num0][id];

                cnt_total[num10][c] -= contribution;
                --cnt_total[num0][c];
                ft[num0][c].update(tin[dep][u], -1);
            } else if(a[u] == 1){
                int cnt10 = ft[num0][c].queryRange(tin[dep][u], tout[dep][u]);
                int cnt12 = ft[num2][c].queryRange(tin[dep][u], tout[dep][u]);
                cnt_subtree[num10][id] -= cnt10;
                cnt_subtree[num12][id] -= cnt12;
                cnt_total[num10][c] -= cnt10;
                cnt_total[num12][c] -= cnt12;
                ft[num1][c].updateRange(tin[dep][u], tout[dep][u], -1);
            } else{
                int contribution = ft[num1][c].query(tin[dep][u]);
                cnt_subtree[num12][id] -= contribution;
                --cnt_subtree[num2][id];
                
                cnt_total[num12][c] -= contribution;
                --cnt_total[num2][c];
                ft[num2][c].update(tin[dep][u], -1);
            }

            //end remove data
            //replace data

            if(nval == 0){
                int contribution = ft[num1][c].query(tin[dep][u]);
                cnt_subtree[num10][id] += contribution;
                ++cnt_subtree[num0][id];

                cnt_total[num10][c] += contribution;
                ++cnt_total[num0][c];
                ft[num0][c].update(tin[dep][u], -1);
            } else if(nval == 1){
                int cnt10 = ft[num0][c].queryRange(tin[dep][u], tout[dep][u]);
                int cnt12 = ft[num2][c].queryRange(tin[dep][u], tout[dep][u]);
                cnt_subtree[num10][id] += cnt10;
                cnt_subtree[num12][id] += cnt12;
                cnt_total[num10][c] += cnt10;
                cnt_total[num12][c] += cnt12;
                ft[num1][c].updateRange(tin[dep][u], tout[dep][u], +1);
            } else{
                int contribution = ft[num1][c].query(tin[dep][u]);
                cnt_subtree[num12][id] += contribution;
                ++cnt_subtree[num2][id];
                
                cnt_total[num12][c] += contribution;
                ++cnt_total[num2][c];
                ft[num2][c].update(tin[dep][u], -1);
            }

            num_ways[c] += cnt_total[num0][c] * cnt_total[num2][c];
            ans += cnt_total[num0][c] * cnt_total[num12][c];
            ans += cnt_total[num2][c] * cnt_total[num10][c];

            num_ways[c] -= cnt_subtree[num0][id] * cnt_subtree[num2][id];
            ans -= cnt_subtree[num0][id] * cnt_subtree[num12][id];
            ans -= cnt_subtree[num2][id] * cnt_subtree[num10][id];

            if(a[c] == 0) ans += cnt_total[num12][c];
            if(a[c] == 1) ans += num_ways[c];
            if(a[c] == 2) ans += cnt_total[num10][c];

            //end replace data
        }

        c = par[c]; --dep;
    }

    a[u] = nval;
}

long long num_tours() {
    return ans;
}

#ifdef LOCAL
int main(){
    ios_base::sync_with_stdio(0);
    cin.tie(0);

    freopen("task.inp", "r", stdin);
    freopen("task.out", "w", stdout);

    int N; cin >> N;
    vector<int> F(N);

    for(int i = 0; i < N; ++i) cin >> F[i];

    vector<int> U(N), V(N);
    for(int i = 0; i < N - 1; ++i) cin >> U[i] >> V[i];

    int Q; cin >> Q;
    init(N, F, U, V, Q);
    
    cout << num_tours() << "\n";
    while(Q--){
        int X, Y;
        cin >> X >> Y;
        change(X, Y);
        cout << num_tours() << '\n';
    }

    return 0;
}
#endif //LOCAL

/*


Test 1 : 
Input
3
0 1 2
0 1
1 2
0

Output : 
1

Test 2 : 
Input :
3
0 1 2
0 1
1 2
2
2 0
0 2

Output
1
0
1


Test 3 : 
Input : 
7
1 0 2 2 0 1 0
0 1
0 2
1 3
1 4
2 5
2 6
7
0 0
1 1
2 0
3 0
4 2
5 2
6 2

Ouptut : 3
0
4
4
0
4
5
5


*/
# 결과 실행 시간 메모리 Grader output
1 Correct 11 ms 30028 KB Output is correct
2 Correct 14 ms 30300 KB Output is correct
3 Correct 11 ms 30188 KB Output is correct
4 Incorrect 13 ms 30296 KB Wrong Answer [1]
5 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 11 ms 30028 KB Output is correct
2 Correct 14 ms 30300 KB Output is correct
3 Correct 11 ms 30188 KB Output is correct
4 Incorrect 13 ms 30296 KB Wrong Answer [1]
5 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 10 ms 30040 KB Output is correct
2 Correct 837 ms 190136 KB Output is correct
3 Correct 841 ms 189940 KB Output is correct
4 Correct 747 ms 188076 KB Output is correct
5 Correct 864 ms 190768 KB Output is correct
6 Correct 302 ms 178000 KB Output is correct
7 Correct 318 ms 178024 KB Output is correct
8 Correct 654 ms 179024 KB Output is correct
9 Correct 700 ms 178064 KB Output is correct
10 Correct 689 ms 175272 KB Output is correct
11 Correct 693 ms 170148 KB Output is correct
12 Correct 759 ms 186188 KB Output is correct
13 Correct 770 ms 186280 KB Output is correct
14 Correct 671 ms 186028 KB Output is correct
15 Correct 769 ms 185468 KB Output is correct
16 Correct 844 ms 194028 KB Output is correct
17 Correct 11 ms 30040 KB Output is correct
18 Correct 10 ms 30040 KB Output is correct
19 Correct 10 ms 30040 KB Output is correct
20 Correct 11 ms 30040 KB Output is correct
21 Correct 668 ms 165300 KB Output is correct
22 Correct 642 ms 165204 KB Output is correct
23 Correct 594 ms 166284 KB Output is correct
24 Correct 670 ms 167420 KB Output is correct
25 Correct 253 ms 89540 KB Output is correct
26 Correct 266 ms 89452 KB Output is correct
27 Correct 233 ms 89540 KB Output is correct
28 Correct 247 ms 89628 KB Output is correct
29 Correct 348 ms 196688 KB Output is correct
30 Correct 380 ms 196940 KB Output is correct
31 Correct 322 ms 196944 KB Output is correct
32 Correct 356 ms 196904 KB Output is correct
33 Correct 360 ms 174460 KB Output is correct
34 Correct 377 ms 174416 KB Output is correct
35 Correct 303 ms 174400 KB Output is correct
36 Correct 349 ms 174504 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 13 ms 30292 KB Output is correct
2 Correct 10 ms 30052 KB Output is correct
3 Incorrect 13 ms 30040 KB Wrong Answer [1]
4 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 10 ms 30040 KB Output is correct
2 Incorrect 11 ms 30000 KB Wrong Answer [1]
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 11 ms 30028 KB Output is correct
2 Correct 14 ms 30300 KB Output is correct
3 Correct 11 ms 30188 KB Output is correct
4 Incorrect 13 ms 30296 KB Wrong Answer [1]
5 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 11 ms 30028 KB Output is correct
2 Correct 14 ms 30300 KB Output is correct
3 Correct 11 ms 30188 KB Output is correct
4 Incorrect 13 ms 30296 KB Wrong Answer [1]
5 Halted 0 ms 0 KB -