#include "joitour.h"
#include<bits/stdc++.h>
using namespace std;
#define ll long long int
#define MOD (1000000000+7)
#define MOD1 (998244353)
#define pb push_back
#define all(x) x.begin(), x.end()
#define en cout << '\n'
#define ff first
#define ss second
#define pii pair<int,int>
#define vi vector<int>
const int N = 2e5+100;
struct Fenwick{
int n;
vector<int> t;
void init(int _n){
n = _n;
t.clear();
t.resize(n+1, 0);
}
void add(int v, int x){
while(v <= n){
t[v] += x;
v += (v&-v);
}
}
int get(int v){
int res = 0;
while(v > 0){
res += t[v];
v -= (v&-v);
}
return res;
}
};
int tp[N], s[N], TIN[N], TOUT[N], timer = 1;
vector<array<ll, 5>> dp[N]; // 0, 1, 2, 01, 21
ll ans, bad[N];
vi g[N];
bitset<N> vis;
Fenwick T[N][3];
vector<array<int, 5>> C[N];
ll calc(int v, int c, int tpp, bool flag){
ll sum = 0;
if(flag){
if(tpp == 1){
sum = dp[c][v][0] * dp[c][v][2];
}else if(tpp == 0){
sum = dp[c][v][4];
}else{
sum = dp[c][v][3];
}
}
return sum + dp[c][v][2]*dp[c][v][3] + dp[c][v][0]*dp[c][v][4]; // 2*10 + 0*12
}
void pre(int v, int p){
s[v] = 1;
for(int u: g[v]){
if(!vis[u] && u != p){
pre(u, v);
s[v] += s[u];
}
}
}
int num;
int centro(int v, int p){
for(int u: g[v]){
if(!vis[u] && u != p && s[u] >= (num+1)/2) return centro(u, v);
}
return v;
}
void dfs(int v, int p, int c, int d){
TIN[v] = timer++;
dp[c].pb(array<ll,5>{0ll});
for(int u: g[v]){
if(u != p && !vis[u]){
dfs(u, v, c, (v == c ? u : d));
for(int j = 0; j < 5; ++j){
dp[c][TIN[v]][j] += dp[c][TIN[u]][j];
}
if(tp[v] == 1 && v != c) dp[c][TIN[v]][3] += dp[c][TIN[u]][0], dp[c][TIN[v]][4] += dp[c][TIN[u]][2];
}
}
TOUT[v] = timer - 1;
C[v].pb({c, d, TIN[v], TOUT[v], TIN[d]});
if(v != c){
T[c][tp[v]].add(TIN[v], 1);
if(tp[v] == 1) T[c][tp[v]].add(TOUT[v] + 1, -1);
dp[c][TIN[v]][tp[v]]++;
}
}
void f(int v){
pre(v, v);
num = s[v];
v = centro(v, v);
timer = 1;
T[v][0].init(num);
T[v][1].init(num);
T[v][2].init(num);
dp[v].clear();
dp[v].pb(array<ll,5>{0ll});
dfs(v, v, v, v);
ans += calc(1, v, tp[v], true);
bad[v] = 0;
for(int u: g[v]){
if(!vis[u]){
ans -= calc(TIN[u], v, tp[u], false);
bad[v] += dp[v][TIN[u]][0] * dp[v][TIN[u]][2];
}
}
if(tp[v] == 1) ans -= bad[v];
vis[v] = 1;
for(int u: g[v]){
if(!vis[u]) f(u);
}
}
void init(int n, std::vector<int> F, std::vector<int> U, std::vector<int> V, int Q) {
for(int i = 0; i + 1 < n; ++i){
g[U[i]].pb(V[i]);
g[V[i]].pb(U[i]);
}
for(int i = 1; i <= n; ++i){
tp[i - 1] = F[i - 1];
}
ans = 0;
f(0);
}
void change(int v, int Y) {
int A = tp[v];
int B = Y;
for(auto [centro, low, tin, tout, low_tin]: C[v]){
// cout << v << ' ' << centro << ' ' << low << ' ' << tin << ' ' << tout << ' ' << low_tin << '\n';
// cout << T[centro][2].get(tout) << "f\n";
if(v == centro){
ans -= calc(1, v, A, true);
if(A == 1) ans += bad[v];
ans += calc(1, v, B, true);
if(B == 1) ans -= bad[v];
continue;
}
// removal
ans -= calc(1, centro, tp[centro], true);
if(tp[centro] == 1) ans += bad[centro];
ans += calc(low_tin, centro, tp[low], false);
bad[centro] -= dp[centro][low_tin][0] * dp[centro][low_tin][2];
// changes - removal
dp[centro][1][A]--;
if(A == 0){
dp[centro][1][3] -= T[centro][1].get(tin);
}else if(A == 1){
dp[centro][1][3] -= T[centro][0].get(tout) - T[centro][0].get(tin - 1);
dp[centro][1][4] -= T[centro][2].get(tout) - T[centro][2].get(tin - 1);
}else{
dp[centro][1][4] -= T[centro][1].get(tin);
}
dp[centro][low_tin][A]--;
if(A == 0){
dp[centro][low_tin][3] -= T[centro][1].get(tin);
}else if(A == 1){
dp[centro][low_tin][3] -= T[centro][0].get(tout) - T[centro][0].get(tin - 1);
dp[centro][low_tin][4] -= T[centro][2].get(tout) - T[centro][2].get(tin - 1);
}else{
dp[centro][low_tin][4] -= T[centro][1].get(tin);
}
T[centro][A].add(tin, -1);
if(A == 1) T[centro][A].add(tout + 1, 1);
// changes - addition
T[centro][B].add(tin, 1);
if(B == 1) T[centro][B].add(tout + 1, -1);
dp[centro][1][B]++;
if(B == 0){
dp[centro][1][3] += T[centro][1].get(tin);
}else if(B == 1){
dp[centro][1][3] += T[centro][0].get(tout) - T[centro][0].get(tin - 1);
// cout << tin << ' ' << tout << ' ' << << "f\n";
dp[centro][1][4] += T[centro][2].get(tout) - T[centro][2].get(tin - 1);
}else{
dp[centro][1][4] += T[centro][1].get(tin);
}
dp[centro][low_tin][B]++;
if(B == 0){
dp[centro][low_tin][3] += T[centro][1].get(tin);
}else if(B == 1){
dp[centro][low_tin][3] += T[centro][0].get(tout) - T[centro][0].get(tin - 1);
dp[centro][low_tin][4] += T[centro][2].get(tout) - T[centro][2].get(tin - 1);
}else{
dp[centro][low_tin][4] += T[centro][1].get(tin);
}
// recalculation
ans -= calc(low_tin, centro, tp[low], false);
bad[centro] += dp[centro][low_tin][0] * dp[centro][low_tin][2];
ans += calc(1, centro, tp[centro], true);
if(tp[centro] == 1) ans -= bad[centro];
}
tp[v] = Y;
}
long long num_tours() {
return ans;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |