//#include "grader.cpp"
#include "joitour.h"
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define pii pair<ll,ll>
#define f first
#define s second
#define all(x) x.begin(),x.end()
#define _ ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
namespace{
const ll mxn=2e5+5;
vector<ll> adj[mxn];
vector<ll> a(mxn);
vector<ll> sz(mxn);
vector<bool> visited(mxn);
vector<vector<ll>> st(mxn);
vector<vector<ll>> en(mxn);
vector<vector<pair<ll,ll>>> par(mxn);
vector<vector<ll>> cnt(mxn,vector<ll>(3));
vector<vector<vector<ll>>> cntc(mxn);
vector<vector<ll>> sum(mxn,vector<ll>(3));
vector<vector<vector<ll>>> sumc(mxn);
vector<ll> mulsum(mxn);
ll res=0;
struct BIT{
vector<ll> bit;
ll n;
void init(ll _n){
n=_n;
bit=vector<ll>(n+1,0);
}
void update(ll pos,ll val){
if(pos>n) return;
for(;pos<=n;pos+=(pos&-pos)){
bit[pos]+=val;
}
}
ll query(ll pos){
ll ans=0;
for(;pos>0;pos-=(pos&-pos)){
ans+=bit[pos];
}
return ans;
}
};
vector<vector<BIT>> bit(mxn,vector<BIT>(3));
ll timer=1;
ll dfs(ll v,ll p=0){
sz[v]=1;
for(auto u:adj[v]){
if(u==p or visited[u]) continue;
sz[v]+=dfs(u,v);
}
return sz[v];
}
ll find(ll v,ll tot,ll p=0){
for(auto u:adj[v]){
if(u==p or visited[u]) continue;
if(sz[u]*2>tot) return find(u,tot,v);
}
return v;
}
void dfs1(ll v,ll p,ll pp,ll ppp){
st[v].push_back(++timer);
//cout<<"p "<<p<<' '<<v<<'\n';
par[v].push_back({ppp,pp});
for(auto u:adj[v]){
if(u==p or visited[u]) continue;
dfs1(u,v,pp,ppp);
}
en[v].push_back(timer);
}
void centroid(ll v=1){
//cout<<"cen "<<v<<'\n';
v=find(v,dfs(v));
timer=1;
st[v].push_back(1);
par[v].push_back({v,-1});
ll cnt=0;
for(auto u:adj[v]){
if(visited[u]) continue;
dfs1(u,v,cnt,v);
cnt++;
}
cntc[v]=vector<vector<ll>>(cnt,vector<ll>(3));
sumc[v]=vector<vector<ll>>(cnt,vector<ll>(3));
en[v].push_back(timer);
for(ll t=0;t<3;t++){
bit[v][t].init(timer);
}
visited[v]=true;
for(auto u:adj[v]){
if(visited[u]) continue;
centroid(u);
}
}
void add(ll v,ll d){
ll tmp=0;
for(ll i=0;i<(ll)par[v].size();i++){
auto p=par[v][i];
if(v==p.f){
if(a[v]==1){
tmp+=1ll*cnt[v][2]*cnt[v][0];
tmp-=mulsum[v];//sum(cnt[u][0]*cnt[u][2])
}
else{
tmp+=sum[v][2-a[v]];
}
//cout<<p.f<<' '<<tmp<<'\n';
continue;
}
if(a[v]==1){
bit[p.f][1].update(st[v][i],d);
bit[p.f][1].update(en[v][i]+1,-d);
ll cnt0=bit[p.f][0].query(en[v][i])-bit[p.f][0].query(st[v][i]-1);
sumc[p.f][p.s][0]+=cnt0*d;
sum[p.f][0]+=cnt0*d;
tmp+=1ll*cnt0*(cnt[p.f][2]-cntc[p.f][p.s][2]);
ll cnt2=bit[p.f][2].query(en[v][i])-bit[p.f][2].query(st[v][i]-1);
sumc[p.f][p.s][2]+=cnt2*d;
sum[p.f][2]+=cnt2*d;
tmp+=1ll*cnt2*(cnt[p.f][0]-cntc[p.f][p.s][0]);
if(a[p.f]==0){
tmp+=cnt2;
}
else{
tmp+=cnt0;
}
}
else{
bit[p.f][a[v]].update(st[v][i],d);
ll cnt1=bit[p.f][1].query(st[v][i]);
cnt[p.f][a[v]]+=d;
cntc[p.f][p.s][a[v]]+=d;
sum[p.f][a[v]]+=cnt1*d;
sumc[p.f][p.s][a[v]]+=cnt1*d;
tmp+=(sum[p.f][2-a[v]]-sumc[p.f][p.s][2-a[v]]);
tmp+=1ll*cnt1*(cnt[p.f][2-a[v]]-cntc[p.f][p.s][2-a[v]]);
if(a[p.f]==1){
tmp+=(cnt[p.f][2-a[v]]-cntc[p.f][p.s][2-a[v]]);
}
else if(a[p.f]==2-a[v]){
tmp+=cnt1;
}
mulsum[p.f]+=d*cntc[p.f][p.s][2-a[v]];
}
//cout<<p.f<<' '<<tmp<<'\n';
}
res+=tmp*d;
//cout<<v<<' '<<d<<' '<<res<<'\n';
}
};
void init(int n, std::vector<int> F, std::vector<int> U, std::vector<int> V,
int Q){
for(ll i=1;i<=n;i++){
a[i]=F[i-1];
}
for(ll i=0;i<n-1;i++){
U[i]++;
V[i]++;
adj[V[i]].push_back(U[i]);
adj[U[i]].push_back(V[i]);
}
centroid();
/*for(ll i=1;i<=n;i++){
cout<<i<<'\n';
for(auto p:par[i]){
cout<<p.f<<' '<<p.s<<'\n';
}
}*/
for(ll i=1;i<=n;i++){
add(i,1);
}
}
void change(int x, int y) {
x++;
//cout<<'\n';
//cout<<"q "<<x<<' '<<y<<'\n';
add(x,-1);
a[x]=y;
add(x,1);
}
long long num_tours() {
return res;
}
/*
11
0 1 1 2 1 0 2 1 2 2 1
0 1
0 2
0 3
1 4
1 5
2 7
2 8
3 9
5 6
9 10
0
*/
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |