# | 제출 시각 | 아이디 | 문제 | 언어 | 결과 | 실행 시간 | 메모리 |
---|---|---|---|---|---|---|---|
1000202 | caterpillow | JOI tour (JOI24_joitour) | C++17 | 0 ms | 0 KiB |
이 제출은 이전 버전의 oj.uz에서 채점하였습니다. 현재는 제출 당시와는 다른 서버에서 채점을 하기 때문에, 다시 제출하면 결과가 달라질 수도 있습니다.
#include <bits/stdc++.h>
#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
using namespace std;
using ll = long long;
using pl = pair<ll, ll>;
#define vt vector
#define f first
#define s second
#define all(x) x.begin(), x.end()
#define pb push_back
#define FOR(i, a, b) for (int i = (a); i < (b); i++)
#define ROF(i, a, b) for (int i = (b) - 1; i >= (a); i--)
#define F0R(i, b) FOR (i, 0, b)
#define endl '\n'
#define debug(x) do{auto _x = x; cerr << #x << " = " << _x << endl;} while(0)
const ll INF = 1e18;
struct SegTree {
int n;
vt<int> seg;
void init(int _n) {
for (n = 1; n < _n; n *= 2);
seg.resize(2 * n);
}
void upd(int i, int v) {
i += n;
seg[i] = v;
while (i > 1) {
i /= 2;
seg[i] = seg[2 * i] + seg[2 * i + 1];
}
}
int query(int l, int r) {
int res = 0;
for (l += n, r += n + 1; l < r; l /= 2, r /= 2) {
if (l & 1) res += seg[l++];
if (r & 1) res += seg[--r];
}
return res;
}
};
/*
centroid decomposition on the tree
count the # of paths that go through some root such that there is a 0 and 1 in one subtree
and a 2 in another
consider the total # of tours that go through some root node
consider tuples of ordered values in a tour of {subtree 1, root, subtree 2}
cases are:
1. {{0, 1}, {}, {2}}
2. {{0}, {}, {1, 2}}
3. {{0}, {1}, {2}}
4. {{0, 1}, {2}, {}}
5. {{}, {0}, {1, 2}}
when performing an update, we need to be able to subtract off the old paths that contained
the existing restaurant and add the new paths that use it
we need to be able to:
1. query # of 1's between two vertices
we need to remap node labels for each centroid decomp
*/
using pi = pair<int, int>;
struct Centroid {
int root;
vt<int> tout; // dfs time out
SegTree tree0, tree2; // euler tour for counting # of 0's and 2's in a node's subtree
vt<ll> cnt10, cnt12; // # of pairs of 10 and 12 in each subtree
vt<int> cnt0, cnt2;
ll tot10, tot12, tot0, tot2;
vt<int> subtree;
vt<pi> subtree_times;
ll ans;
};
int n, q;
vt<vt<int>> adj;
vt<Centroid> centroids;
vt<vt<pi>> parents; // centroid root, time
vt<int> colour;
ll gans;
vt<int> sz;
vt<bool> done;
int dfs_sz(int u, int par = -1) {
sz[u] = 1;
for (int v : adj[u]) {
if (v == par || done[v]) continue;
sz[u] += dfs_sz(v, u);
}
return sz[u];
}
int find_centroid(int u, int tsz, int par = -1) {
for (int v : adj[u]) {
if (v == par || done[v]) continue;
if (sz[v] * 2 > tsz) return find_centroid(v, tsz, u);
}
return u;
}
void dfs_time(int u, int& t, Centroid& obj, ll ones, int par = -1, int subtree = -1) {
int tin = ++t;
parents[u].pb({obj.root, t});
obj.subtree[tin] = subtree;
if (colour[u] == 0) obj.tree0.upd(tin, 1);
if (colour[u] == 2) obj.tree2.upd(tin, 1);
if (subtree != -1) {
// update 10's and 12's
if (colour[u] == 0) {
obj.cnt0[subtree]++;
obj.cnt10[subtree] += ones;
} else if (colour[u] == 1) {
ones++;
} else {
obj.cnt2[subtree]++;
obj.cnt12[subtree] += ones;
}
}
F0R (i, adj[u].size()) {
int v = adj[u][i];
if (v == par || done[v]) continue;
if (subtree == -1) {
obj.subtree_times[i].f = t;
}
dfs_time(v, t, obj, ones, u, subtree == -1 ? i : subtree);
if (subtree == -1) {
obj.subtree_times[i].f = t;
}
}
obj.tout[tin] = t;
}
void decomp(int u = 0) {
int tsz = dfs_sz(u);
int r = find_centroid(u, tsz);
Centroid& obj = centroids[r];
obj.root = r;
obj.tout = obj.subtree = vt<int>(tsz);
obj.cnt0 = obj.cnt2 = vt<int>(adj[r].size());
obj.cnt10 = obj.cnt12 = vt<ll>(adj[r].size());
obj.subtree_times.resize(adj[r].size());
obj.tree0.init(tsz);
obj.tree2.init(tsz);
int t = -1;
dfs_time(r, t, obj, 0);
obj.tot0 = accumulate(all(obj.cnt0), 0ll);
obj.tot2 = accumulate(all(obj.cnt2), 0ll);
obj.tot10 = accumulate(all(obj.cnt10), 0ll);
obj.tot12 = accumulate(all(obj.cnt12), 0ll);
// calculate answer
F0R (i, adj[r].size()) {
int v = adj[r][i];
if (done[v]) continue;
obj.ans += 1ll * obj.cnt10[i] * (obj.tot2 - obj.cnt2[i]);
obj.ans += 1ll * obj.cnt0[i] * (obj.tot12 - obj.cnt12[i]);
if (colour[r] == 1) obj.ans += 1ll * obj.cnt0[i] * (obj.tot2 - obj.cnt2[i]);
}
if (colour[r] == 0) obj.ans += obj.tot12;
if (colour[r] == 2) obj.ans += obj.tot10;
gans += obj.ans;
done[r] = true;
for (int v : adj[r]) {
if (!done[v]) decomp(v);
}
}
struct HLD {
int t;
vt<int> sz, pos, par, root, depth;
vt<vt<int>> adj;
SegTree seg;
void init(vt<vt<int>>& _adj) {
t = 0;
sz = pos = par = root = depth = vt<int>(n);
adj = _adj;
seg.init(n);
}
int dfs_sz(int u) {
sz[u] = 1;
for (int& v : adj[u]) {
par[v] = u;
depth[v] = depth[u] + 1;
adj[v].erase(find(all(adj[v]), u));
sz[u] += dfs_sz(v);
if (sz[v] > sz[adj[u][0]]) swap(v, adj[u][0]);
}
return sz[u];
}
void dfs_hld(int u) {
pos[u] = t++;
for (int& v : adj[u]) {
root[v] = (v == adj[u][0] ? root[u] : v);
dfs_hld(v);
}
}
void gen() {
dfs_sz(0);
dfs_hld(0);
}
int query(int u, int v) {
int res = 0;
while (root[u] != root[v]) {
if (depth[root[u]] > depth[root[v]]) swap(u, v);
res += seg.query(pos[root[v]], pos[v]);
v = par[root[v]];
}
if (depth[u] > depth[v]) swap(u, v);
return res + seg.query(pos[u], pos[v]);
}
void upd(int u, int v) {
seg.upd(pos[u], v);
}
};
HLD hld;
void upd(Centroid& obj, int u, int tin, int prev_c, int new_c) {
int i = obj.subtree[tin];
int tout = obj.tout[tin];
ll prev_ans = obj.ans;
// handle removal
// not root
if (u != obj.root) {
int subroot = adj[obj.root][i];
// subtract answer
if (prev_c == 0) {
ll par1s = hld.query(u, subroot) - (colour[u] == 1);
obj.ans -= par1s * (obj.tot2 - obj.cnt2[i]);
obj.ans -= obj.tot12 - obj.cnt12[i];
if (colour[obj.root] == 1) obj.ans -= obj.tot2 - obj.cnt2[i];
if (colour[obj.root] == 2) obj.ans -= par1s;
// update counts
obj.cnt0[i]--;
obj.tot0--;
obj.tree0.upd(tin, 0);
obj.tot10 -= par1s;
obj.cnt10[i] -= par1s;
} else if (prev_c == 1) {
int t0 = obj.tree0.query(tin, tout);
int t2 = obj.tree2.query(tin, tout);
obj.ans -= t0 * (obj.tot2 - obj.cnt2[i]);
obj.ans -= t2 * (obj.tot0 - obj.cnt0[i]);
if (colour[obj.root] == 2) obj.ans -= t0;
if (colour[obj.root] == 0) obj.ans -= t2;
// upd
obj.cnt10[i] -= t0;
obj.tot10 -= t0;
obj.cnt12[i] -= t2;
obj.tot12 -= t2;
} else {
ll par1s = hld.query(u, subroot) - (colour[u] == 1);
obj.ans -= par1s * (obj.tot0 - obj.cnt0[i]);
obj.ans -= obj.tot10 - obj.cnt10[i];
if (colour[obj.root] == 1) obj.ans -= obj.tot0 - obj.cnt0[i];
if (colour[obj.root] == 0) obj.ans -= par1s;
// upd
obj.cnt2[i]--;
obj.tot2--;
obj.tree2.upd(tin, 0);
obj.tot12 -= par1s;
obj.cnt12[i] -= par1s;
}
} else {
if (prev_c == 0) {
obj.ans -= obj.tot12;
} else if (prev_c == 1) {
ll sub = 0;
F0R (j, adj[obj.root].size()) {
sub += obj.cnt0[j] * (obj.tot2 - obj.cnt2[j]);
}
obj.ans -= 1ll * sub;
} else {
obj.ans -= obj.tot10;
}
}
// now handle addition
if (u != obj.root) {
int subroot = adj[obj.root][i];
// add answer
if (new_c == 0) {
ll par1s = hld.query(u, subroot) - (colour[u] == 1);
obj.ans += par1s * (obj.tot2 - obj.cnt2[i]);
obj.ans += obj.tot12 - obj.cnt12[i];
if (colour[obj.root] == 1) obj.ans += obj.tot2 - obj.cnt2[i];
if (colour[obj.root] == 2) obj.ans += par1s;
// update counts
obj.cnt0[i]++;
obj.tot0++;
obj.tree0.upd(tin, 1);
obj.tot10 += par1s;
obj.cnt10[i] += par1s;
} else if (new_c == 1) {
int t0 = obj.tree0.query(tin, tout);
int t2 = obj.tree2.query(tin, tout);
obj.ans += t0 * (obj.tot2 - obj.cnt2[i]);
obj.ans += t2 * (obj.tot0 - obj.cnt0[i]);
if (colour[obj.root] == 2) obj.ans += t0;
if (colour[obj.root] == 0) obj.ans += t2;
// upd
obj.cnt10[i] += t0;
obj.tot10 += t0;
obj.cnt12[i] += t2;
obj.tot12 += t2;
} else {
ll par1s = hld.query(u, subroot) - (colour[u] == 1);
obj.ans += par1s * (obj.tot0 - obj.cnt0[i]);
obj.ans += obj.tot10 - obj.cnt10[i];
if (colour[obj.root] == 1) obj.ans += obj.tot0 - obj.cnt0[i];
if (colour[obj.root] == 0) obj.ans += par1s;
// upd
obj.cnt2[i]++;
obj.tot2++;
obj.tree2.upd(tin, 1);
obj.tot12 += par1s;
obj.cnt12[i] += par1s;
}
} else {
if (new_c == 0) {
obj.ans += obj.tot12;
} else if (new_c == 1) {
ll sub = 0;
F0R (j, adj[obj.root].size()) {
sub += obj.cnt0[j] * (obj.tot2 - obj.cnt2[j]);
}
obj.ans += 1ll * sub;
} else {
obj.ans += obj.tot10;
}
}
gans -= prev_ans;
gans += obj.ans;
}
void change(int u, int c) {
if (c == colour[u]) return;
for (auto [cent, tin] : parents[u]) {
upd(centroids[cent], u, tin, colour[u], c);
}
if (colour[u] == 1) hld.upd(u, 0);
colour[u] = c;
if (c == 1) hld.upd(u, 1);
}
void init(int N, vt<int> F, vt<int> U, vt<int> V, int Q) {
n = N;
q = Q;
colour = F;
parents.resize(n);
centroids.resize(n);
adj.resize(n);
F0R (i, n - 1) {
adj[U[i]].pb(V[i]);
adj[V[i]].pb(U[i]);
}
hld.init(adj);
hld.gen();
F0R (i, n) {
if (colour[i] == 1) hld.upd(i, 1);
}
sz.resize(n);
done.resize(n);
gans = 0;
decomp();
}
ll num_tours() {
return gans;
}
main() {
cin.tie(0)->sync_with_stdio(0);
int n, q;
cin >> n;
vt<int> col(n);
vt<int> a(n - 1), b(n - 1);
F0R (i, n) cin >> col[i];
F0R (i, n - 1) cin >> a[i] >> b[i];
cin >> q;
init(n, col, a, b, q);
cout << num_tours() << endl;
F0R (i, q) {
int u, c;
cin >> u >> c;
change(u, c);
cout << num_tours() << endl;
}
}