제출 #1000203

#제출 시각아이디문제언어결과실행 시간메모리
1000203caterpillowJOI tour (JOI24_joitour)C++17
86 / 100
3089 ms302488 KiB
#include <bits/stdc++.h> #pragma GCC optimize("O3,unroll-loops") #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt") using namespace std; using ll = long long; using pl = pair<ll, ll>; #define vt vector #define f first #define s second #define all(x) x.begin(), x.end() #define pb push_back #define FOR(i, a, b) for (int i = (a); i < (b); i++) #define ROF(i, a, b) for (int i = (b) - 1; i >= (a); i--) #define F0R(i, b) FOR (i, 0, b) #define endl '\n' #define debug(x) do{auto _x = x; cerr << #x << " = " << _x << endl;} while(0) const ll INF = 1e18; struct SegTree { int n; vt<int> seg; void init(int _n) { for (n = 1; n < _n; n *= 2); seg.resize(2 * n); } void upd(int i, int v) { i += n; seg[i] = v; while (i > 1) { i /= 2; seg[i] = seg[2 * i] + seg[2 * i + 1]; } } int query(int l, int r) { int res = 0; for (l += n, r += n + 1; l < r; l /= 2, r /= 2) { if (l & 1) res += seg[l++]; if (r & 1) res += seg[--r]; } return res; } }; /* centroid decomposition on the tree count the # of paths that go through some root such that there is a 0 and 1 in one subtree and a 2 in another consider the total # of tours that go through some root node consider tuples of ordered values in a tour of {subtree 1, root, subtree 2} cases are: 1. {{0, 1}, {}, {2}} 2. {{0}, {}, {1, 2}} 3. {{0}, {1}, {2}} 4. {{0, 1}, {2}, {}} 5. {{}, {0}, {1, 2}} when performing an update, we need to be able to subtract off the old paths that contained the existing restaurant and add the new paths that use it we need to be able to: 1. query # of 1's between two vertices we need to remap node labels for each centroid decomp */ using pi = pair<int, int>; struct Centroid { int root; vt<int> tout; // dfs time out SegTree tree0, tree2; // euler tour for counting # of 0's and 2's in a node's subtree vt<ll> cnt10, cnt12; // # of pairs of 10 and 12 in each subtree vt<int> cnt0, cnt2; ll tot10, tot12, tot0, tot2; vt<int> subtree; vt<pi> subtree_times; ll ans; }; int n, q; vt<vt<int>> adj; vt<Centroid> centroids; vt<vt<pi>> parents; // centroid root, time vt<int> colour; ll gans; vt<int> sz; vt<bool> done; int dfs_sz(int u, int par = -1) { sz[u] = 1; for (int v : adj[u]) { if (v == par || done[v]) continue; sz[u] += dfs_sz(v, u); } return sz[u]; } int find_centroid(int u, int tsz, int par = -1) { for (int v : adj[u]) { if (v == par || done[v]) continue; if (sz[v] * 2 > tsz) return find_centroid(v, tsz, u); } return u; } void dfs_time(int u, int& t, Centroid& obj, ll ones, int par = -1, int subtree = -1) { int tin = ++t; parents[u].pb({obj.root, t}); obj.subtree[tin] = subtree; if (colour[u] == 0) obj.tree0.upd(tin, 1); if (colour[u] == 2) obj.tree2.upd(tin, 1); if (subtree != -1) { // update 10's and 12's if (colour[u] == 0) { obj.cnt0[subtree]++; obj.cnt10[subtree] += ones; } else if (colour[u] == 1) { ones++; } else { obj.cnt2[subtree]++; obj.cnt12[subtree] += ones; } } F0R (i, adj[u].size()) { int v = adj[u][i]; if (v == par || done[v]) continue; if (subtree == -1) { obj.subtree_times[i].f = t; } dfs_time(v, t, obj, ones, u, subtree == -1 ? i : subtree); if (subtree == -1) { obj.subtree_times[i].f = t; } } obj.tout[tin] = t; } void decomp(int u = 0) { int tsz = dfs_sz(u); int r = find_centroid(u, tsz); Centroid& obj = centroids[r]; obj.root = r; obj.tout = obj.subtree = vt<int>(tsz); obj.cnt0 = obj.cnt2 = vt<int>(adj[r].size()); obj.cnt10 = obj.cnt12 = vt<ll>(adj[r].size()); obj.subtree_times.resize(adj[r].size()); obj.tree0.init(tsz); obj.tree2.init(tsz); int t = -1; dfs_time(r, t, obj, 0); obj.tot0 = accumulate(all(obj.cnt0), 0ll); obj.tot2 = accumulate(all(obj.cnt2), 0ll); obj.tot10 = accumulate(all(obj.cnt10), 0ll); obj.tot12 = accumulate(all(obj.cnt12), 0ll); // calculate answer F0R (i, adj[r].size()) { int v = adj[r][i]; if (done[v]) continue; obj.ans += 1ll * obj.cnt10[i] * (obj.tot2 - obj.cnt2[i]); obj.ans += 1ll * obj.cnt0[i] * (obj.tot12 - obj.cnt12[i]); if (colour[r] == 1) obj.ans += 1ll * obj.cnt0[i] * (obj.tot2 - obj.cnt2[i]); } if (colour[r] == 0) obj.ans += obj.tot12; if (colour[r] == 2) obj.ans += obj.tot10; gans += obj.ans; done[r] = true; for (int v : adj[r]) { if (!done[v]) decomp(v); } } struct HLD { int t; vt<int> sz, pos, par, root, depth; vt<vt<int>> adj; SegTree seg; void init(vt<vt<int>>& _adj) { t = 0; sz = pos = par = root = depth = vt<int>(n); adj = _adj; seg.init(n); } int dfs_sz(int u) { sz[u] = 1; for (int& v : adj[u]) { par[v] = u; depth[v] = depth[u] + 1; adj[v].erase(find(all(adj[v]), u)); sz[u] += dfs_sz(v); if (sz[v] > sz[adj[u][0]]) swap(v, adj[u][0]); } return sz[u]; } void dfs_hld(int u) { pos[u] = t++; for (int& v : adj[u]) { root[v] = (v == adj[u][0] ? root[u] : v); dfs_hld(v); } } void gen() { dfs_sz(0); dfs_hld(0); } int query(int u, int v) { int res = 0; while (root[u] != root[v]) { if (depth[root[u]] > depth[root[v]]) swap(u, v); res += seg.query(pos[root[v]], pos[v]); v = par[root[v]]; } if (depth[u] > depth[v]) swap(u, v); return res + seg.query(pos[u], pos[v]); } void upd(int u, int v) { seg.upd(pos[u], v); } }; HLD hld; void upd(Centroid& obj, int u, int tin, int prev_c, int new_c) { int i = obj.subtree[tin]; int tout = obj.tout[tin]; ll prev_ans = obj.ans; // handle removal // not root if (u != obj.root) { int subroot = adj[obj.root][i]; // subtract answer if (prev_c == 0) { ll par1s = hld.query(u, subroot) - (colour[u] == 1); obj.ans -= par1s * (obj.tot2 - obj.cnt2[i]); obj.ans -= obj.tot12 - obj.cnt12[i]; if (colour[obj.root] == 1) obj.ans -= obj.tot2 - obj.cnt2[i]; if (colour[obj.root] == 2) obj.ans -= par1s; // update counts obj.cnt0[i]--; obj.tot0--; obj.tree0.upd(tin, 0); obj.tot10 -= par1s; obj.cnt10[i] -= par1s; } else if (prev_c == 1) { int t0 = obj.tree0.query(tin, tout); int t2 = obj.tree2.query(tin, tout); obj.ans -= t0 * (obj.tot2 - obj.cnt2[i]); obj.ans -= t2 * (obj.tot0 - obj.cnt0[i]); if (colour[obj.root] == 2) obj.ans -= t0; if (colour[obj.root] == 0) obj.ans -= t2; // upd obj.cnt10[i] -= t0; obj.tot10 -= t0; obj.cnt12[i] -= t2; obj.tot12 -= t2; } else { ll par1s = hld.query(u, subroot) - (colour[u] == 1); obj.ans -= par1s * (obj.tot0 - obj.cnt0[i]); obj.ans -= obj.tot10 - obj.cnt10[i]; if (colour[obj.root] == 1) obj.ans -= obj.tot0 - obj.cnt0[i]; if (colour[obj.root] == 0) obj.ans -= par1s; // upd obj.cnt2[i]--; obj.tot2--; obj.tree2.upd(tin, 0); obj.tot12 -= par1s; obj.cnt12[i] -= par1s; } } else { if (prev_c == 0) { obj.ans -= obj.tot12; } else if (prev_c == 1) { ll sub = 0; F0R (j, adj[obj.root].size()) { sub += obj.cnt0[j] * (obj.tot2 - obj.cnt2[j]); } obj.ans -= 1ll * sub; } else { obj.ans -= obj.tot10; } } // now handle addition if (u != obj.root) { int subroot = adj[obj.root][i]; // add answer if (new_c == 0) { ll par1s = hld.query(u, subroot) - (colour[u] == 1); obj.ans += par1s * (obj.tot2 - obj.cnt2[i]); obj.ans += obj.tot12 - obj.cnt12[i]; if (colour[obj.root] == 1) obj.ans += obj.tot2 - obj.cnt2[i]; if (colour[obj.root] == 2) obj.ans += par1s; // update counts obj.cnt0[i]++; obj.tot0++; obj.tree0.upd(tin, 1); obj.tot10 += par1s; obj.cnt10[i] += par1s; } else if (new_c == 1) { int t0 = obj.tree0.query(tin, tout); int t2 = obj.tree2.query(tin, tout); obj.ans += t0 * (obj.tot2 - obj.cnt2[i]); obj.ans += t2 * (obj.tot0 - obj.cnt0[i]); if (colour[obj.root] == 2) obj.ans += t0; if (colour[obj.root] == 0) obj.ans += t2; // upd obj.cnt10[i] += t0; obj.tot10 += t0; obj.cnt12[i] += t2; obj.tot12 += t2; } else { ll par1s = hld.query(u, subroot) - (colour[u] == 1); obj.ans += par1s * (obj.tot0 - obj.cnt0[i]); obj.ans += obj.tot10 - obj.cnt10[i]; if (colour[obj.root] == 1) obj.ans += obj.tot0 - obj.cnt0[i]; if (colour[obj.root] == 0) obj.ans += par1s; // upd obj.cnt2[i]++; obj.tot2++; obj.tree2.upd(tin, 1); obj.tot12 += par1s; obj.cnt12[i] += par1s; } } else { if (new_c == 0) { obj.ans += obj.tot12; } else if (new_c == 1) { ll sub = 0; F0R (j, adj[obj.root].size()) { sub += obj.cnt0[j] * (obj.tot2 - obj.cnt2[j]); } obj.ans += 1ll * sub; } else { obj.ans += obj.tot10; } } gans -= prev_ans; gans += obj.ans; } void change(int u, int c) { if (c == colour[u]) return; for (auto [cent, tin] : parents[u]) { upd(centroids[cent], u, tin, colour[u], c); } if (colour[u] == 1) hld.upd(u, 0); colour[u] = c; if (c == 1) hld.upd(u, 1); } void init(int N, vt<int> F, vt<int> U, vt<int> V, int Q) { n = N; q = Q; colour = F; parents.resize(n); centroids.resize(n); adj.resize(n); F0R (i, n - 1) { adj[U[i]].pb(V[i]); adj[V[i]].pb(U[i]); } hld.init(adj); hld.gen(); F0R (i, n) { if (colour[i] == 1) hld.upd(i, 1); } sz.resize(n); done.resize(n); gans = 0; decomp(); } ll num_tours() { return gans; }

컴파일 시 표준 에러 (stderr) 메시지

joitour.cpp: In function 'void dfs_time(int, int&, Centroid&, ll, int, int)':
joitour.cpp:15:42: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
   15 | #define FOR(i, a, b) for (int i = (a); i < (b); i++)
      |                                          ^
joitour.cpp:17:19: note: in expansion of macro 'FOR'
   17 | #define F0R(i, b) FOR (i, 0, b)
      |                   ^~~
joitour.cpp:135:5: note: in expansion of macro 'F0R'
  135 |     F0R (i, adj[u].size()) {
      |     ^~~
joitour.cpp: In function 'void decomp(int)':
joitour.cpp:15:42: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
   15 | #define FOR(i, a, b) for (int i = (a); i < (b); i++)
      |                                          ^
joitour.cpp:17:19: note: in expansion of macro 'FOR'
   17 | #define F0R(i, b) FOR (i, 0, b)
      |                   ^~~
joitour.cpp:174:5: note: in expansion of macro 'F0R'
  174 |     F0R (i, adj[r].size()) {
      |     ^~~
joitour.cpp: In function 'void upd(Centroid&, int, int, int, int)':
joitour.cpp:15:42: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
   15 | #define FOR(i, a, b) for (int i = (a); i < (b); i++)
      |                                          ^
joitour.cpp:17:19: note: in expansion of macro 'FOR'
   17 | #define F0R(i, b) FOR (i, 0, b)
      |                   ^~~
joitour.cpp:309:13: note: in expansion of macro 'F0R'
  309 |             F0R (j, adj[obj.root].size()) {
      |             ^~~
joitour.cpp:15:42: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
   15 | #define FOR(i, a, b) for (int i = (a); i < (b); i++)
      |                                          ^
joitour.cpp:17:19: note: in expansion of macro 'FOR'
   17 | #define F0R(i, b) FOR (i, 0, b)
      |                   ^~~
joitour.cpp:374:13: note: in expansion of macro 'F0R'
  374 |             F0R (j, adj[obj.root].size()) {
      |             ^~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...