#include "joitour.h"
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
struct AIB
{
vector<int> v;
void init(int siz)
{
v.clear();
v.resize(siz+2, 0);
}
int qry(int poz)
{
assert(poz < v.size());
int aux = 0;
for(int i=poz;i>0;i-=(i&(-i)))
aux += v[i];
return aux;
}
void upd(int poz, int newv)
{
assert(0 < poz);
for(int i=poz;i<v.size();i+=(i&(-i)))
v[i] += newv;
}
};
int n;
vector<int> con[200005];
bool iss[200005];
int idkin[200005], idkout[200005];
struct INFO_CEN
{
vector<int> tin, tout, which_suba;
vector<ll> cnt10, cnt12, cnt0, cnt2;
ll tot10, tot12, tot0, tot2;
ll p0, p1, p2;
AIB lant[3], normal[3];
int root, isroot[3];
map<int,int> nrm;
vector<int> nodes;
void dfs_timer(int nod, int par)
{
nodes.push_back(nod);
for(int adj:con[nod])
{
if(adj == par || iss[adj])
continue;
dfs_timer(adj, nod);
}
}
int timer;
vector<int> subs;
void dfs(int nod, int par)
{
int nrmnod = nrm[nod], nrmadj;
tin[nrmnod] = ++timer;
for(int adj:con[nod])
{
if(adj == par || iss[adj])
continue;
nrmadj = nrm[adj];
if(which_suba[nrmnod] == -1)
{
which_suba[nrmadj] = nrmadj;
subs.push_back(nrmadj);
}
else
which_suba[nrmadj] = which_suba[nrmnod];
dfs(adj, nod);
}
tout[nrmnod] = timer;
}
void init(int coproot)
{
root = coproot;
dfs_timer(root, -1);
for(int i=0;i<nodes.size();i++)
nrm[nodes[i]] = i;
tin.resize(nodes.size() + 2, 0);
tout = which_suba = tin;
cnt0.resize(nodes.size() + 2, 0);
cnt2 = cnt10 = cnt12 = cnt0;
timer = 0;
which_suba[nrm[root]] = -1;
dfs(root, -1);
for(int c=0;c<3;c++)
{
lant[c].init(timer);
normal[c].init(timer);
}
//cerr<<root<<": "<<timer<<" root: timer\n";
}
ll recalc()
{
//cerr<<root<<" incepe recalc\n";
ll sum = 0;
sum += (ll)tot0 * tot12;
sum += (ll)tot2 * tot10;
sum -= p0;
sum -= p1;
/*for(int s:subs)
{
sum -= (ll)cnt0[s] * cnt12[s];//p0
sum -= (ll)cnt2[s] * cnt10[s];//p1
}*/
if(isroot[0])
{
sum += tot12;
}
else if(isroot[2])
{
sum += tot10;
}
else if(isroot[1])
{
sum += (ll)tot0 * tot2;
sum -= p2;
}
return sum;
}
void upd(int nod, int tip, int newv)
{
if(nod == root)
{
isroot[tip] += newv;
return;
}
nod = nrm[nod];
int s = which_suba[nod];
p0 -= (ll)cnt0[s] * cnt12[s];
p1 -= (ll)cnt2[s] * cnt10[s];
p2 -= (ll)cnt0[s] * cnt2[s];
int aux;
if(tip == 0)
{
aux = newv;
cnt0[s] += aux;
tot0 += aux;
aux = newv * lant[1].qry(tin[nod]);
cnt10[s] += aux;
tot10 += aux;
//cerr<<"add 0\n";
}
else if(tip == 2)
{
aux = newv;
cnt2[s] += aux;
tot2 += aux;
aux = newv * lant[1].qry(tin[nod]);
cnt12[s] += aux;
tot12 += aux;
}
else
{
assert(tip == 1);
aux = newv * (normal[0].qry(tout[nod]) - normal[0].qry(tin[nod] - 1));
cnt10[s] += aux;
tot10 += aux;
aux = newv * (normal[2].qry(tout[nod]) - normal[2].qry(tin[nod] - 1));
cnt12[s] += aux;
tot12 += aux;
}
p0 += (ll)cnt0[s] * cnt12[s];
p1 += (ll)cnt2[s] * cnt10[s];
p2 += (ll)cnt0[s] * cnt2[s];
normal[tip].upd(tin[nod], newv);
lant[tip].upd(tin[nod], newv);
lant[tip].upd(tout[nod] + 1, -newv);
}
};
INFO_CEN info[200005];
int centroid_parent[200005];
ll ans;
void add(int nod, int tip, int newv)
{
int c = nod;
while(c != -1)
{
ans -= info[c].recalc();
info[c].upd(nod, tip, newv);
ans += info[c].recalc();
c = centroid_parent[c];
}
}
int siz[200005];
int find_centroid(int nod, int par, int tot)
{
for(int adj:con[nod])
if(!iss[adj] && adj != par && siz[adj] * 2 > tot)
return find_centroid(adj, nod, tot);
return nod;
}
void get_sizes(int nod, int par)
{
siz[nod] = 1;
for(int adj:con[nod])
{
if(iss[adj] || adj == par)
continue;
get_sizes(adj, nod);
siz[nod] += siz[adj];
}
}
int dfs_centroid(int nod)
{
get_sizes(nod, -1);
nod = find_centroid(nod, -1, siz[nod]);
info[nod].init(nod);
iss[nod] = 1;
for(int adj:con[nod])
{
if(!iss[adj])
{
int who = dfs_centroid(adj);
centroid_parent[who] = nod;
}
}
return nod;
}
vector<int> mytype;
void init(int N, std::vector<int> F, std::vector<int> U, std::vector<int> V, int Q)
{
assert(!n);
n = N;
for(int i=0;i<N-1;i++)
{
con[U[i]].push_back(V[i]);
con[V[i]].push_back(U[i]);
}
centroid_parent[dfs_centroid(0)] = -1;
mytype = F;
for(int i=0;i<N;i++)
add(i, mytype[i], +1);
}
void change(int x, int y)
{
add(x, mytype[x], -1);
mytype[x] = y;
add(x, mytype[x], +1);
}
ll num_tours()
{
return ans;
}