Submission #978233

#TimeUsernameProblemLanguageResultExecution timeMemory
978233model_codeJOI tour (JOI24_joitour)C++17
100 / 100
607 ms97908 KiB
// static toptree 解 O((N + Q) log N)
#include "joitour.h"
#include <algorithm>
#include <cassert>
#include <vector>
using namespace std;

/*
    頂点 0 を根とする根付き木で考える.
    上の boundary vertex を u, 下の boundary vertex を v とする.
    頂点属性を辺属性に変換するため,頂点の情報を,その頂点から上に伸びる辺に載せる.(ダミーの頂点を頂点 0 の上に作る)
    したがって,各 cluster は F[u] が分からないものとして情報を集めることになる.
    各 cluster が持つ情報:
    f_v  : F[v]
    c_012: cluster 内の 0 - 1 - 2 パスの個数
    c_01u: cluster 内の 0 - 1 - u パスの個数
    c_u12: cluster 内の u - 1 - 2 パスの個数
    c_0u2: cluster 内の 0 - u - 2 パスの個数
    c_01v: cluster 内の 0 - 1 - v パス (v が 1 を担当しても良い) の個数
    c_v12: cluster 内の v - 1 - 2 パス (v が 1 を担当しても良い) の個数
    c_0uv: cluster 内の 0 - u - v パスの個数
    c_u1v: cluster 内の u - 1 - v パス (v が 1 を担当しても良い) の個数
    c_vu2: cluster 内の v - u - 2 パスの個数
    c_0  : cluster 内の 0 の個数
    c_2  : cluster 内の 2 の個数
*/
struct T {
    long long f_v, c_012, c_01u, c_u12, c_0u2, c_01v, c_v12, c_0uv, c_u1v, c_vu2, c_0, c_2;
};
T one(int f) {
    T x = {};
    x.f_v = f;
    if (f == 0) x.c_0++;
    if (f == 1) x.c_u1v++;
    if (f == 2) x.c_2++;
    return x;
}
T rake(const T& a, const T& b) {
    // a.u == b.u である 2 つの cluster をマージ
    return T{
        a.f_v,
        a.c_012 + b.c_012 + a.c_01u * b.c_2 + a.c_0 * b.c_u12 + b.c_0 * a.c_u12 + b.c_01u * a.c_2,
        a.c_01u + b.c_01u,
        a.c_u12 + b.c_u12,
        a.c_0u2 + b.c_0u2 + a.c_0 * b.c_2 + b.c_0 * a.c_2,
        a.c_01v + b.c_01u + b.c_0 * a.c_u1v,
        a.c_v12 + b.c_u12 + a.c_u1v * b.c_2,
        a.c_0uv + b.c_0,
        a.c_u1v,
        a.c_vu2 + b.c_2,
        a.c_0 + b.c_0,
        a.c_2 + b.c_2};
}
T compress(const T& a, const T& b) {
    // a.v == b.u である 2 つの cluster をマージ
    return T{
        b.f_v,
        a.c_012 + b.c_012 + a.c_01v * b.c_2 + a.c_0 * b.c_u12 + b.c_0 * a.c_v12 + b.c_01u * a.c_2 + b.c_0u2 * (a.f_v == 1),
        a.c_01u + b.c_01u + b.c_0 * a.c_u1v,
        a.c_u12 + b.c_u12 + a.c_u1v * b.c_2,
        a.c_0u2 + a.c_0uv * b.c_2 + b.c_0 * a.c_vu2,
        a.c_01v + b.c_01v + a.c_0 * b.c_u1v + b.c_0uv * (a.f_v == 1),
        a.c_v12 + b.c_v12 + b.c_u1v * a.c_2 + (a.f_v == 1) * b.c_vu2,
        a.c_0uv,
        a.c_u1v + b.c_u1v,
        a.c_vu2,
        a.c_0 + b.c_0,
        a.c_2 + b.c_2};
}

vector<T> D;
vector<int> P, L, R, T;

void update(int i) {
    D[i] = (T[i] ? compress : rake)(D[L[i]], D[R[i]]);
}

void init(int N, vector<int> F, vector<int> U, vector<int> V, int Q) {
    vector g(N, vector<int>{});
    for (int i = 0; i < N - 1; i++) {
        int u = U[i], v = V[i];
        g[u].push_back(v);
        g[v].push_back(u);
    }

    // static toptree の構築
    P.resize(N);
    L.resize(N);
    R.resize(N);
    T.resize(N);
    for (int f : F) D.push_back(one(f));

    vector<int> sz(N, 1);  // 部分木の大きさ
    {
        // heavy path の構築
        auto dfs = [&](auto dfs, int i) -> void {
            for (int j : g[i]) {
                // erase(g[j], i);
                g[j].erase(find(begin(g[j]), end(g[j]), i));
                dfs(dfs, j);
                sz[i] += sz[j];
            }
            if (size(g[i])) {
                // iter_swap(ranges::max_element(g[i], {}, [&](int i) { return sz[i]; }), begin(g[i]));
                auto it = max_element(begin(g[i]), end(g[i]), [&](int i, int j) { return sz[i] < sz[j]; });
                iter_swap(it, begin(g[i]));
            }
        };
        dfs(dfs, 0);
    }

    // data[i] と data[j] を rake/compress して新しいノードを作る
    auto op_D = [&](int i, int j, int t) -> int {
        const int K = size(D);
        T.push_back(t);
        L.push_back(i);
        R.push_back(j);
        P.push_back(0);
        D.emplace_back();
        P[i] = K;
        P[j] = K;
        return K;
    };

    auto merge_dfs = [&](auto merge_dfs, const vector<pair<int, int>>& a, int t) -> pair<int, int> {
        // 頂点数の合計が半分になるように分割して再帰的に rake/compress
        if (size(a) == 1) return a[0];
        int sum_s = 0;
        for (auto [i, s] : a) sum_s += s;
        vector<pair<int, int>> b, c;
        for (auto [i, s] : a) {
            (sum_s > s ? b : c).emplace_back(i, s);
            sum_s -= s * 2;
        }
        auto [i, si] = merge_dfs(merge_dfs, b, t);
        auto [j, sj] = merge_dfs(merge_dfs, c, t);
        return {op_D(i, j, t), si + sj};
    };

    auto collect_r = [&](auto collect_r, auto collect_c, int i) -> pair<int, int> {
        // 頂点 i から下に出るすべての辺を heavy 方向へ rake
        vector<pair<int, int>> childs;
        childs.emplace_back(g[i][0], 1);
        for (int j = 1; j < size(g[i]); j++) childs.push_back(collect_c(collect_r, collect_c, g[i][j]));
        return merge_dfs(merge_dfs, childs, 0);
    };

    auto collect_c = [&](auto collect_r, auto collect_c, int i) -> pair<int, int> {
        // par[i] -> i 辺とその先の heavy path を compress
        vector<pair<int, int>> path = {{i, 1}};
        while (size(g[i])) {
            path.push_back(collect_r(collect_r, collect_c, i));
            i = g[i][0];
        }
        return merge_dfs(merge_dfs, path, 1);
    };
    collect_c(collect_r, collect_c, 0);
    assert(size(D) == N * 2 - 1);

    // 木に沿って計算
    for (int i = N; i < size(D); i++) update(i);
}

void change(int X, int Y) {
    D[X] = one(Y);
    while (P[X]) update(X = P[X]);
}

long long num_tours() {
    return D.back().c_012;
}

Compilation message (stderr)

joitour.cpp: In instantiation of 'init(int, std::vector<int>, std::vector<int>, std::vector<int>, int)::<lambda(auto:3, auto:4, int)> [with auto:3 = init(int, std::vector<int>, std::vector<int>, std::vector<int>, int)::<lambda(auto:3, auto:4, int)>; auto:4 = init(int, std::vector<int>, std::vector<int>, std::vector<int>, int)::<lambda(auto:5, auto:6, int)>]':
joitour.cpp:152:37:   required from 'init(int, std::vector<int>, std::vector<int>, std::vector<int>, int)::<lambda(auto:5, auto:6, int)> [with auto:5 = init(int, std::vector<int>, std::vector<int>, std::vector<int>, int)::<lambda(auto:3, auto:4, int)>; auto:6 = init(int, std::vector<int>, std::vector<int>, std::vector<int>, int)::<lambda(auto:5, auto:6, int)>]'
joitour.cpp:157:39:   required from here
joitour.cpp:144:27: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  144 |         for (int j = 1; j < size(g[i]); j++) childs.push_back(collect_c(collect_r, collect_c, g[i][j]));
      |                         ~~^~~~~~~~~~~~
In file included from /usr/include/c++/10/cassert:44,
                 from joitour.cpp:4:
joitour.cpp: In function 'void init(int, std::vector<int>, std::vector<int>, std::vector<int>, int)':
joitour.cpp:158:20: warning: comparison of integer expressions of different signedness: 'std::vector<T>::size_type' {aka 'long unsigned int'} and 'int' [-Wsign-compare]
  158 |     assert(size(D) == N * 2 - 1);
      |            ~~~~~~~~^~~~~~~~~~~~
joitour.cpp:161:23: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<T>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  161 |     for (int i = N; i < size(D); i++) update(i);
      |                     ~~^~~~~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...