#include "joitour.h"
#include <bits/stdc++.h>
using namespace std;
using i32 = int32_t;
#define int long long
const int N = 2e5 + 5;
const int LG = 20;
int n, bl[N], freq[N], sz[N], tot, timer = 0, tin[N], tout[N], f[N], ans[N], total = 0;
vector<int> adj[N], anc[N];
map<int, pair<int, int>> cd[N];
int par[N][20];
map<pair<int, int>, array<int, 5>> mp; // cent, child, 0, 1, 2, 01, 21
array<int, 8> info[N]; // s0, s1, s2, s01, s21, p(0,21), p(2,01), p(0,2);
struct Fenwick {
int maxn;
vector<vector<int>> fen;
void init(int sz) {
fen.resize(3);
maxn = sz;
for (int i = 0; i < 3; i++) fen[i].assign(maxn + 1, 0);
}
void upd(int f, int id, int k) {
for (; id <= maxn; id += (id & -id)) fen[f][id] += k;
}
int pref(int f, int id) {
int res = 0;
for (; id; id -= (id & -id)) res += fen[f][id];
return res;
}
int query(int f, int l, int r) {
return pref(f, r) - pref(f, l - 1);
}
} fen[N];
void get_sz(int u, int p) {
sz[u] = 1;
for (auto& v : adj[u]) {
if (v == p || bl[v]) continue;
get_sz(v, u);
sz[u] += sz[v];
}
}
int get_cent(int u, int p) {
// cout << "dfs: " << u << '\n';
int mx = 0, res = -1;
for (auto& v : adj[u]) {
if (v == p || bl[v]) continue;
res = max(res, get_cent(v, u));
mx = max(mx, sz[v]);
}
mx = max(mx, tot - sz[u]);
if (mx <= tot / 2) return u;
return res;
}
vector<int> node, tmp;
void dfs(int u, int p, int cent) {
tin[u] = ++timer;
anc[u].push_back(cent);
node.push_back(u);
for (auto& v : adj[u]) {
if (v == p || bl[v]) continue;
dfs(v, u, cent);
}
tout[u] = timer;
// cd[cent][u] = {tin[u], tout[u]};
}
void dadfs(int u, int p, int rt, int cent) {
tmp.push_back(u);
par[u][anc[u].size() - 1] = rt;
for (auto& v : adj[u]) {
if (v == p || bl[v] || v == cent) continue;
dadfs(v, u, rt, cent);
}
}
void calc(int u) {
ans[u] = 0;
if (f[u] == 0) ans[u] += info[u][4];
if (f[u] == 1) ans[u] += info[u][0] * info[u][2] - info[u][7];
if (f[u] == 2) ans[u] += info[u][3];
ans[u] += info[u][0] * info[u][4] - info[u][5];
ans[u] += info[u][2] * info[u][3] - info[u][6];
}
void dnc(int u) {
// cout << "cd: " << u << '\n';
get_sz(u, -1);
// cout << "finished!\n";
tot = sz[u];
int cent = get_cent(u, -1);
// cout << "cent: " << cent << '\n';
timer = 0;
node.clear();
dfs(u, -1, cent);
fen[cent].init(timer);
for (auto& x : node) fen[cent].upd(f[x], tin[x], 1);
vector<vector<int>> ps(3, vector<int> (timer+1, 0));
for (auto& x : node) ps[f[x]][tin[x]]++;
for (int i = 0; i < 3; i++) for (int j = 1; j <= timer; j++) ps[i][j] += ps[i][j - 1];
for (auto& v : adj[cent]) if (!bl[v]) {
tmp.clear();
dadfs(v, -1, v, cent);
array<int, 5> arr = {0, 0, 0, 0, 0};
for (int i = 0; i < 3; i++) arr[i] = ps[i][tout[v]] - ps[i][tin[v] - 1];
for (auto& x : tmp) {
if (f[x] == 1) {
arr[3] += ps[0][tout[x]] - ps[0][tin[x] - 1];
arr[4] += ps[2][tout[x]] - ps[2][tin[x] - 1];
}
}
mp[{cent, v}] = arr;
for (int i = 0; i < 5; i++) info[cent][i] += arr[i];
info[cent][5] += arr[0] * arr[4];
info[cent][6] += arr[2] * arr[3];
info[cent][7] += arr[0] * arr[2];
}
calc(cent);
bl[cent] = 1;
for (auto& v : adj[cent]) if (!bl[v]) dnc(v);
}
void init(i32 _N, std::vector<i32> F, std::vector<i32> U, std::vector<i32> V,
i32 Q) {
cerr << "hi\n";
n = _N;
for (int i = 0; i < n - 1; i++) {
adj[U[i]].push_back(V[i]);
adj[V[i]].push_back(U[i]);
}
for (int i = 0; i < n; i++) f[i] = F[i];
// cerr << "hi\n";
dnc(0);
for (int i = 0; i < n; i++) total += ans[i];
}
void change(i32 X, i32 Y) {
f[X] = Y;
total -= ans[X];
total += ans[X];
}
int num_tours() {
return total;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |