#include<bits/stdc++.h>
using namespace std;
#define F first
#define S second
#define all(x) x.begin(),x.end()
#define pii pair<long long,long long>
#define pb push_back
#define sz(x) (long long)(x.size())
#define chmin(x,y) x=min(x,y)
#ifdef local
void CHECK();
void setio(){
freopen("/Users/iantsai/cpp/input.txt","r",stdin);
freopen("/Users/iantsai/cpp/output.txt","w",stdout);
}
#else
void setio(){}
#endif
#define TOI_is_so_de ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);setio();
const long long mxn = 2e5 + 5;
struct BIT{
vector<vector<long long>>bit;
vector<long long>sum;
void init(long long tmr){
sum = vector<long long>(3);
bit = vector<vector<long long>>(3, vector<long long>(tmr + 1));
}
void m(long long d, long long p, long long v){
sum[d] += v;
for(; p < sz(bit[d]); p += p & -p) bit[d][p] += v;
}
void m(long long d, long long l, long long r, long long v){
if(d == 1) sum[d] += v;
m(d, l, v);
m(d, r + 1, -v);
}
long long q(long long d, long long p){
long long r = 0;
for(; p; p -= p & -p) r += bit[d][p];
return r;
}
long long q(long long d, long long l, long long r){
return q(d, r) - q(d, l - 1);
}
};
long long ans = 0, ty[mxn], cnt[mxn], tmr, neg[mxn];
vector<long long>adj[mxn], p[mxn], in[mxn], out[mxn], num[mxn], ord;
int cal[mxn][3];
vector<vector<long long>>sum[mxn], s2[mxn];
bool vis[mxn];
BIT bt[mxn];
void upd(long long v, long long d, bool f){
if(d == -1){
//cout << ans << '\n';
long long pa = p[v].back(), l = in[v].back(), r = out[v].back(), id = num[v].back();
if(ty[v] == 1) ans += d * (bt[pa].sum[0] * bt[pa].sum[2] - neg[pa]);
else if(ty[v] == 0){
ans += d * cal[pa][2];
}
else if(ty[v] == 2){
ans += d * cal[pa][0];
}
//cout << v << ' ' << pa << ' ' << bt[pa].sum[0] << ' ' << bt[pa].sum[2] << ' ' << neg[pa] << ' ' << cal[pa][0] << ' ' << cal[pa][2] << ' ' << ans << '\n';
}
for(long long _ = 0; _ < sz(p[v]) - 1; _++){
long long pa = p[v][_], l = in[v][_], r = out[v][_], id = num[v][_];
//cout << v << ' ' << pa << ' ' << l << ' ' << r << ' ' << id << ' ' << ans << '\n';
if(ty[v] == 0){
if(d == 1){
bt[pa].m(0, l, d);
sum[pa][id][0] += d;
neg[pa] += sum[pa][id][2] * d;
}
int val = d * bt[pa].q(1, l);//10
cal[pa][0] += val;
s2[pa][id][0] += val;
ans += (bt[pa].sum[2] - sum[pa][id][2]) * val;
ans += (cal[pa][2] - s2[pa][id][2]) * d;
if(d == -1){
bt[pa].m(0, l, d);
sum[pa][id][0] += d;
neg[pa] += sum[pa][id][2] * d;
}
}
else if(ty[v] == 1){
if(d == 1){
bt[pa].m(1, l, r, d);
sum[pa][id][1] += d;
}
int val = bt[pa].q(0, l, r) * d;//10
cal[pa][0] += val;//10
s2[pa][id][0] += val;
ans += val * (bt[pa].sum[2] - sum[pa][id][2]);
val = bt[pa].q(2, l, r) * d;//12
cal[pa][2] += val;//12
s2[pa][id][2] += val;
ans += val * (bt[pa].sum[0] - sum[pa][id][0]);
if(d == -1){
bt[pa].m(1, l, r, d);
sum[pa][id][1] += d;
}
}
else{
if(d == 1){
bt[pa].m(2, l, d);
sum[pa][id][2] += d;
neg[pa] += sum[pa][id][0] * d;
}
int val = d * bt[pa].q(1, l);//12
cal[pa][2] += val;
s2[pa][id][2] += val;
ans += (bt[pa].sum[0] - sum[pa][id][0]) * val;
ans += (cal[pa][0] - s2[pa][id][0]) * d;
if(d == -1){
bt[pa].m(2, l, d);
sum[pa][id][2] += d;
neg[pa] += sum[pa][id][0] * d;
}
}
}
if(d == 1){
//cout << ans << '\n';
long long pa = p[v].back(), l = in[v].back(), r = out[v].back(), id = num[v].back();
if(ty[v] == 1) ans += d * (bt[pa].sum[0] * bt[pa].sum[2] - neg[pa]);
else if(ty[v] == 0){
ans += d * (cal[pa][2]);
}
else if(ty[v] == 2){
ans += d * (cal[pa][0]);
}
//cout << v << ' ' << pa << ' ' << bt[pa].sum[0] << ' ' << bt[pa].sum[2] << ' ' << neg[pa] << ' ' << cal[pa][0] << ' ' << cal[pa][2] << ' ' << ans << '\n';
}
}
void cnt_sz(long long v, long long pa){
cnt[v] = 1;
for(auto u : adj[v]){
if(u == pa or vis[u]) continue;
cnt_sz(u, v);
cnt[v] += cnt[u];
}
}
long long find(long long v, long long pa, long long tar){
for(auto u : adj[v]){
if(u == pa or vis[u]) continue;
if(cnt[u] * 2 > tar) return find(u, v, tar);
}
return v;
}
void dfs(long long v, long long pa, long long top, long long cc){
p[v].pb(top);
in[v].pb(++tmr);
num[v].pb(cc);
for(auto u : adj[v]){
if(u == pa or vis[u]) continue;
dfs(u, v, top, cc);
}
out[v].pb(tmr);
}
void cd(long long v){
cnt_sz(v, v);
v = find(v, v, cnt[v]);
vis[v] = 1;
ord.pb(v);
tmr = 0;
long long cc = 0;
p[v].pb(v);
in[v].pb(++tmr);
num[v].pb(-1);
for(auto u : adj[v]){
if(vis[u]) continue;
dfs(u, v, v, cc);
cc++;
}
sum[v] = vector<vector<long long>>(cc, vector<long long>(3));
s2[v] = vector<vector<long long>>(cc, vector<long long>(3));
out[v].pb(tmr);
for(long long i = 0; i < 3; i ++){
bt[v].init(tmr);
}
for(auto u : adj[v]){
if(vis[u]) continue;
cd(u);
}
}
#ifdef local
void solve(){
long long n;
cin >> n;
for(long long i = 1; i <= n; i++){
cin >> ty[i];
}
for(long long i = 1; i < n; i++){
long long a, b;
cin >> a >> b;
a++, b++;
adj[a].pb(b);
adj[b].pb(a);
}
cd(1);
reverse(all(ord));
for(auto v : ord){
upd(v, 1, 0);
}
cout << "prt" << ans << '\n' << '\n';
long long q;
cin >> q;
while(q--){
int x, y;
cin >> x >> y;
x++;
upd(x, -1, 1);
ty[x] = y;
upd(x, 1, 1);
cout << "prt" << ans << '\n';
}
}
signed main(){
TOI_is_so_de;
long long t = 1;
//cin >> t;
while(t--){
solve();
}
#ifdef local
CHECK();
#endif
}
#else
#include "joitour.h"
void init(int n, vector<int>typ, vector<int>ea, vector<int>eb, int q){
for(long long i = 1; i <= n; i ++){
ty[i] = typ[i - 1];
}
for(long long i = 1; i < n - 1; i++){
adj[++ea[i - 1]].pb(++eb[i - 1]);
adj[eb[i - 1]].pb(ea[i - 1]);
}
cd(1);
reverse(all(ord));
for(auto v : ord){
upd(v, 1, 0);
}
}
void change(int x, int y){
upd(x + 1, -1, 1);
ty[x + 1] = y;
upd(x + 1, 1, 1);
}
long long num_tours(){
return ans;
}
#endif
/*
input:
*/
#ifdef local
void CHECK(){
cerr << "\n[Time]: " << 1000.0 * clock() / CLOCKS_PER_SEC << " ms.\n";
function<bool(string,string)> compareFiles = [](string p1, string p2)->bool {
std::ifstream file1(p1);
std::ifstream file2(p2);
if(!file1.is_open() || !file2.is_open()) return false;
std::string line1, line2;
while (getline(file1, line1) && getline(file2, line2)) {
if (line1 != line2)return false;
}
long long cnta = 0, cntb = 0;
while(getline(file1,line1))cnta++;
while(getline(file2,line2))cntb++;
return cntb - cnta <= 1;
};
bool check = compareFiles("output.txt","expected.txt");
if(check) cerr<<"ACCEPTED\n";
else cerr<<"WRONG ANSWER!\n";
}
#else
#endif
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |