Submission #985740

#TimeUsernameProblemLanguageResultExecution timeMemory
985740maomao90JOI tour (JOI24_joitour)C++17
100 / 100
595 ms172056 KiB

// Hallelujah, praise the one who set me free
// Hallelujah, death has lost its grip on me
// You have broken every chain, There's salvation in your name
// Jesus Christ, my living hope
#include <bits/stdc++.h> 
#include "joitour.h"
using namespace std;

typedef long long ll;

#ifndef DEBUG
#define cerr if (0) cerr
#endif

template <typename T = int>
struct StaticTopTree {
    struct Path {
        struct Mat {
            ll ab, a, b, c;
            void incre(ll x, ll y) {
                ab += x * b + y * a + x * y * c;
                a += x * c;
                b += y * c;
            }
            Mat& operator+=(const Mat &o) {
                ab += o.ab;
                a += o.a;
                b += o.b;
                c += o.c;
                return *this;
            }
            Mat operator+(const Mat &o) const {
                Mat res = *this;
                res += o;
                return res;
            }
        };
        ll c0, c2;
        Mat mp, mu;
        ll uc0, uc2;
    };
    struct Point {
        ll c0, c2, c02, rp, ru, uc0, uc2;
    };
    Path initVertex(int u) {
        if (a[u] == 1) {
            typename Path::Mat m = {0, 0, 0, 1};
            return {0, 0, m, m, 0, 0};
        } else {
            bool b0 = a[u] == 0, b2 = a[u] == 2;
            typename Path::Mat m = {0, 0, 0, 0};
            return {b0, b2, m, m, 0, 0};
        }
    }
    Path addVertex(int u, Point x) {
        if (a[u] == 1) {
            return {x.c0, x.c2, 
                {x.rp + x.c02, 0, 0, 1},
                {x.ru + x.c0 * x.c2, x.c0, x.c2, 1},
                x.uc0, x.uc2};
        } else {
            int c0 = x.c0 + (a[u] == 0), c2 = x.c2 + (a[u] == 2);
            return {c0, c2, {x.rp, 0, 0, 0}, {x.ru, 0, 0, 0}, x.uc0, x.uc2};
        }
    }
    Point addEdge(Path x) {
        return {x.c0, x.c2, x.c0 * x.c2, x.mp.ab, x.mu.ab,
            x.uc0 + x.mu.a, x.uc2 + x.mu.b};
    }
    Path compress(Path lx, Path rx) {
        lx.c0 += rx.c0;
        lx.c2 += rx.c2;
        lx.mp.incre(rx.c0, rx.c2);
        lx.mu.incre(rx.c0, rx.c2);
        lx.mp += rx.mp;
        lx.mu += rx.mu;
        lx.uc0 += rx.uc0;
        lx.uc2 += rx.uc2;
        return lx;
    }
    Point rake(Point lx, Point rx) {
        return {lx.c0 + rx.c0, lx.c2 + rx.c2, lx.c02 + rx.c02, lx.rp + rx.rp, lx.ru + rx.ru, lx.uc0 + rx.uc0, lx.uc2 + rx.uc2};
    }
    StaticTopTree() {}
    StaticTopTree(int n, vector<vector<int>> &adj): n(n), adj(adj) {
        lc.resize(4 * n + 5);
        rc.resize(4 * n + 5);
        p.resize(4 * n + 5);
        op.resize(4 * n + 5);
        path.resize(4 * n + 5);
        point.resize(4 * n + 5);
        ptr = n + 1;
        hld(1, -1);
        r = buildPath(1).first;
    }
    void init(vector<T> &_a) {
        a = _a;
        dfs(r);
    }
    void update(int u, T &x) {
        a[u] = x;
        apply(u);
        while (p[u]) {
            u = p[u];
            apply(u);
        }
    }
    Path query(int u) {
        Path res = path[u];
        while (p[u]) {
            if (op[p[u]] == AddEdge) {
                break;
            }
            if (lc[p[u]] == u) {
                res = compress(res, path[rc[p[u]]]);
            }
            u = p[u];
        }
        return res;
    }
    private:
    enum Op {
        InitVertex,
        AddVertex,
        AddEdge,
        Compress,
        Rake
    };
    int n;
    vector<vector<int>> adj;
    vector<T> a;
    int r, ptr;
    vector<int> lc, rc, p;
    vector<Op> op;
    vector<Path> path;
    vector<Point> point;
    int hld(int u, int p) {
        int sub = 1, mx = 0;
        for (int i = 0; i < adj[u].size(); i++) {
            int v = adj[u][i];
            if (v == p) {
                adj[u].erase(adj[u].begin() + i);
                i--;
                continue;
            }
            int vsub = hld(v, u);
            sub += vsub;
            if (vsub > mx) {
                mx = vsub;
                if (i) {
                    swap(adj[u][i], adj[u][0]);
                }
            }
        }
        return sub;
    }
    void dfs(int u) {
        if (lc[u]) {
            dfs(lc[u]);
        }
        if (rc[u]) {
            dfs(rc[u]);
        }
        apply(u);
    }
    void apply(int u) {
        if (op[u] == InitVertex) {
            path[u] = initVertex(u);
        } else if (op[u] == AddVertex) {
            path[u] = addVertex(u, point[lc[u]]);
        } else if (op[u] == AddEdge) {
            point[u] = addEdge(path[lc[u]]);
        } else if (op[u] == Compress) {
            path[u] = compress(path[lc[u]], path[rc[u]]);
        } else if (op[u] == Rake) {
            point[u] = rake(point[lc[u]], point[rc[u]]);
        }
    }
    inline void add(int u, int l, int r, Op o) {
        lc[u] = l; p[l] = u;
        rc[u] = r; p[r] = u;
        op[u] = o;
    }
    pair<int, int> merge(vector<pair<int, int>> &lst, Op o) {
        if (lst.size() == 1) {
            return lst[0];
        }
        int tot = 0;
        for (auto [u, s] : lst) {
            tot += s;
        }
        vector<pair<int, int>> lft, rht;
        for (auto [u, s] : lst) {
            (tot > s ? lft : rht).push_back({u, s});
            tot -= 2 * s;
        }
        auto [lu, ls] = merge(lft, o);
        auto [ru, rs] = merge(rht, o);
        add(ptr, lu, ru, o);
        return {ptr++, ls + rs + 1};
    }
    pair<int, int> buildPath(int u) {
        vector<pair<int, int>> lst;
        lst.push_back(buildVertex(u));
        while (!adj[u].empty()) {
            u = adj[u][0];
            lst.push_back(buildVertex(u));
        }
        return merge(lst, Compress);
    }
    pair<int, int> buildVertex(int u) {
        if (adj[u].size() <= 1) {
            op[u] = InitVertex;
            return {u, 1};
        }
        vector<pair<int, int>> lst;
        for (int i = 1; i < adj[u].size(); i++) {
            lst.push_back(buildEdge(adj[u][i]));
        }
        auto [v, s] = merge(lst, Rake);
        add(u, v, 0, AddVertex);
        return {u, s + 1};
    }
    pair<int, int> buildEdge(int u) {
        auto [v, s] = buildPath(u);
        add(ptr, v, 0, AddEdge);
        return {ptr++, s + 1};
    }
};

int n;
vector<int> a;
vector<vector<int>> adj;
StaticTopTree stt;
ll c[3];

void init(int N, vector<int> F, vector<int> U, vector<int> V, int Q) {
    n = N;
    a.resize(n + 1);
    for (int i = 0; i < n; i++) {
        a[i + 1] = F[i];
        c[a[i + 1]]++;
    }
    adj.resize(n + 1);
    for (int i = 0; i < n - 1; i++) {
        adj[U[i] + 1].push_back(V[i] + 1);
        adj[V[i] + 1].push_back(U[i] + 1);
    }
    stt = StaticTopTree(n, adj);
    stt.init(a);
}

void change(int x, int y) {
    x++;
    c[a[x]]--;
    a[x] = y;
    c[a[x]]++;
    stt.update(x, y);
}

// cnt0[1] * cnt1[1] * cnt2[1] - cnt0[u] * cnt2[u] * (f[p[u]] == 1) - (cnt0[1] - cnt0[u]) * (cnt2[1] - cnt2[u]) * (f[u] == 1)
ll num_tours() {
    StaticTopTree<int>::Path res = stt.query(1);
    return c[0] * c[1] * c[2] - res.mp.ab - (c[0] * c[1] * c[2] - 
            c[0] * (res.uc2 + res.mu.b) - c[2] * (res.uc0 + res.mu.a) +
            res.mu.ab);
}

Compilation message (stderr)

joitour.cpp: In instantiation of 'int StaticTopTree<T>::hld(int, int) [with T = int]':
joitour.cpp:94:9:   required from 'StaticTopTree<T>::StaticTopTree(int, std::vector<std::vector<int> >&) [with T = int]'
joitour.cpp:250:31:   required from here
joitour.cpp:140:27: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  140 |         for (int i = 0; i < adj[u].size(); i++) {
      |                         ~~^~~~~~~~~~~~~~~
joitour.cpp: In instantiation of 'std::pair<int, int> StaticTopTree<T>::buildVertex(int) [with T = int]':
joitour.cpp:205:23:   required from 'std::pair<int, int> StaticTopTree<T>::buildPath(int) [with T = int]'
joitour.cpp:95:13:   required from 'StaticTopTree<T>::StaticTopTree(int, std::vector<std::vector<int> >&) [with T = int]'
joitour.cpp:250:31:   required from here
joitour.cpp:218:27: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  218 |         for (int i = 1; i < adj[u].size(); i++) {
      |                         ~~^~~~~~~~~~~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...