#include<bits/stdc++.h>
using namespace std;
#define F first
#define S second
#define all(x) x.begin(),x.end()
#define pii pair<long long,long long>
#define pb push_back
#define sz(x) (long long)(x.size())
#define chmin(x,y) x=min(x,y)
#ifdef local
void CHECK();
void setio(){
freopen("/Users/iantsai/cpp/input.txt","r",stdin);
freopen("/Users/iantsai/cpp/output.txt","w",stdout);
}
#else
void setio(){}
#endif
#define TOI_is_so_de ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);setio();
const long long mxn = 2e5 + 5;
struct node{
vector<vector<vector<int>>>s;
vector<vector<int>>ss;
int neg;
void init(int siz){
neg = 0;
ss = vector<vector<int>>(2, vector<int>(3));
s = vector<vector<vector<int>>>(2, vector<vector<int>>(siz, vector<int>(3)));
}
};
struct BIT{
vector<vector<long long>>bit;
void init(long long tmr){
bit = vector<vector<long long>>(3, vector<long long>(tmr + 1));
}
void m(long long d, long long p, long long v){
for(; p < sz(bit[d]); p += p & -p) bit[d][p] += v;
}
void m(long long d, long long l, long long r, long long v){
m(d, l, v);
m(d, r + 1, -v);
}
long long q(long long d, long long p){
long long r = 0;
for(; p; p -= p & -p) r += bit[d][p];
return r;
}
long long q(long long d, long long l, long long r){
return q(d, r) - q(d, l - 1);
}
};
long long ans = 0, ty[mxn], cnt[mxn], tmr;
vector<long long>adj[mxn], p[mxn], in[mxn], out[mxn], num[mxn], ord;
long long cal[mxn][3];
node dt[mxn];
bool vis[mxn];
BIT bt[mxn];
void upd(long long v, long long d){
if(d == -1){
long long pa = p[v].back(), l = in[v].back(), r = out[v].back(), id = num[v].back();
if(ty[v] == 1) ans += (dt[pa].ss[0][0] * dt[pa].ss[0][2] - dt[pa].neg) * d;
else{
ans += dt[pa].ss[1][2 - ty[v]] * d;
}
}
for(long long _ = 0; _ < sz(p[v]) - 1; _++){
long long pa = p[v][_], l = in[v][_], r = out[v][_], id = num[v][_];
auto m = [&](){
if(ty[v] == 1){
bt[pa].m(1, l, r, d);
dt[pa].ss[0][1] += d;
dt[pa].s[0][id][1] += d;
}
else{
bt[pa].m(ty[v], l, 1);
dt[pa].s[0][id][ty[v]] += d;
dt[pa].ss[0][ty[v]] += d;
}
};
auto calc = [&](){
if(ty[v] == 1){
for(int j = 0; j < 3; j++){
if(j == 1) continue;
long long val = bt[pa].q(j, l, r);
dt[pa].ss[1][j] += val * d;
dt[pa].s[1][id][j] += val * d;
ans += val * (dt[pa].ss[0][2 - j] - dt[pa].s[0][id][2 - j]);
}
}
else{
dt[pa].neg += dt[pa].s[0][id][2 - ty[v]];
long long val = bt[pa].q(1, l);//10
dt[pa].ss[1][ty[v]] += val * d;
dt[pa].s[1][id][ty[v]] += val * d;
ans += (dt[pa].ss[0][2 - ty[v]] - dt[pa].s[0][id][2 - ty[v]]) * val * d;
ans += (dt[pa].ss[1][2 - ty[v]] - dt[pa].s[1][id][2 - ty[v]]);
}
};
if(d == 1) m();
calc();
if(d == -1) m();
}
if(d == 1){
long long pa = p[v].back(), l = in[v].back(), r = out[v].back(), id = num[v].back();
if(ty[v] == 1) ans += (dt[pa].ss[0][0] * dt[pa].ss[0][2] - dt[pa].neg) * d;
else{
ans += dt[pa].ss[1][2 - ty[v]] * d;
}
}
}
void cnt_sz(long long v, long long pa){
cnt[v] = 1;
for(auto u : adj[v]){
if(u == pa or vis[u]) continue;
cnt_sz(u, v);
cnt[v] += cnt[u];
}
}
long long find(long long v, long long pa, long long tar){
for(auto u : adj[v]){
if(u == pa or vis[u]) continue;
if(cnt[u] * 2 > tar) return find(u, v, tar);
}
return v;
}
void dfs(long long v, long long pa, long long top, long long cc){
p[v].pb(top);
in[v].pb(++tmr);
num[v].pb(cc);
for(auto u : adj[v]){
if(u == pa or vis[u]) continue;
dfs(u, v, top, cc);
}
out[v].pb(tmr);
}
void cd(long long v){
cnt_sz(v, v);
v = find(v, v, cnt[v]);
vis[v] = 1;
ord.pb(v);
tmr = 0;
long long cc = 0;
p[v].pb(v);
in[v].pb(++tmr);
num[v].pb(-1);
for(auto u : adj[v]){
if(vis[u]) continue;
dfs(u, v, v, cc);
cc++;
}
out[v].pb(tmr);
dt[v].init(cc);
for(long long i = 0; i < 3; i ++){
bt[v].init(tmr);
}
for(auto u : adj[v]){
if(vis[u]) continue;
cd(u);
}
}
#ifdef local
void solve(){
long long n;
cin >> n;
for(long long i = 1; i <= n; i++){
cin >> ty[i];
}
for(long long i = 1; i < n; i++){
long long a, b;
cin >> a >> b;
a++, b++;
adj[a].pb(b);
adj[b].pb(a);
}
cd(1);
reverse(all(ord));
for(auto v : ord){
upd(v, 1);
}
cout << "prt" << ans << '\n' << '\n';
long long q;
cin >> q;
while(q--){
long long x, y;
cin >> x >> y;
x++;
upd(x, -1);
ty[x] = y;
upd(x, 1);
cout << "prt" << ans << '\n';
}
}
signed main(){
TOI_is_so_de;
long long t = 1;
//cin >> t;
while(t--){
solve();
}
#ifdef local
CHECK();
#endif
}
#else
#include "joitour.h"
void init(int n, vector<int>typ, vector<int>ea, vector<int>eb, int q){
for(long long i = 1; i <= n; i ++){
ty[i] = typ[i - 1];
}
for(long long i = 1; i < n; i++){
ea[i - 1]++;
eb[i - 1]++;
adj[ea[i - 1]].pb(eb[i - 1]);
adj[eb[i - 1]].pb(ea[i - 1]);
}
cd(1);
reverse(all(ord));
for(auto v : ord){
upd(v, 1);
}
}
void change(int x, int y){
upd(x + 1, -1);
ty[x + 1] = y;
upd(x + 1, 1);
}
long long num_tours(){
return ans;
}
#endif
/*
input:
*/
#ifdef local
void CHECK(){
cerr << "\n[Time]: " << 1000.0 * clock() / CLOCKS_PER_SEC << " ms.\n";
function<bool(string,string)> compareFiles = [](string p1, string p2)->bool {
std::ifstream file1(p1);
std::ifstream file2(p2);
if(!file1.is_open() || !file2.is_open()) return false;
std::string line1, line2;
while (getline(file1, line1) && getline(file2, line2)) {
if (line1 != line2)return false;
}
long long cnta = 0, cntb = 0;
while(getline(file1,line1))cnta++;
while(getline(file2,line2))cntb++;
return cntb - cnta <= 1;
};
bool check = compareFiles("output.txt","expected.txt");
if(check) cerr<<"ACCEPTED\n";
else cerr<<"WRONG ANSWER!\n";
}
#else
#endif
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |