Submission #1199135

#TimeUsernameProblemLanguageResultExecution timeMemory
1199135otariusJOI tour (JOI24_joitour)C++20
100 / 100
2324 ms322472 KiB
#include <bits/stdc++.h>
#include <bits/extc++.h>
#include "joitour.h"
using namespace __gnu_pbds;
using namespace std;

// #pragma GCC optimize("Ofast")
// #pragma GCC optimize ("unroll-loops")
// #pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native")

#define ff first
#define sc second
#define pb push_back
#define ll long long
#define pll pair<ll, ll>
#define pii pair<int, int>
#define ull unsigned long long
#define all(x) (x).begin(),(x).end()
#define rall(x) (x).rbegin(),(x).rend()

mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
mt19937_64 rngl(chrono::steady_clock::now().time_since_epoch().count());

// #define int long long
// #define int unsigned long long

// #define ordered_set(T) tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>
// #define ordered_multiset(T) tree<T, null_type, less_equal<T>, rb_tree_tag, tree_order_statistics_node_update>

// const ll mod = 1e9 + 7;
// const ll mod = 998244353;

const ll inf = 1e9;
const ll biginf = 1e18;
const int maxN = 2 * 1e5 + 15;

struct fenwick {
    vector<ll> fn;
    fenwick() : fn() {}
    fenwick(int n) : fn(n + 3, 0) {}

    void update(int x, int d) {
        for ( ; x < fn.size(); x += (x & -x)) fn[x] += d;
    }

    int get(int x) {
        int ans = 0;
        for ( ; x; x -= (x & -x)) ans += fn[x];
        return ans;
    }

    int query(int l, int r) {
        return get(r) - get(l - 1);
    }
};

fenwick fn[5][maxN];

bool used[maxN];
vector<int> g[maxN];
ll n, ans, f[maxN], sz[maxN], in[20][maxN], out[20][maxN], id[20][maxN], tim, stim, ex[maxN], par[maxN], dep[maxN];

ll cnt[5][maxN], scnt[5][maxN];

// 0 - 0
// 1 - 1
// 2 - 2
// 0-1 - 3
// 1-2 - 4

int cur_cent;
void getsize(int v, int p) {
    sz[v] = 1;
    for (int u : g[v]) {
        if (u == p || used[u]) continue;
        getsize(u, v); sz[v] += sz[u];
    }
}
int getcent(int v, int p, int need) {
    for (int u : g[v]) {;
        if (u == p || used[u] || 2 * sz[u] <= need) continue;
        return getcent(u, v, need);
    } return v;
}
void dfs(int v, int p, int d) {
    in[d][v] = ++tim;
    id[d][v] = stim;

    if (f[v] == 0) {
        fn[0][cur_cent].update(in[d][v], 1);
        cnt[0][cur_cent]++; scnt[0][stim]++;
    }

    if (f[v] == 2) {
        fn[2][cur_cent].update(in[d][v], 1);
        cnt[2][cur_cent]++; scnt[2][stim]++;
    }

    for (int u : g[v]) {
        if (u == p || used[u]) continue;
        dfs(u, v, d); if (p == -1) stim++;
    }

    out[d][v] = tim;
    // if (cur_cent == 0) cout << v << ' ' << cnt[4][cur_cent] << '\n';
    if (f[v] == 1) {
        fn[1][cur_cent].update(in[d][v], 1);
        fn[1][cur_cent].update(out[d][v] + 1, -1);

        ll cnt01 = fn[0][cur_cent].query(in[d][v], out[d][v]);
        ll cnt12 = fn[2][cur_cent].query(in[d][v], out[d][v]);
        cnt[3][cur_cent] += cnt01; cnt[4][cur_cent] += cnt12;
        scnt[3][stim] += cnt01; scnt[4][stim] += cnt12;

        // if (cur_cent == 0) cout << v << ' ' << cnt[4][cur_cent] << '\n';

        // cout << cnt12 << ' ';
    }
}

void decompose(int v, int d, int p) {

    // cout << cnt[4][v] << ' ';
    getsize(v, -1);
    int cur = sz[v];
    v = getcent(v, -1, cur);
    // cout << cnt[4][v] << ' ';


    used[v] = 1; par[v] = p; dep[v] = d;

    for (int i = 0; i < 5; i++)
        fn[i][v] = fenwick(cur);
    // cout << cnt[4][v] << ' ';

    tim = 0; cur_cent = v;
    for (int u : g[v]) {
        if (used[u]) continue;
        dfs(u, v, d);

        ex[v] -= scnt[0][stim] * scnt[2][stim];
        ans -= scnt[3][stim] * scnt[2][stim];
        ans -= scnt[0][stim] * scnt[4][stim];


        // cout << cnt[4][v] << ' ';

        stim++;
    } //cout << '\n';

    ex[v] += cnt[0][v] * cnt[2][v];
    ans += cnt[0][v] * cnt[4][v];
    ans += cnt[3][v] * cnt[2][v];
    

    // cout << cnt[0][v] << ' ' << cnt[1][v] << ' ' << cnt[2][v] << ' ' << cnt[3][v] << ' ' << cnt[4][v] << '\n';

    // cout << cnt[0][v] * cnt[2][v] << '\n';
    // cout << cnt[0][v] * cnt[4][v] << '\n';
    // cout << cnt[3][v] * cnt[2][v] << '\n';

    if (f[v] == 0) ans += cnt[4][v];
    if (f[v] == 1) ans += ex[v];
    if (f[v] == 2) ans += cnt[3][v];

    for (int u : g[v]) {
        if (!used[u]) decompose(u, d + 1, v);
    }
}
void init(int _n, vector<int> _f, vector<int> _u, vector<int> _v, int _q) {
    n = _n;
    for (int i = 0; i < n; i++) f[i] = _f[i];
    for (int i = 0; i < n - 1; i++) {
        g[_u[i]].pb(_v[i]);
        g[_v[i]].pb(_u[i]);
    }

    decompose(0, 0, -1);
}

































// 0 - 0
// 1 - 1
// 2 - 2
// 0-1 - 3
// 1-2 - 4
void change(int v, int x) {
    if (f[v] == x) return;
    int d = dep[v], cur = v;
    while (cur != -1) {
        if (cur == v) {
            if (f[v] == 0) ans -= cnt[4][v];
            if (f[v] == 1) ans -= ex[v];
            if (f[v] == 2) ans -= cnt[3][v];

            if (x == 0) ans += cnt[4][v];
            if (x == 1) ans += ex[v];
            if (x == 2) ans += cnt[3][v];
        } else {
            int sid = id[d][v];
            if (f[cur] == 0) ans -= cnt[4][cur];
            if (f[cur] == 1) ans -= ex[cur];
            if (f[cur] == 2) ans -= cnt[3][cur];

            ex[cur] -= cnt[0][cur] * cnt[2][cur];
            ans -= cnt[0][cur] * cnt[4][cur];
            ans -= cnt[3][cur] * cnt[2][cur];

            ex[cur] += scnt[0][sid] * scnt[2][sid];
            ans += scnt[0][sid] * scnt[4][sid];
            ans += scnt[3][sid] * scnt[2][sid];

            if (f[v] == 0) {
                int cntt = fn[1][cur].get(in[d][v]);
                scnt[3][sid] -= cntt;
                scnt[0][sid]--;

                cnt[3][cur] -= cntt;
                cnt[0][cur]--;

                fn[0][cur].update(in[d][v], -1);
            } else if (f[v] == 2) {
                int cntt = fn[1][cur].get(in[d][v]);
                scnt[4][sid] -= cntt;
                scnt[2][sid]--;

                cnt[4][cur] -= cntt;
                cnt[2][cur]--;

                fn[2][cur].update(in[d][v], -1);
            } else {
                int cnt01 = fn[0][cur].query(in[d][v], out[d][v]);
                int cnt12 = fn[2][cur].query(in[d][v], out[d][v]);

                scnt[3][sid] -= cnt01; scnt[4][sid] -= cnt12;
                cnt[3][cur] -= cnt01; cnt[4][cur] -= cnt12;

                fn[1][cur].update(in[d][v], -1);
                fn[1][cur].update(out[d][v] + 1, 1);
            }



            if (x == 0) {
                int cntt = fn[1][cur].get(in[d][v]);
                scnt[3][sid] += cntt;
                scnt[0][sid]++;

                cnt[3][cur] += cntt;
                cnt[0][cur]++;

                fn[0][cur].update(in[d][v], 1);
            } else if (x == 2) {
                int cntt = fn[1][cur].get(in[d][v]);
                scnt[4][sid] += cntt;
                scnt[2][sid]++;

                cnt[4][cur] += cntt;
                cnt[2][cur]++;

                fn[2][cur].update(in[d][v], 1);
            } else {
                int cnt01 = fn[0][cur].query(in[d][v], out[d][v]);
                int cnt12 = fn[2][cur].query(in[d][v], out[d][v]);

                scnt[3][sid] += cnt01; scnt[4][sid] += cnt12;
                cnt[3][cur] += cnt01; cnt[4][cur] += cnt12;

                fn[1][cur].update(in[d][v], 1);
                fn[1][cur].update(out[d][v] + 1, -1);
            }

            ex[cur] += cnt[0][cur] * cnt[2][cur];
            ans += cnt[0][cur] * cnt[4][cur];
            ans += cnt[3][cur] * cnt[2][cur];

            ex[cur] -= scnt[0][sid] * scnt[2][sid];
            ans -= scnt[0][sid] * scnt[4][sid];
            ans -= scnt[3][sid] * scnt[2][sid];

            if (f[cur] == 0) ans += cnt[4][cur];
            if (f[cur] == 1) ans += ex[cur];
            if (f[cur] == 2) ans += cnt[3][cur];
        }

        cur = par[cur]; d--;
    }

    f[v] = x;
}
ll num_tours() {
    return ans;
}

// int main() {
//   int N;
//   assert(scanf("%d", &N) == 1);

//   std::vector<int> F(N);
//   for (int i = 0; i < N; i++) {
//     assert(scanf("%d", &F[i]) == 1);
//   }

//   std::vector<int> U(N - 1), V(N - 1);
//   for (int j = 0; j < N - 1; j++) {
//     assert(scanf("%d %d", &U[j], &V[j]) == 2);
//   }

//   int Q;
//   assert(scanf("%d", &Q) == 1);

//   init(N, F, U, V, Q);
//   printf("%lld\n", num_tours());
//   fflush(stdout);

//   for (int k = 0; k < Q; k++) {
//     int X, Y;
//     assert(scanf("%d %d", &X, &Y) == 2);

//     change(X, Y);
//     printf("%lld\n", num_tours());
//     fflush(stdout);
//   }
// }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...