제출 #1195874

#제출 시각아이디문제언어결과실행 시간메모리
1195874vjudge2JOI tour (JOI24_joitour)C++20
6 / 100
1456 ms287608 KiB
#include "joitour.h"
#include <bits/stdc++.h>
using namespace std;

using i32 = int32_t;
#define int long long

const int N = 2e5 + 5;
const int LG = 20;

int n, bl[N], freq[N], sz[N], tot, timer = 0, tin[N], tout[N], f[N], ans[N], total = 0;
vector<int> adj[N], anc[N];
map<int, pair<int, int>> cd[N];
int par[N][20];

map<pair<int, int>, array<int, 5>> mp; // cent, child, 0, 1, 2, 01, 21
array<int, 8> info[N]; // s0, s1, s2, s01, s21, p(0,21), p(2,01), p(0,2);

struct Fenwick {
    int maxn;
    vector<vector<int>> fen;
    void init(int sz) {
        fen.resize(3);
        maxn = sz;
        for (int i = 0; i < 3; i++) fen[i].assign(maxn + 1, 0);
    }
    void upd(int f, int id, int k) {
        for (; id <= maxn; id += (id & -id)) fen[f][id] += k;
    }
    int pref(int f, int id) {
        int res = 0;
        for (; id; id -= (id & -id)) res += fen[f][id];
        return res;
    }
    int query(int f, int l, int r) {
        return pref(f, r) - pref(f, l - 1);
    }
} fen[N];

void get_sz(int u, int p) {
    sz[u] = 1;
    for (auto& v : adj[u]) {
        if (v == p || bl[v]) continue;
        get_sz(v, u);
        sz[u] += sz[v];
    }
}

int get_cent(int u, int p) {
    int mx = 0, res = -1;
    for (auto& v : adj[u]) {
        if (v == p || bl[v]) continue;
        res = max(res, get_cent(v, u));
        mx = max(mx, sz[v]);
    }
    mx = max(mx, tot - sz[u]);
    if (mx <= tot / 2) return u;
    return res;
}

vector<int> node, tmp;

void dfs(int u, int p, int cent) {
    tin[u] = ++timer;
    anc[u].push_back(cent);
    node.push_back(u);
    for (auto& v : adj[u]) {
        if (v == p || bl[v]) continue;
        dfs(v, u, cent);
    }
    tout[u] = timer;
    // cd[cent][u] = {tin[u], tout[u]};
}

void dadfs(int u, int p, int rt, int cent) {
    tmp.push_back(u);
    par[u][anc[u].size() - 1] = rt;
    for (auto& v : adj[u]) {
        if (v == p || bl[v] || v == cent) continue;
        dadfs(v, u, rt, cent);
    }
}

void calc(int u) {
    ans[u] = 0;
    if (f[u] == 0) ans[u] += info[u][4];
    if (f[u] == 1) ans[u] += info[u][0] * info[u][2] - info[u][7];
    if (f[u] == 2) ans[u] += info[u][3];
    ans[u] += info[u][0] * info[u][4] - info[u][5];
    ans[u] += info[u][2] * info[u][3] - info[u][6];
}

int cnter = 0;

void dnc(int u) {
    get_sz(u, -1);
    tot = sz[u];
    cnter += tot;
    int cent = get_cent(u, -1);
    timer = 0;
    node.clear();
    dfs(cent, -1, cent);
    fen[cent].init(timer);
    for (auto& x : node) fen[cent].upd(f[x], tin[x], 1);
    vector<vector<int>> ps(3, vector<int> (timer+1, 0));
    for (auto& x : node) ps[f[x]][tin[x]]++;
    for (int i = 0; i < 3; i++) for (int j = 1; j <= timer; j++) ps[i][j] += ps[i][j - 1];
    // for (int i = 0; i < 3; i++) {
    //     for (int j = 1; j <= timer; j++) {
    //         fen[cent].fen[i][j] += ps[i][j];
    //         int nxt = j + (j & -j);
    //         if (nxt <= timer) fen[cent].fen[i][nxt] += fen[cent].fen[i][j];
    //     }
    // }
    for (int i = 0; i < 8; i++) info[cent][i] = 0;
    for (auto& v : adj[cent]) if (!bl[v]) {
        tmp.clear();
        dadfs(v, -1, v, cent);
        array<int, 5> arr = {0, 0, 0, 0, 0};
        for (int i = 0; i < 3; i++) arr[i] = ps[i][tout[v]] - ps[i][tin[v] - 1];
        for (auto& x : tmp) {
            if (f[x] == 1) {
                arr[3] += ps[0][tout[x]] - ps[0][tin[x] - 1];
                arr[4] += ps[2][tout[x]] - ps[2][tin[x] - 1];
            }
        }
        mp[{cent, v}] = arr;
        for (int i = 0; i < 5; i++) info[cent][i] += arr[i];
        info[cent][5] += arr[0] * arr[4];
        info[cent][6] += arr[2] * arr[3];
        info[cent][7] += arr[0] * arr[2];
    }
    calc(cent);
    bl[cent] = 1;
    for (auto& v : adj[cent]) if (!bl[v]) dnc(v);
}

void init(i32 _N, std::vector<i32> F, std::vector<i32> U, std::vector<i32> V,
          i32 Q) {
    n = _N;
    for (int i = 0; i < n - 1; i++) {
        adj[U[i]].push_back(V[i]);
        adj[V[i]].push_back(U[i]);
    }
    for (int i = 0; i < n; i++) f[i] = F[i];
    dnc(0);
    for (int i = 0; i < n; i++) total += ans[i];
}

void change(i32 X, i32 Y) {
    f[X] = Y;
    total -= ans[X];
    for (auto& cent : anc[X]) {

    }
    total += ans[X];
}

int num_tours() {
    return total;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...