Submission #1345792

#TimeUsernameProblemLanguageResultExecution timeMemory
1345792MisterReaperJOI tour (JOI24_joitour)C++20
100 / 100
2577 ms321656 KiB
#include "joitour.h"
#include <bits/stdc++.h>

using i64 = long long;

#ifdef DEBUG
    #include "/home/ahmetalp/Desktop/Workplace/debug.h"
#else
    #define debug(...) void(23)
#endif

namespace {
    int N;
    std::vector<int> F;
    std::vector<int> U;
    std::vector<int> V;
    std::vector<std::vector<int>> adj;

    struct fenwick {
        int n;
        int all = 0;
        std::vector<int> tree;
        fenwick() {}
        fenwick(int n_) : n(n_), tree(n + 1) {}
        void init(int n_) {
            n = n_;
            tree.assign(n + 1, 0);
        }
        void modify(int p, int x) {
            all += x;
            for (p += 1; p <= n; p += p & -p) {
                tree[p] += x;
            }
        }
        void modify(int l, int r, int x) {
            modify(l, +x);
            modify(r, -x);
        }
        int get(int p) {
            int res = 0;
            for (p += 1; p; p -= p & -p) {
                res += tree[p];
            }
            return res;
        }
        int get(int l, int r) {
            // assert(l <= r);
            return get(r - 1) - get(l - 1);
        }
        int oth(int l, int r) {
            return all - get(l, r);
        }
    };

    i64 ans = 0;

    struct DS {
        i64 cnt10 = 0;
        i64 cnt12 = 0;
        i64 cnt02 = 0;
        std::map<int, i64> top_cnt12;
        std::map<int, i64> top_cnt10;
        fenwick fen0;
        fenwick fen1;
        fenwick fen2;
        int n = 0;
        int tim = 0;
        int r = -1;
        int col = -1;
        std::map<int, int> tin;
        std::map<int, int> tout;
        std::map<int, int> top;
        DS() {}
        void init(int n_, int r_) {
            n = n_;
            fen0.init(n);
            fen1.init(n);
            fen2.init(n);
            r = r_;
        }
        void open(int v) {
            // assert(!tin.contains(v));
            // tin[v] = tim++;
            tim++;
        }
        void close(int v) {
            // assert(!tout.contains(v));
            // tout[v] = tim;
        }
        void set(int v, int x, int tin_v, int tout_v, int top_v, int tin_top_v, int tout_top_v) {
            if (x < 0) {
                // closes
                x = -x - 1;
                if (v == r) {
                    // root
                    if (x == 0) {
                        ans -= cnt12;
                    } else if (x == 1) {
                        ans -= cnt02;
                    } else {
                        ans -= cnt10;
                    }
                    col = -1;
                } else {
                    if (x == 0) {
                        if (col == 2) {
                            ans -= fen1.get(tin_v);
                        }
                        if (col == 1) {
                            ans -= fen2.oth(tin_top_v, tout_top_v);
                        }
                        ans -= 1LL * fen1.get(tin_v) * fen2.oth(tin_top_v, tout_top_v);
                        ans -= (cnt12 - top_cnt12[top_v]);
                        cnt02 -= fen2.oth(tin_top_v, tout_top_v);
                        fen0.modify(tin_v, -1);
                        cnt10 -= fen1.get(tin_v);
                        top_cnt10[top_v] -= fen1.get(tin_v);
                    } else if (x == 1) {
                        if (col == 0) {
                            ans -= fen2.get(tin_v, tout_v);
                        }
                        if (col == 2) {
                            ans -= fen0.get(tin_v, tout_v);
                        }
                        ans -= 1LL * fen0.oth(tin_top_v, tout_top_v) * fen2.get(tin_v, tout_v);
                        ans -= 1LL * fen2.oth(tin_top_v, tout_top_v) * fen0.get(tin_v, tout_v);
                        fen1.modify(tin_v, tout_v, -1);
                        cnt10 -= fen0.get(tin_v, tout_v);
                        cnt12 -= fen2.get(tin_v, tout_v);
                        top_cnt10[top_v] -= fen0.get(tin_v, tout_v);
                        top_cnt12[top_v] -= fen2.get(tin_v, tout_v);
                    } else {
                        if (col == 0) {
                            ans -= fen1.get(tin_v);
                        }
                        if (col == 1) {
                            ans -= fen0.oth(tin_top_v, tout_top_v);
                        }
                        ans -= 1LL * fen1.get(tin_v) * fen0.oth(tin_top_v, tout_top_v);
                        ans -= (cnt10 - top_cnt10[top_v]);
                        cnt02 -= fen0.oth(tin_top_v, tout_top_v);
                        fen2.modify(tin_v, -1);
                        cnt12 -= fen1.get(tin_v);
                        top_cnt12[top_v] -= fen1.get(tin_v);
                    }
                }
            } else {
                if (v == r) {
                    // root
                    if (x == 0) {
                        ans += cnt12;
                    } else if (x == 1) {
                        ans += cnt02;
                    } else {
                        ans += cnt10;
                    }
                    col = x;
                } else {
                    if (x == 0) {
                        if (col == 2) {
                            ans += fen1.get(tin_v);
                        }
                        if (col == 1) {
                            ans += fen2.oth(tin_top_v, tout_top_v);
                        }
                        ans += 1LL * fen1.get(tin_v) * fen2.oth(tin_top_v, tout_top_v);
                        ans += (cnt12 - top_cnt12[top_v]);
                        cnt02 += fen2.oth(tin_top_v, tout_top_v);
                        fen0.modify(tin_v, +1);
                        cnt10 += fen1.get(tin_v);
                        top_cnt10[top_v] += fen1.get(tin_v);
                    } else if (x == 1) {
                        if (col == 0) {
                            ans += fen2.get(tin_v, tout_v);
                        }
                        if (col == 2) {
                            ans += fen0.get(tin_v, tout_v);
                        }
                        ans += 1LL * fen0.oth(tin_top_v, tout_top_v) * fen2.get(tin_v, tout_v);
                        ans += 1LL * fen2.oth(tin_top_v, tout_top_v) * fen0.get(tin_v, tout_v);
                        fen1.modify(tin_v, tout_v, +1);
                        cnt10 += fen0.get(tin_v, tout_v);
                        cnt12 += fen2.get(tin_v, tout_v);
                        top_cnt10[top_v] += fen0.get(tin_v, tout_v);
                        top_cnt12[top_v] += fen2.get(tin_v, tout_v);
                    } else {
                        if (col == 0) {
                            ans += fen1.get(tin_v);
                        }
                        if (col == 1) {
                            ans += fen0.oth(tin_top_v, tout_top_v);
                        }
                        ans += 1LL * fen1.get(tin_v) * fen0.oth(tin_top_v, tout_top_v);
                        ans += (cnt10 - top_cnt10[top_v]);
                        cnt02 += fen0.oth(tin_top_v, tout_top_v);
                        fen2.modify(tin_v, +1);
                        cnt12 += fen1.get(tin_v);
                        top_cnt12[top_v] += fen1.get(tin_v);
                    }
                }
            }
        }
    };

    std::vector<DS> centros;
    std::vector<int> act;
    std::vector<int> siz;
    std::vector<std::vector<std::array<int, 6>>> pars;
    void calc_sizes(int v, int pr) {
        siz[v] = 1;
        for (auto u : adj[v]) {
            if (!act[u] || u == pr) {
                continue;
            }
            calc_sizes(u, v);
            siz[v] += siz[u];
        }
    }
    int get_centroid(int v, int pr, int tot) {
        for (auto u : adj[v]) {
            if (!act[u] || u == pr) {
                continue;
            }
            if (siz[u] * 2 > tot) {
                return get_centroid(u, v, tot);
            }
        }
        return v;
    }
    void insert(int r, int v, int pr, int top_v, int tin_top_v, int tout_top_v) {
        int tim = centros[r].tim;
        debug(r, v, tim, tim + siz[v], top_v, tin_top_v, tout_top_v);
        pars[v].push_back({r, tim, tim + siz[v], top_v, tin_top_v, tout_top_v});
        centros[r].open(v);
        for (auto u : adj[v]) {
            if (!act[u] || u == pr) {
                continue;
            }
            int new_top_v;
            int new_tin_top_v;
            int new_tout_top_v;
            if (v == r) {
                new_top_v = u;
                new_tin_top_v = centros[r].tim;
                new_tout_top_v = centros[r].tim + siz[u];
                // centros[r].top[u] = u;
            } else {
                new_top_v = top_v;
                new_tin_top_v = tin_top_v;
                new_tout_top_v = tout_top_v;
                // centros[r].top[u] = centros[r].top[v];
            }
            insert(r, u, v, new_top_v, new_tin_top_v, new_tout_top_v);
        }
        centros[r].close(v);
    }
    void dnq(int r) {
        calc_sizes(r, r);
        int n = siz[r];
        r = get_centroid(r, r, n);
        calc_sizes(r, r);
        // debug(r, n);
        centros[r].init(n, r);
        insert(r, r, -1, -1, -1, -1);
        act[r] = 0;
        for (auto u : adj[r]) {
            if (!act[u]) {
                continue;
            }
            dnq(u);
        }
    }
};

void init(int N, std::vector<int> F, std::vector<int> U, std::vector<int> V,
          int Q) {
    ::N = N;
    ::F = F;
    ::U = U;
    ::V = V;
    adj.assign(N, {});
    for (int i = 0; i < N - 1; ++i) {
        adj[U[i]].emplace_back(V[i]);
        adj[V[i]].emplace_back(U[i]);
    }
    centros.assign(N, {});
    act.assign(N, 1);
    siz.assign(N, 1);
    pars.assign(N, {});
    dnq(0);
    for (int v = 0; v < N; ++v) {
        for (auto[r, tin_v, tout_v, top_v, tin_top_v, tout_top_v] : pars[v]) {
            centros[r].set(v, F[v], tin_v, tout_v, top_v, tin_top_v, tout_top_v);
        }
    }
}

void change(int X, int Y) {
    for (auto[r, tin_v, tout_v, top_v, tin_top_v, tout_top_v] : pars[X]) {
        centros[r].set(X, -F[X] - 1, tin_v, tout_v, top_v, tin_top_v, tout_top_v);
    }
    F[X] = Y;
    for (auto[r, tin_v, tout_v, top_v, tin_top_v, tout_top_v] : pars[X]) {
        centros[r].set(X, F[X], tin_v, tout_v, top_v, tin_top_v, tout_top_v);
    }
}

i64 num_tours() {
    return ans;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...