#include "joitour.h"
#include<bits/stdc++.h>
#define pb push_back
#define all(v) v.begin(),v.end()
#define forf(i,s,e) for(int i = s; i<=e; i++)
#define forb(i,s,e) for(int i = s; i>=e; i--)
#define idx(i,v) lower_bound(all(v),i)-v.begin()
#define comp(v) v.erase(unique(all(v)),v.end())
#define sz(v) (int)v.size()
#define fs first
#define se second
#define SP << " " <<
#define LN << "\n"
#define IO cin.tie(0);cout.tie(0);ios_base::sync_with_stdio(false);
using namespace std;
typedef long long ll;
ll inf = 1e18;
int N;
int C[200001];
vector<int> adj[200001];
vector<int> c_ch[200001];
int c_par[200001],c_dep[200001],c_chidx[200001];
int sz[200001], chk[200001], pth[200001];
ll subsum[200001][3][3];
ll dp[200001][3][3];
int getsz(int now, int p = -1){
sz[now] = 1;
for(int &i : adj[now]) if(i!=p && !chk[i]) sz[now]+=getsz(i,now);
return sz[now];
}
int getcent(int now, int p, int cap, int d){
if(d!= -1) pth[now] = d;
for(int &i : adj[now]) if(i!=p && !chk[i] && sz[i]*2>cap) return getcent(i,now,cap,d);
return now;
}
int centdecom(int now, int par, int d = 1){
now = getcent(now,-1,getsz(now),-1);
c_chidx[now] = -1;
c_dep[now] = d;
chk[now] = 1;
for(int &i : adj[now]){
if(!chk[i]){
int cent = getcent(i,-1,getsz(i),-1);
if(pth[i] == d) c_chidx[now] = cent;
cent = getcent(i,-1,getsz(i),d+1);
c_par[cent] = now;
c_ch[now].pb(cent);
int cent2 = centdecom(i,par,d+1);
assert(cent==cent2);
}
}
return now;
}
int cent;
ll ans[200001][3];
ll rans = 0;
void ansupd(ll now, ll f){
forf(i,0,2) {
ans[now][i] += subsum[now][0][1] * subsum[now][2][2] * f;
ans[now][i] += subsum[now][0][0] * subsum[now][2][1] * f;
}
ans[now][0] += subsum[now][2][1]*f;
ans[now][1] += subsum[now][0][0]*subsum[now][2][2]*f;
ans[now][2] += subsum[now][0][1]*f;
}
void ansupd2(ll now, ll f){
int par = c_par[now];
if(par == -1) return;
forf(i,0,2) {
ans[par][i] -= dp[now][0][1] * dp[now][2][2] * f;
ans[par][i] -= dp[now][0][0] * dp[now][2][1] * f;
}
ans[par][1] -= dp[now][0][0]*dp[now][2][2]*f;
}
void upd(int now, int to){
int par = c_par[now];
int ch = c_chidx[now];
if(par != -1) rans -= ans[par][C[par]],ansupd(par,-1),ansupd2(now,-1);
if(par != -1) forf(i,0,2) forf(j,0,2) subsum[par][i][j] -= dp[now][i][j];
C[now] = to;
forf(i,0,2) dp[now][i][i] = subsum[now][i][i] + ((C[now]==i)?1:0);
forf(i,0,2) forf(j,0,2) {
if(i==j) continue;
dp[now][i][j] = subsum[now][i][j];
if (ch != -1){
dp[now][i][j] += (-dp[ch][i][j]+dp[ch][j][i]);
dp[now][i][j] += dp[ch][j][j]* (subsum[now][i][i] - dp[ch][i][i]);
}
if (C[now]==j){
dp[now][i][j] += subsum[now][i][i];
if(ch != -1) dp[now][i][j] -= dp[ch][i][i];
}
if(C[now] == i && ch != -1){
dp[now][i][j] += dp[ch][j][j];
}
}
if(par != -1) forf(i,0,2) forf(j,0,2) subsum[par][i][j] += dp[now][i][j];
if(par != -1) upd(par,C[par]);
if(par != -1) ansupd(par,1),ansupd2(now,1),rans += ans[par][C[par]];
}
void init(int N, std::vector<int> F, std::vector<int> U, std::vector<int> V,
int Q) {
::N = N;
forf(i,0,N-2) adj[U[i]].pb(V[i]) ,adj[V[i]].pb(U[i]);
cent = centdecom(0,-1,1); c_par[cent] = -1;
forf(i,0,N-1){
rans += ans[i][C[i]];
upd(i,F[i]);
rans += ans[i][C[i]];
}
}
void change(int X, int Y) {
rans -= ans[X][C[X]];
upd(X,Y);
rans += ans[X][C[X]];
}
long long num_tours() {
return rans;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |