#include <bits/stdc++.h>
using namespace std;
const int maxn = 2e5 + 12, lg = 20;
int n, s, L[maxn * 4], R[maxn * 4], sum[maxn * 4], par[maxn][lg], mx[maxn * 4], st[maxn], en[maxn], h[maxn], U[maxn], V[maxn], bcnv[maxn], c[maxn];
vector <int> conn[maxn];
vector <tuple<int, int, int>> seg;
bitset <maxn> marked;
void make_tree(int l, int r, int ind){
int mid = (l + r) / 2;
L[ind] = l;
R[ind] = r;
if(l == r)
return;
make_tree(l, mid, ind * 2);
make_tree(mid + 1, r, ind * 2 + 1);
}
void update_tree_sum(int l, int r, int u, int k, bool add = 0){
if(r < L[u] || R[u] < l)
return;
if(l <= L[u] && R[u] <= r){
if(add == 0)
sum[u] = k;
else
sum[u] += k;
return;
}
update_tree_sum(l, r, u * 2, k, add);
update_tree_sum(l, r, u * 2 + 1, k, add);
sum[u] = (sum[u * 2] + sum[u * 2 + 1]);
}
void update_tree_mx(int l, int r, int u, int k){
if(r < L[u] || R[u] < l)
return;
if(l <= L[u] && R[u] <= r){
mx[u] = k;
return;
}
update_tree_mx(l, r, u * 2, k);
update_tree_mx(l, r, u * 2 + 1, k);
mx[u] = max(mx[u * 2], mx[u * 2 + 1]);
}
int get_sum(int l, int r, int u){
if(r < L[u] || R[u] < l)
return 0;
if(l <= L[u] && R[u] <= r)
return sum[u];
return get_sum(l, r, u * 2) + get_sum(l, r, u * 2 + 1);
}
int get_max(int l, int r, int u){
if(r < L[u] || R[u] < l)
return 0;
if(l <= L[u] && R[u] <= r)
return mx[u];
return max(get_max(l, r, u * 2), get_max(l, r, u * 2 + 1));
}
int kpar(int u, int k){
for(int i = 0; i < lg; i++)
if((k >> i) & 1)
u = par[u][i];
return u;
}
int lca(int a, int b){
if(h[a] < h[b])
swap(a, b);
// cout << a << '-' << b << endl;
// cout << h[a] << ' ' << h[b] << endl;
a = kpar(a, h[a] - h[b]);
if(a == b)
return a;
//cout << a << '-' << b << endl;
for(int i = lg - 1; i >= 0; i--)
if(par[a][i] != par[b][i]){
a = par[a][i];
b = par[b][i];
}
return par[a][0];
}
void dfs_set(int u){
marked[u] = 1;
for(int i = 1; i < lg; i++)
par[u][i] = par[par[u][i - 1]][i - 1];
st[u] = s;
s += 1;
for(int v: conn[u])
if(!marked[v]){
h[v] = h[u] + 1;
// cout << u << ' ' << v << ' ' << h[v] << endl;
par[v][0] = u;
dfs_set(v);
}
en[u] = s - 1;
}
void get_seg(int u){
seg.clear();
int now = 1, res = -1;
while(res != u){
// cout << st[now] << ' ' << en[now] << endl;
res = get_max(st[now], en[now], 1);
// cout << res << endl;
res = bcnv[res];
// cout << res << endl;
//cout << u << ' ' << res << endl;
int nc = c[res];
res = lca(res, u);
// cout << res << endl;
// cout << now << ' ' << res << endl;
seg.push_back(make_tuple(now, res, nc));
if(res != u)
now = kpar(u, h[u] - h[res] - 1);
//cout << now << endl;
}
return;
}
long long get_ans(){
long long nat = 0;
for(auto i: seg){
// cout << get<0>(i) << ' ' << get<1>(i) << ' ' << get<2>(i) << endl;
int len = h[get<1>(i)] - h[get<0>(i)] + 1;
nat += get_sum(get<2>(i) + 1, n, 1) * len;
update_tree_sum(get<2>(i), get<2>(i), 1, len, 1);
}
for(auto i: seg)
update_tree_sum(get<2>(i), get<2>(i), 1, 0);
return nat;
}
int main(){
ios_base::sync_with_stdio(false), cin.tie(0);
cin >> n;
for(int i = 0; i < n; i++)
cin >> c[i + 1];
for(int i = 0; i < n - 1; i++){
cin >> U[i] >> V[i];
conn[U[i]].push_back(V[i]);
conn[V[i]].push_back(U[i]);
}
dfs_set(1);
make_tree(0, n, 1);
update_tree_mx(st[1], st[1], 1, 1);
bcnv[1] = 1;
for(int i = 0; i < n - 1; i++){
long long ans = 0;
get_seg(U[i]);
ans += get_ans();
// cout << st[V[i]] << endl;
update_tree_mx(st[V[i]], st[V[i]], 1, i + 2);
// cout << get_max(0, 4, 1) << endl;
bcnv[i + 2] = V[i];
cout << ans << "\n";
// cout << endl;
// cout << endl;
//if(i == 2)
// return 0;
}
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
3 ms |
5076 KB |
Output is correct |
2 |
Incorrect |
3 ms |
5076 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
3 ms |
5076 KB |
Output is correct |
2 |
Incorrect |
3 ms |
5076 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
3 ms |
5076 KB |
Output is correct |
2 |
Incorrect |
3 ms |
5076 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |