제출 #1000202

#제출 시각아이디문제언어결과실행 시간메모리
1000202caterpillowJOI tour (JOI24_joitour)C++17
컴파일 에러
0 ms0 KiB
#include <bits/stdc++.h>

#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")

using namespace std;

using ll = long long;
using pl = pair<ll, ll>;
#define vt vector
#define f first
#define s second
#define all(x) x.begin(), x.end() 
#define pb push_back
#define FOR(i, a, b) for (int i = (a); i < (b); i++)
#define ROF(i, a, b) for (int i = (b) - 1; i >= (a); i--)
#define F0R(i, b) FOR (i, 0, b)
#define endl '\n'
#define debug(x) do{auto _x = x; cerr << #x << " = " << _x << endl;} while(0)
const ll INF = 1e18;

struct SegTree {
    int n;
    vt<int> seg; 
    void init(int _n) {
        for (n = 1; n < _n; n *= 2);
        seg.resize(2 * n);
    }
    void upd(int i, int v) {
        i += n;
        seg[i] = v;
        while (i > 1) {
            i /= 2;
            seg[i] = seg[2 * i] + seg[2 * i + 1];
        }
    }
    int query(int l, int r) {
        int res = 0;
        for (l += n, r += n + 1; l < r; l /= 2, r /= 2) {
            if (l & 1) res += seg[l++];
            if (r & 1) res += seg[--r];
        }
        return res;
    }
};

/*

centroid decomposition on the tree
count the # of paths that go through some root such that there is a 0 and 1 in one subtree
and a 2 in another

consider the total # of tours that go through some root node
consider tuples of ordered values in a tour of {subtree 1, root, subtree 2}
cases are:
    1. {{0, 1}, {}, {2}}
    2. {{0}, {}, {1, 2}}
    3. {{0}, {1}, {2}}
    4. {{0, 1}, {2}, {}}
    5. {{}, {0}, {1, 2}}

when performing an update, we need to be able to subtract off the old paths that contained
the existing restaurant and add the new paths that use it

we need to be able to:
    1. query # of 1's between two vertices

we need to remap node labels for each centroid decomp

*/

using pi = pair<int, int>;

struct Centroid {
    int root; 
    vt<int> tout; // dfs time out
    SegTree tree0, tree2; // euler tour for counting # of 0's and 2's in a node's subtree

    vt<ll> cnt10, cnt12; // # of pairs of 10 and 12 in each subtree
    vt<int> cnt0, cnt2;
    ll tot10, tot12, tot0, tot2;
    vt<int> subtree;
    vt<pi> subtree_times;

    ll ans;
};

int n, q;
vt<vt<int>> adj;
vt<Centroid> centroids;
vt<vt<pi>> parents; // centroid root, time
vt<int> colour;
ll gans;

vt<int> sz;
vt<bool> done;

int dfs_sz(int u, int par = -1) {
    sz[u] = 1;
    for (int v : adj[u]) {
        if (v == par || done[v]) continue;
        sz[u] += dfs_sz(v, u);
    }
    return sz[u];
}

int find_centroid(int u, int tsz, int par = -1) {
    for (int v : adj[u]) {
        if (v == par || done[v]) continue;
        if (sz[v] * 2 > tsz) return find_centroid(v, tsz, u);
    }
    return u;
}

void dfs_time(int u, int& t, Centroid& obj, ll ones, int par = -1, int subtree = -1) {
    int tin = ++t;
    parents[u].pb({obj.root, t});
    obj.subtree[tin] = subtree;

    if (colour[u] == 0) obj.tree0.upd(tin, 1);
    if (colour[u] == 2) obj.tree2.upd(tin, 1);
    if (subtree != -1) { 
        // update 10's and 12's 
        if (colour[u] == 0) {
            obj.cnt0[subtree]++;
            obj.cnt10[subtree] += ones;
        } else if (colour[u] == 1) {
            ones++;
        } else {
            obj.cnt2[subtree]++;
            obj.cnt12[subtree] += ones;
        }
    }

    F0R (i, adj[u].size()) {
        int v = adj[u][i];
        if (v == par || done[v]) continue;
        if (subtree == -1) {
            obj.subtree_times[i].f = t;
        }
        dfs_time(v, t, obj, ones, u, subtree == -1 ? i : subtree);
        if (subtree == -1) {
            obj.subtree_times[i].f = t;
        }
    }

    obj.tout[tin] = t;
}

void decomp(int u = 0) {
    int tsz = dfs_sz(u);
    int r = find_centroid(u, tsz);

    Centroid& obj = centroids[r];
    obj.root = r;
    obj.tout = obj.subtree = vt<int>(tsz);

    obj.cnt0 = obj.cnt2 = vt<int>(adj[r].size());
    obj.cnt10 = obj.cnt12 = vt<ll>(adj[r].size());
    obj.subtree_times.resize(adj[r].size());

    obj.tree0.init(tsz);
    obj.tree2.init(tsz);

    int t = -1;
    dfs_time(r, t, obj, 0);

    obj.tot0 = accumulate(all(obj.cnt0), 0ll);
    obj.tot2 = accumulate(all(obj.cnt2), 0ll);
    obj.tot10 = accumulate(all(obj.cnt10), 0ll);
    obj.tot12 = accumulate(all(obj.cnt12), 0ll);

    // calculate answer
    F0R (i, adj[r].size()) {
        int v = adj[r][i];
        if (done[v]) continue;

        obj.ans += 1ll * obj.cnt10[i] * (obj.tot2 - obj.cnt2[i]);
        obj.ans += 1ll * obj.cnt0[i] * (obj.tot12 - obj.cnt12[i]);
        if (colour[r] == 1) obj.ans += 1ll * obj.cnt0[i] * (obj.tot2 - obj.cnt2[i]);

    }
    if (colour[r] == 0) obj.ans += obj.tot12;
    if (colour[r] == 2) obj.ans += obj.tot10;

    gans += obj.ans;

    done[r] = true;
    for (int v : adj[r]) {
        if (!done[v]) decomp(v);
    }
}

struct HLD {
    int t;
    vt<int> sz, pos, par, root, depth;
    vt<vt<int>> adj;
    SegTree seg;
    void init(vt<vt<int>>& _adj) {
        t = 0;
        sz = pos = par = root = depth = vt<int>(n);
        adj = _adj;
        seg.init(n);
    }
    int dfs_sz(int u) {
        sz[u] = 1;
        for (int& v : adj[u]) {
            par[v] = u;
            depth[v] = depth[u] + 1;
            adj[v].erase(find(all(adj[v]), u));
            sz[u] += dfs_sz(v);
            if (sz[v] > sz[adj[u][0]]) swap(v, adj[u][0]);
        }
        return sz[u];
    }
    void dfs_hld(int u) {
        pos[u] = t++;
        for (int& v : adj[u]) {
            root[v] = (v == adj[u][0] ? root[u] : v);
            dfs_hld(v);
        }
    }
    void gen() {
        dfs_sz(0);
        dfs_hld(0);
    }
    int query(int u, int v) {
        int res = 0;
        while (root[u] != root[v]) {
            if (depth[root[u]] > depth[root[v]]) swap(u, v);
            res += seg.query(pos[root[v]], pos[v]);
            v = par[root[v]];
        }
        if (depth[u] > depth[v]) swap(u, v);
        return res + seg.query(pos[u], pos[v]);
    }
    void upd(int u, int v) {
        seg.upd(pos[u], v);
    }
};

HLD hld;

void upd(Centroid& obj, int u, int tin, int prev_c, int new_c) {

    int i = obj.subtree[tin];
    int tout = obj.tout[tin];

    ll prev_ans = obj.ans;

    // handle removal

    // not root
    if (u != obj.root) {

        int subroot = adj[obj.root][i];

        // subtract answer 
        if (prev_c == 0) {

            ll par1s = hld.query(u, subroot) - (colour[u] == 1);
            
            obj.ans -= par1s * (obj.tot2 - obj.cnt2[i]);
            obj.ans -= obj.tot12 - obj.cnt12[i];
            if (colour[obj.root] == 1) obj.ans -= obj.tot2 - obj.cnt2[i];
            if (colour[obj.root] == 2) obj.ans -= par1s;

            // update counts
            obj.cnt0[i]--;
            obj.tot0--;
            obj.tree0.upd(tin, 0);
            obj.tot10 -= par1s;
            obj.cnt10[i] -= par1s;
        } else if (prev_c == 1) {
            int t0 = obj.tree0.query(tin, tout);
            int t2 = obj.tree2.query(tin, tout);
            obj.ans -= t0 * (obj.tot2 - obj.cnt2[i]);
            obj.ans -= t2 * (obj.tot0 - obj.cnt0[i]);
            if (colour[obj.root] == 2) obj.ans -= t0;
            if (colour[obj.root] == 0) obj.ans -= t2;

            // upd
            obj.cnt10[i] -= t0;
            obj.tot10 -= t0;
            obj.cnt12[i] -= t2;
            obj.tot12 -= t2;
        } else {

            ll par1s = hld.query(u, subroot) - (colour[u] == 1);

            obj.ans -= par1s * (obj.tot0 - obj.cnt0[i]);
            obj.ans -= obj.tot10 - obj.cnt10[i];
            if (colour[obj.root] == 1) obj.ans -= obj.tot0 - obj.cnt0[i];
            if (colour[obj.root] == 0) obj.ans -= par1s;

            // upd
            obj.cnt2[i]--;
            obj.tot2--;
            obj.tree2.upd(tin, 0);
            obj.tot12 -= par1s;
            obj.cnt12[i] -= par1s;
        }
    } else {

        if (prev_c == 0) {
            obj.ans -= obj.tot12;
        } else if (prev_c == 1) {
            ll sub = 0;
            F0R (j, adj[obj.root].size()) { 
                sub += obj.cnt0[j] * (obj.tot2 - obj.cnt2[j]);
            }
            obj.ans -= 1ll * sub;
        } else {
            obj.ans -= obj.tot10;
        }
    }

    // now handle addition

    if (u != obj.root) {

        int subroot = adj[obj.root][i];

        // add answer 
        if (new_c == 0) {

            ll par1s = hld.query(u, subroot) - (colour[u] == 1);

            obj.ans += par1s * (obj.tot2 - obj.cnt2[i]);
            obj.ans += obj.tot12 - obj.cnt12[i];
            if (colour[obj.root] == 1) obj.ans += obj.tot2 - obj.cnt2[i];
            if (colour[obj.root] == 2) obj.ans += par1s;

            // update counts
            obj.cnt0[i]++;
            obj.tot0++;
            obj.tree0.upd(tin, 1);
            obj.tot10 += par1s;
            obj.cnt10[i] += par1s;
        } else if (new_c == 1) {
            int t0 = obj.tree0.query(tin, tout);
            int t2 = obj.tree2.query(tin, tout);
            obj.ans += t0 * (obj.tot2 - obj.cnt2[i]);
            obj.ans += t2 * (obj.tot0 - obj.cnt0[i]);
            if (colour[obj.root] == 2) obj.ans += t0;
            if (colour[obj.root] == 0) obj.ans += t2;

            // upd
            obj.cnt10[i] += t0;
            obj.tot10 += t0;
            obj.cnt12[i] += t2;
            obj.tot12 += t2;
        } else {

            ll par1s = hld.query(u, subroot) - (colour[u] == 1);

            obj.ans += par1s * (obj.tot0 - obj.cnt0[i]);
            obj.ans += obj.tot10 - obj.cnt10[i];
            if (colour[obj.root] == 1) obj.ans += obj.tot0 - obj.cnt0[i];
            if (colour[obj.root] == 0) obj.ans += par1s;

            // upd
            obj.cnt2[i]++;
            obj.tot2++;
            obj.tree2.upd(tin, 1);
            obj.tot12 += par1s;
            obj.cnt12[i] += par1s;
        }
    } else {
        if (new_c == 0) {
            obj.ans += obj.tot12;
        } else if (new_c == 1) {
            ll sub = 0;
            F0R (j, adj[obj.root].size()) {
                sub += obj.cnt0[j] * (obj.tot2 - obj.cnt2[j]);
            }
            obj.ans += 1ll * sub;
        } else {
            obj.ans += obj.tot10;
        }
    }

    gans -= prev_ans;
    gans += obj.ans;

}

void change(int u, int c) {
    if (c == colour[u]) return;
    for (auto [cent, tin] : parents[u]) {
        upd(centroids[cent], u, tin, colour[u], c);
    }
    if (colour[u] == 1) hld.upd(u, 0);
    colour[u] = c;
    if (c == 1) hld.upd(u, 1);
}

void init(int N, vt<int> F, vt<int> U, vt<int> V, int Q) {
    n = N;
    q = Q;
    colour = F;
    parents.resize(n);
    centroids.resize(n);
    adj.resize(n);
    F0R (i, n - 1) {
        adj[U[i]].pb(V[i]);
        adj[V[i]].pb(U[i]);
    }

    hld.init(adj);
    hld.gen();

    F0R (i, n) {
        if (colour[i] == 1) hld.upd(i, 1);
    }

    sz.resize(n);
    done.resize(n);

    gans = 0;

    decomp();
}

ll num_tours() {
    return gans;
}

main() {
    
    cin.tie(0)->sync_with_stdio(0);

    int n, q; 
    cin >> n;
    vt<int> col(n);
    vt<int> a(n - 1), b(n - 1);
    F0R (i, n) cin >> col[i];
    F0R (i, n - 1) cin >> a[i] >> b[i];
    cin >> q;
    init(n, col, a, b, q);
    cout << num_tours() << endl;
    F0R (i, q) {
        int u, c;
        cin >> u >> c;
        change(u, c);
        cout << num_tours() << endl;
    }
}

컴파일 시 표준 에러 (stderr) 메시지

joitour.cpp: In function 'void dfs_time(int, int&, Centroid&, ll, int, int)':
joitour.cpp:15:42: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
   15 | #define FOR(i, a, b) for (int i = (a); i < (b); i++)
      |                                          ^
joitour.cpp:17:19: note: in expansion of macro 'FOR'
   17 | #define F0R(i, b) FOR (i, 0, b)
      |                   ^~~
joitour.cpp:135:5: note: in expansion of macro 'F0R'
  135 |     F0R (i, adj[u].size()) {
      |     ^~~
joitour.cpp: In function 'void decomp(int)':
joitour.cpp:15:42: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
   15 | #define FOR(i, a, b) for (int i = (a); i < (b); i++)
      |                                          ^
joitour.cpp:17:19: note: in expansion of macro 'FOR'
   17 | #define F0R(i, b) FOR (i, 0, b)
      |                   ^~~
joitour.cpp:174:5: note: in expansion of macro 'F0R'
  174 |     F0R (i, adj[r].size()) {
      |     ^~~
joitour.cpp: In function 'void upd(Centroid&, int, int, int, int)':
joitour.cpp:15:42: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
   15 | #define FOR(i, a, b) for (int i = (a); i < (b); i++)
      |                                          ^
joitour.cpp:17:19: note: in expansion of macro 'FOR'
   17 | #define F0R(i, b) FOR (i, 0, b)
      |                   ^~~
joitour.cpp:309:13: note: in expansion of macro 'F0R'
  309 |             F0R (j, adj[obj.root].size()) {
      |             ^~~
joitour.cpp:15:42: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
   15 | #define FOR(i, a, b) for (int i = (a); i < (b); i++)
      |                                          ^
joitour.cpp:17:19: note: in expansion of macro 'FOR'
   17 | #define F0R(i, b) FOR (i, 0, b)
      |                   ^~~
joitour.cpp:374:13: note: in expansion of macro 'F0R'
  374 |             F0R (j, adj[obj.root].size()) {
      |             ^~~
joitour.cpp: At global scope:
joitour.cpp:429:1: warning: ISO C++ forbids declaration of 'main' with no type [-Wreturn-type]
  429 | main() {
      | ^~~~
/usr/bin/ld: /tmp/ccB5h0IK.o: in function `main':
stub.cpp:(.text.startup+0x0): multiple definition of `main'; /tmp/ccwdSLPL.o:joitour.cpp:(.text.startup+0x0): first defined here
collect2: error: ld returned 1 exit status