#include "joitour.h"
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
vector<int> f;
vector<vector<int>> al;
int n;
ll ans=0;
vector<array<ll, 4>> dp;
void dfs(int x, int p){
if(f[x] == 0){
dp[x][0]++;
}
if(f[x] == 2){
dp[x][1]++;
}
for(auto it : al[x]){
if(it == p)continue;
dfs(it,x);
ans += dp[x][2] * dp[it][1];
ans += dp[x][3] * dp[it][0];
ans += dp[it][2] * dp[x][1];
ans += dp[it][3] * dp[x][0];
if(f[x] == 0){
//printf("HERE\n");
//ans += dp[it][3];
}
if(f[x] == 1){
dp[x][2] += dp[it][0];
dp[x][3] += dp[it][1];
}
if(f[x] == 2){
//ans += dp[it][2];
}
dp[x][0] += dp[it][0];
dp[x][1] += dp[it][1];
dp[x][2] += dp[it][2];
dp[x][3] += dp[it][3];
}
//printf("x %d, ans %lld, f[x] %d , p0 %lld, p1 %lld, p2 %lld, p3 %lld\n",
//x,ans, f[x], dp[x][0], dp[x][1],dp[x][2],dp[x][3]);
}
void init(int N, vector<int> F, vector<int> U, vector<int> V,
int Q) {
n=N;
f=F;
al.resize(n+1);
dp.assign(n+1, array<ll,4>{0ll});
for(int i=0;i<n-1;i++){
al[U[i]].push_back(V[i]);
al[V[i]].push_back(U[i]);
}
}
void change(int X, int Y) {
ans=0;
f[X]=Y;
}
long long num_tours() {
dp.assign(n+1, array<ll,4>{0ll});
dfs(0, 0);
return ans;
}