#include <bits/stdc++.h>
using namespace std;
typedef long long int ll;
const int P = 17;
const int N = 1<<P;
struct SegTree {
int Tree[2*N];
void update(int u, int val) {
u += N;
Tree[u] = val;
u /= 2;
while(u > 0) {
Tree[u] = max(Tree[2*u], Tree[2*u + 1]);
u /= 2;
}
}
int query(int a, int b) {
a += N;
b += N;
int sol = max(Tree[a], Tree[b]);
while(a < b-1) {
if(a / 2 == (a+1) / 2)
sol = max(sol, Tree[a+1]);
if(b / 2 == (b-1) / 2)
sol = max(sol, Tree[b-1]);
a /= 2;
b /= 2;
}
return sol;
}
};
SegTree T1;
vector<int> Graph[N];
vector<pair<int, int>> Edges;
int Lift[N][P];
pair<int, int> Segment[N];
int Depth[N];
int C[N];
int Color[N];
int pre_cnt = 1;
void DFS(int u, int parent)
{
Depth[u] = Depth[parent] + 1;
Lift[u][0] = parent;
for(int i = 1; i < P; i++)
Lift[u][i] = Lift[Lift[u][i-1]][i-1];
Segment[u].first = pre_cnt++;
for(auto v : Graph[u]) {
if(v != parent)
DFS(v, u);
}
Segment[u].second = pre_cnt - 1;
}
ll Find(vector<pair<int, int>>& V)
{
int n = V.size();
if(n == 1)
return 0;
vector<pair<int, int>> A, B;
for(int i = 0; i < n; i++) {
if(i < n/2)
A.push_back(V[i]);
else
B.push_back(V[i]);
}
ll sol = Find(A) + Find(B);
int idx = 0;
ll cnt = 0;
V.clear();
for(int i = 0; i < B.size(); i++) {
while(idx < A.size() && A[idx].first < B[i].first) {
V.push_back(A[idx]);
cnt += A[idx].second;
idx++;
}
sol += (ll)B[i].second * cnt;
V.push_back(B[i]);
}
while(idx < A.size())
V.push_back(A[idx++]);
return sol;
}
void Solve(int n)
{
vector<pair<int, int>> Values;
Color[0] = C[1];
int time = 1;
for(auto [a, b] : Edges) {
int u = a;
int col = T1.query(Segment[a].first, Segment[a].second);
while(u != 0) {
int v = u;
for(int i = P-1; i >= 0; i--) {
if(Lift[u][i] != 0 && T1.query(Segment[Lift[u][i]].first, Segment[Lift[u][i]].second) == col)
u = Lift[u][i];
}
Values.push_back({Color[col], Depth[v] - Depth[u] + 1});
u = Lift[u][0];
col = T1.query(Segment[u].first, Segment[u].second);
}
cout << Find(Values) << "\n";
T1.update(Segment[b].first, time);
Color[time] = C[b];
time++;
Values.clear();
}
}
int main()
{
ios_base::sync_with_stdio(false);
cin.tie(0);
int n, a, b;
cin >> n;
for(int i = 1; i <= n; i++)
cin >> C[i];
for(int i = 0; i < n-1; i++) {
cin >> a >> b;
Graph[a].push_back(b);
Graph[b].push_back(a);
Edges.push_back({a, b});
}
DFS(1, 0);
Solve(n);
}
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |