#include <bits/stdc++.h>
#include <bits/extc++.h>
#include "joitour.h"
using namespace __gnu_pbds;
using namespace std;
// #pragma GCC optimize("Ofast")
// #pragma GCC optimize ("unroll-loops")
// #pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native")
#define ff first
#define sc second
#define pb push_back
#define ll long long
#define pll pair<ll, ll>
#define pii pair<int, int>
#define ull unsigned long long
#define all(x) (x).begin(),(x).end()
#define rall(x) (x).rbegin(),(x).rend()
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
mt19937_64 rngl(chrono::steady_clock::now().time_since_epoch().count());
// #define int long long
// #define int unsigned long long
// #define ordered_set(T) tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>
// #define ordered_multiset(T) tree<T, null_type, less_equal<T>, rb_tree_tag, tree_order_statistics_node_update>
// const ll mod = 1e9 + 7;
// const ll mod = 998244353;
const ll inf = 1e9;
const ll biginf = 1e18;
const int maxN = 2 * 1e5 + 15;
struct fenwick {
vector<ll> fn;
fenwick() : fn() {}
fenwick(int n) : fn(n + 1, 0) {}
void update(int x, int d) {
for ( ; x < fn.size(); x += (x & -x)) fn[x] += d;
}
int get(int x) {
int ans = 0;
for ( ; x; x -= (x & -x)) ans += fn[x];
return ans;
}
int query(int l, int r) {
return get(r) - get(l - 1);
}
};
fenwick fn[5][maxN];
bool used[maxN];
vector<int> g[maxN];
ll n, ans, f[maxN], sz[maxN], in[20][maxN], out[20][maxN], id[20][maxN], tim, stim, ex[maxN], par[maxN], dep[maxN];
ll cnt[4][maxN], scnt[4][maxN];
// 0 - 0
// 1 - 1
// 2 - 2
// 0-1 - 3
// 1-2 - 4
int cur_cent;
void getsize(int v, int p) {
sz[v] = 1;
for (int u : g[v]) {
if (u == p || used[u]) continue;
getsize(u, v); sz[v] += sz[u];
}
}
int getcent(int v, int p, int need) {
for (int u : g[v]) {;
if (u == p || used[u] || 2 * sz[u] <= need) continue;
return getcent(u, v, need);
} return v;
}
void dfs(int v, int p, int d) {
in[d][v] = ++tim;
id[d][v] = stim;
if (f[v] == 0) {
fn[0][cur_cent].update(in[d][v], 1);
cnt[0][cur_cent]++; scnt[0][stim]++;
}
if (f[v] == 2) {
fn[2][cur_cent].update(in[d][v], 1);
cnt[2][cur_cent]++; scnt[2][stim]++;
}
for (int u : g[v]) {
if (u == p || used[u]) continue;
dfs(u, v, d); if (p == -1) stim++;
}
out[d][v] = tim;
if (f[v] == 1) {
fn[1][cur_cent].update(in[d][v], 1);
fn[1][cur_cent].update(out[d][v] + 1, -1);
ll cnt01 = fn[0][cur_cent].query(in[d][v], out[d][v]);
ll cnt12 = fn[2][cur_cent].query(in[d][v], out[d][v]);
cnt[3][cur_cent] += cnt01; cnt[4][cur_cent] += cnt12;
scnt[3][stim] += cnt01; scnt[4][stim] += cnt12;
}
}
void decompose(int v, int d, int p) {
getsize(v, -1);
int cur = sz[v];
v = getcent(v, -1, cur);
used[v] = 1; par[v] = p; dep[v] = d;
for (int i = 0; i < 5; i++)
fn[i][v] = fenwick(cur);
tim = 0; cur_cent = v;
for (int u : g[v]) {
if (used[u]) continue;
dfs(u, v, d);
ex[v] -= scnt[0][stim] * scnt[2][stim];
ans -= scnt[3][stim] * scnt[2][stim];
ans -= scnt[0][stim] * scnt[4][stim];
stim++;
}
ex[v] += cnt[0][v] * cnt[2][v];
ans += cnt[0][v] * cnt[4][v];
ans += cnt[3][v] * cnt[2][v];
if (f[v] == 0) ans += cnt[4][v];
if (f[v] == 1) ans += ex[v];
if (f[v] == 2) ans += cnt[3][v];
for (int u : g[v]) {
if (!used[u]) decompose(u, d + 1, v);
}
}
void init(int _n, vector<int> _f, vector<int> _u, vector<int> _v, int _q) {
n = _n;
for (int i = 0; i < n; i++) f[i] = _f[i];
for (int i = 0; i < n - 1; i++) {
g[_u[i]].pb(_v[i]);
g[_v[i]].pb(_u[i]);
}
decompose(0, 0, -1);
}
void change(int v, int x) {
}
ll num_tours() {
return ans;
}
// int main() {
// int N;
// assert(scanf("%d", &N) == 1);
// std::vector<int> F(N);
// for (int i = 0; i < N; i++) {
// assert(scanf("%d", &F[i]) == 1);
// }
// std::vector<int> U(N - 1), V(N - 1);
// for (int j = 0; j < N - 1; j++) {
// assert(scanf("%d %d", &U[j], &V[j]) == 2);
// }
// int Q;
// assert(scanf("%d", &Q) == 1);
// init(N, F, U, V, Q);
// printf("%lld\n", num_tours());
// fflush(stdout);
// for (int k = 0; k < Q; k++) {
// int X, Y;
// assert(scanf("%d %d", &X, &Y) == 2);
// change(X, Y);
// printf("%lld\n", num_tours());
// fflush(stdout);
// }
// }
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |