This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
#include "joitour.h"
#include "bits/stdc++.h"
#define dbg(x) "[" << #x << " = " << (x) << "]"
using namespace std;
const int MAX = 2e5 + 5;
const int num10 = 0;
const int num12 = 1;
const int num0 = 2;
const int num2 = 3;
const int num1 = 4;
int N, a[MAX], par[MAX], depth[MAX], tin[18][MAX], tout[18][MAX], sz[MAX], id_subtree[18][MAX], timer_subtree, timer_dfs;
vector<int> adj[MAX];
bool used[MAX];
long long cnt_total[4][MAX], cnt_subtree[4][MAX], num_ways[MAX], ans;
//num_ways mean case f[c] = 1
struct Fenwick{
vector<int> bit;
Fenwick(int n) : bit(n + 1, 0) {}
Fenwick() : bit() {}
void update(int id, int val){
for(; id < (int)bit.size(); id += id & (-id)){
bit[id] += val;
}
}
int query(int id){
int sum = 0;
for(; id > 0; id -= id & (-id)){
sum += bit[id];
}
return sum;
}
void updateRange(int l, int r, int val){
update(l, val);
update(r + 1, -val);
}
int queryRange(int l, int r){
return query(r) - query(l - 1);
}
};
Fenwick ft[5][MAX];
int dfs_size(int u, int pre){
sz[u] = 1;
for(int v : adj[u]) if(v != pre && !used[v]){
sz[u] += dfs_size(v, u);
}
return sz[u];
}
int find_centroid(int u, int pre, int target){
for(int v : adj[u]) if(v != pre && !used[v] && sz[v] > (target / 2)){
return find_centroid(v, u, target);
}
return u;
}
int c;
void dfs_subtree(int u, int pre, int layer){
tin[layer][u] = ++timer_dfs;
id_subtree[layer][u] = timer_subtree;
if(a[u] == 0){
ft[num0][c].update(tin[layer][u], +1);
++cnt_total[num0][c];
++cnt_subtree[num0][timer_subtree];
} else if(a[u] == 2){
ft[num2][c].update(tin[layer][u], +1);
++cnt_total[num2][c];
++cnt_subtree[num2][timer_subtree];
}
for(int v : adj[u]) if(v != pre && !used[v]){
dfs_subtree(v, u, layer);
if(pre == -1) ++timer_subtree;
}
tout[layer][u] = timer_dfs;
if(a[u] == 1){
ft[num1][c].updateRange(tin[layer][u], tout[layer][u], +1);
int cnt10 = ft[num0][c].queryRange(tin[layer][u], tout[layer][u]);
int cnt12 = ft[num2][c].queryRange(tin[layer][u], tout[layer][u]);
cnt_subtree[num10][timer_subtree] += cnt10;
cnt_subtree[num12][timer_subtree] += cnt12;
cnt_total[num10][c] += cnt10;
cnt_total[num12][c] += cnt12;
}
}
void decompose(int u, int layer, int p){
int size = dfs_size(u, -1);
u = find_centroid(u, -1, size);
used[u] = true;
par[u] = p;
depth[u] = layer;
ft[num10][u] = Fenwick(size);
ft[num12][u] = Fenwick(size);
ft[num0][u] = Fenwick(size);
ft[num1][u] = Fenwick(size);
ft[num2][u] = Fenwick(size);
c = u;
timer_dfs = 0;
for(int v : adj[u]) if(!used[v]){
dfs_subtree(v, u, layer);
num_ways[u] -= cnt_subtree[num0][timer_subtree] * cnt_subtree[num2][timer_subtree];
ans -= cnt_subtree[num0][timer_subtree] * cnt_subtree[num12][timer_subtree];
ans -= cnt_subtree[num2][timer_subtree] * cnt_subtree[num10][timer_subtree];
++timer_subtree;
}
num_ways[u] += cnt_total[num0][u] * cnt_total[num2][u];
ans += cnt_total[num0][u] * cnt_total[num12][u];
ans += cnt_total[num2][u] * cnt_total[num10][u];
if(a[u] == 0) ans += cnt_total[num12][u];
if(a[u] == 1) ans += num_ways[u];
if(a[u] == 2) ans += cnt_total[num10][u];
for(int v : adj[u]) if(!used[v]){
decompose(v, layer + 1, u);
}
}
void init(int N, std::vector<int> F, std::vector<int> U, std::vector<int> V, int Q) {
::N = N;
for(int i = 0; i < N; ++i) a[i] = F[i];
for(int i = 0; i < N - 1; ++i){
int u = U[i], v = V[i];
adj[u].push_back(v);
adj[v].push_back(u);
}
decompose(0, 0, -1);
}
void change(int u, int nval) {
if(a[u] == nval) return;
int dep = depth[u], c = u;
while(c != -1){
if(c == u){
if(a[u] == 0) ans -= cnt_total[num12][u];
if(a[u] == 1) ans -= num_ways[u];
if(a[u] == 2) ans -= cnt_total[num10][u];
if(nval == 0) ans += cnt_total[num12][u];
if(nval == 1) ans += num_ways[u];
if(nval == 2) ans += cnt_total[num10][u];
} else{
int id = id_subtree[dep][u];
///remove data
if(a[c] == 0) ans -= cnt_total[num12][c];
if(a[c] == 1) ans -= num_ways[c];
if(a[c] == 2) ans -= cnt_total[num10][c];
num_ways[c] += cnt_subtree[num0][id] * cnt_subtree[num2][id];
ans += cnt_subtree[num0][id] * cnt_subtree[num12][id];
ans += cnt_subtree[num2][id] * cnt_subtree[num10][id];
num_ways[c] -= cnt_total[num0][c] * cnt_total[num2][c];
ans -= cnt_total[num0][c] * cnt_total[num12][c];
ans -= cnt_total[num2][c] * cnt_total[num10][c];
if(a[u] == 0){
int contribution = ft[num1][c].query(tin[dep][u]);
cnt_subtree[num10][id] -= contribution;
--cnt_subtree[num0][id];
cnt_total[num10][c] -= contribution;
--cnt_total[num0][c];
ft[num0][c].update(tin[dep][u], -1);
} else if(a[u] == 1){
int cnt10 = ft[num0][c].queryRange(tin[dep][u], tout[dep][u]);
int cnt12 = ft[num2][c].queryRange(tin[dep][u], tout[dep][u]);
cnt_subtree[num10][id] -= cnt10;
cnt_subtree[num12][id] -= cnt12;
cnt_total[num10][c] -= cnt10;
cnt_total[num12][c] -= cnt12;
ft[num1][c].updateRange(tin[dep][u], tout[dep][u], -1);
} else{
int contribution = ft[num1][c].query(tin[dep][u]);
cnt_subtree[num12][id] -= contribution;
--cnt_subtree[num2][id];
cnt_total[num12][c] -= contribution;
--cnt_total[num2][c];
ft[num2][c].update(tin[dep][u], -1);
}
//end remove data
//replace data
if(nval == 0){
int contribution = ft[num1][c].query(tin[dep][u]);
cnt_subtree[num10][id] += contribution;
++cnt_subtree[num0][id];
cnt_total[num10][c] += contribution;
++cnt_total[num0][c];
ft[num0][c].update(tin[dep][u], +1);
} else if(nval == 1){
int cnt10 = ft[num0][c].queryRange(tin[dep][u], tout[dep][u]);
int cnt12 = ft[num2][c].queryRange(tin[dep][u], tout[dep][u]);
cnt_subtree[num10][id] += cnt10;
cnt_subtree[num12][id] += cnt12;
cnt_total[num10][c] += cnt10;
cnt_total[num12][c] += cnt12;
ft[num1][c].updateRange(tin[dep][u], tout[dep][u], +1);
} else{
int contribution = ft[num1][c].query(tin[dep][u]);
cnt_subtree[num12][id] += contribution;
++cnt_subtree[num2][id];
cnt_total[num12][c] += contribution;
++cnt_total[num2][c];
ft[num2][c].update(tin[dep][u], +1);
}
num_ways[c] += cnt_total[num0][c] * cnt_total[num2][c];
ans += cnt_total[num0][c] * cnt_total[num12][c];
ans += cnt_total[num2][c] * cnt_total[num10][c];
num_ways[c] -= cnt_subtree[num0][id] * cnt_subtree[num2][id];
ans -= cnt_subtree[num0][id] * cnt_subtree[num12][id];
ans -= cnt_subtree[num2][id] * cnt_subtree[num10][id];
if(a[c] == 0) ans += cnt_total[num12][c];
if(a[c] == 1) ans += num_ways[c];
if(a[c] == 2) ans += cnt_total[num10][c];
//end replace data
}
c = par[c]; --dep;
}
a[u] = nval;
}
long long num_tours() {
return ans;
}
/*
Test 1 :
Input
3
0 1 2
0 1
1 2
0
Output :
1
Test 2 :
Input :
3
0 1 2
0 1
1 2
2
2 0
0 2
Output
1
0
1
Test 3 :
Input :
7
1 0 2 2 0 1 0
0 1
0 2
1 3
1 4
2 5
2 6
7
0 0
1 1
2 0
3 0
4 2
5 2
6 2
Ouptut : 3
0
4
4
0
4
5
5
*/
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |