#include "joitour.h"
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#define ll long long
#define ld long double
#define ull unsigned long long
#define ff first
#define ss second
#define pii pair<int,int>
#define pll pair<long long, long long>
#define vi vector<int>
#define vl vector<long long>
#define pb push_back
#define rep(i, b) for(int i = 0; i < (b); ++i)
#define rep2(i,a,b) for(int i = a; i <= (b); ++i)
#define rep3(i,a,b,c) for(int i = a; i <= (b); i+=c)
#define count_bits(x) __builtin_popcountll((x))
#define all(x) (x).begin(),(x).end()
#define siz(x) (int)(x).size()
#define forall(it,x) for(auto& it:(x))
using namespace __gnu_pbds;
using namespace std;
typedef tree<int, null_type, less<int>, rb_tree_tag,tree_order_statistics_node_update> ordered_set;
//mt19937 mt;void random_start(){mt.seed(chrono::time_point_cast<chrono::milliseconds>(chrono::high_resolution_clock::now()).time_since_epoch().count());}
//ll los(ll a, ll b) {return a + (mt() % (b-a+1));}
const int INF = 1e9+50;
const ll INF_L = 1e18+40;
const ll MOD = 1e9+7;
struct elm
{
int centr, p1, p2, s1, s2;
};
struct segtree
{
int tree_siz;
ll* sum;
ll* oper;
void setup(int n)
{
tree_siz = (1<<(__lg(n)+2))-1;
sum = new ll[tree_siz+1];
oper = new ll[tree_siz+1];
rep(i,tree_siz+1)
{
sum[i] = 0;
oper[i] = 0;
}
}
void spych(int akt)
{
sum[akt*2+1] += oper[akt];
sum[akt*2] += oper[akt];
oper[akt*2+1] += oper[akt];
oper[akt*2] += oper[akt];
oper[akt] = 0;
}
void add_seg2(int akt, int p1, int p2, int s1, int s2, ll x)
{
if(p2 < s1 || p1 > s2) return;
if(p1 >= s1 && p2 <= s2)
{
sum[akt] += x;
oper[akt] += x;
return;
}
spych(akt);
add_seg2(akt*2,p1,(p1+p2)/2,s1,s2,x);
add_seg2(akt*2+1,(p1+p2)/2+1,p2,s1,s2,x);
sum[akt] = sum[akt*2]+sum[akt*2+1];
}
ll get_sum2(int akt, int p1, int p2, int s1, int s2)
{
if(p2 < s1 || p1 > s2) return 0;
if(p1 >= s1 && p2 <= s2) return sum[akt];
spych(akt);
return get_sum2(akt*2,p1,(p1+p2)/2,s1,s2)+get_sum2(akt*2+1,(p1+p2)/2+1,p2,s1,s2);
}
void add_seg(int l, int r, ll x)
{
add_seg2(1,0,tree_siz/2,l,r,x);
}
ll get_sum(int l, int r)
{
return get_sum2(1,0,tree_siz/2,l,r);
}
};
ll cur_ans = 0;
int F[200001];
vi graph[200001];
bool odw[200001];
int sub[200001];
int cur_pre = 1;
int pre[200001];
int maxpre[200001];
vector<elm> elms[200001];
struct centr
{
segtree tree_1,tree_2,tree_3;
int n;
int cc = -1;
ll sum_1_2 = 0, sum_3_2 = 0, sum_1_3 = 0, sum_1 = 0, sum_3 = 0;
unordered_map<int,ll> map_1_2, map_3_2, map_1, map_3;
void setup(int N)
{
n = N+2;
tree_1.setup(n);
tree_2.setup(n);
tree_3.setup(n);
}
void set_vert(int p1, int p2, int s1, int s2, int f, int p)
{
if(p1 != 1)
{
if(f == 1)
{
int two = tree_2.get_sum(p1,p1);
cur_ans += (two*(sum_3-map_3[s1])+sum_3_2-map_3_2[s1])*p;
tree_1.add_seg(p1,p1,p);
sum_1_2 += (two-(cc==2))*p;
map_1_2[s1] += (two-(cc==2))*p;
sum_1 += p;
map_1[s1] += p;
sum_1_3 += (sum_3 - map_3[s1] - (cc==3))*p;
}
if(f == 2)
{
int one = tree_1.get_sum(p1,p2);
int three = tree_3.get_sum(p1,p2);
cur_ans += (one*(sum_3-map_3[s1])+three*(sum_1-map_1[s1]))*p;
tree_2.add_seg(p1,p2,p);
if(p1 != 1)
{
sum_1_2 += one*p;
map_1_2[s1] += one*p;
sum_3_2 += three*p;
map_3_2[s1] += three*p;
}
}
if(f == 3)
{
int two = tree_2.get_sum(p1,p1);
cur_ans += (two*(sum_1-map_1[s1])+sum_1_2-map_1_2[s1])*p;
tree_3.add_seg(p1,p1,p);
sum_3_2 += (two-(cc==2))*p;
map_3_2[s1] += (two-(cc==2))*p;
sum_3 += p;
map_3[s1] += p;
sum_1_3 += (sum_1 - map_1[s1] - (cc == 1))*p;
}
}
else
{
cc = f;
if(f == 1)
{
cur_ans += sum_3_2*p;
sum_1 += p;
}
if(f == 2)
{
cur_ans += sum_1_3*p;
tree_2.add_seg(p1,p2,p);
}
if(f == 3)
{
cur_ans += sum_1_2*p;
sum_3 += p;
}
}
}
};
centr centr_str[200001];
void dfs_sub(int v, int pop)
{
sub[v] = 1;
forall(it,graph[v]) if(it != pop && !odw[it])
{
dfs_sub(it,v);
sub[v] += sub[it];
}
}
void dfs_pre(int v, int pop)
{
pre[v] = cur_pre++;
forall(it,graph[v]) if(it != pop && !odw[it]) dfs_pre(it,v);
maxpre[v] = cur_pre-1;
}
void dfs_add(int v, int pop, int c, int s1, int s2)
{
elms[v].pb({c,pre[v],maxpre[v],s1,s2});
forall(it,graph[v]) if(it != pop && !odw[it]) dfs_add(it,v,c,s1,s2);
}
void centroid(int v, int n)
{
dfs_sub(v,v);
int pop = v;
while(true)
{
pii best = {-1,-1};
forall(it,graph[v]) if(it != pop && !odw[it]) best = max(best,{sub[it],it});
if(best.ff > n/2)
{
pop = v;
v = best.ss;
}
else break;
}
odw[v] = 1;
centr_str[v].setup(n);
dfs_sub(v,v);
cur_pre = 1;
dfs_pre(v,v);
forall(it,graph[v]) if(!odw[it]) dfs_add(it,v,v,pre[it],maxpre[it]);
elms[v].pb({v,1,maxpre[v],1,maxpre[v]});
forall(it,graph[v]) if(!odw[it]) centroid(it,sub[it]);
}
void change_vert(int v, int f)
{
forall(it,elms[v])
{
if(F[v] != -1) centr_str[it.centr].set_vert(it.p1,it.p2,it.s1,it.s2,F[v],-1);
centr_str[it.centr].set_vert(it.p1,it.p2,it.s1,it.s2,f,1);
}
F[v] = f;
}
void init(int n, vi F2, vi U, vi V, int Q)
{
rep(i,n) F[i] = -1;
rep(i,n-1)
{
graph[U[i]].pb(V[i]);
graph[V[i]].pb(U[i]);
}
centroid(0,n);
rep(i,n) change_vert(i,F2[i]+1);
}
void change(int X, int Y)
{
change_vert(X,Y+1);
}
ll num_tours()
{
return cur_ans;
}