#include "joitour.h"
#include<bits/stdc++.h>
using namespace std;
#define F first
#define S second
#define pll pair<ll, ll>
#define vll vector<ll>
#define pb push_back
typedef long long ll;
const ll mxN=2e5+5;
vll adj[mxN];
ll dp[mxN][5];
ll ans;
ll a[mxN];
ll sum[3];
ll con[mxN];
void dfs(ll cur, ll par){
for(ll i=0;i<5;i++) dp[cur][i]=0;
dp[cur][a[cur]]++;
sum[a[cur]]++;
for(auto &chd:adj[cur]){
if(chd==par) continue;
dfs(chd, cur);
for(ll i=0;i<3;i++){
dp[cur][i]+=dp[chd][i];
}
}
}
void dfs3(ll cur, ll par){
if(a[cur]==1){
dp[cur][3]=dp[cur][0];
dp[cur][4]=dp[cur][2];
}
for(auto &chd:adj[cur]){
if(chd==par) continue;
dfs3(chd, cur);
for(ll i=3;i<5;i++){
dp[cur][i]+=dp[chd][i];
}
}
}
void dfs2(ll cur, ll par){
// if(a[cur]==1){
// for(auto &chd:adj[cur]){
// if(chd==par) continue;
// ans+=dp[chd][0]*(sum[2]-dp[chd][2]);
// }
// ans+=(sum[0]-dp[cur][0])*dp[cur][2];
// }
con[cur]=0;
for(auto &chd:adj[cur]){
if(chd==par) continue;
dfs2(chd, cur);
con[cur]+=dp[chd][3]*(dp[cur][2]-dp[chd][2]);
con[cur]+=dp[chd][4]*(dp[cur][0]-dp[chd][0]);
if(a[cur]==1) con[cur]+=dp[chd][0]*(dp[cur][2]-dp[chd][2]);
}
ans+=con[cur];
}
void init(int n, vector<int> tep, vector<int> u, vector<int> v, int q) {
for(ll i=0;i<n;i++){
a[i]=tep[i];
}
for(ll i=0;i<n-1;i++){
adj[u[i]].pb(v[i]);
adj[v[i]].pb(u[i]);
}
}
void change(int X, int Y) {
a[X]=Y;
}
long long num_tours() {
ans=0;
for(ll i=0;i<3;i++) sum[i]=0;
dfs(0, -1);
dfs3(0, -1);
dfs2(0, -1);
return ans;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |