#include "bits/stdc++.h"
using namespace std;
#ifdef Nero
#include "Deb.h"
#else
#define deb(...)
#endif
const int N = 1e5 + 5;
int a[N];
int tree[N];
int pr[N], sz[N], dep[N];
int cnt, id[N], top[N], chain[N];
deque<pair<int, int>> in_chain[N];
vector<int> g[N];
inline void upd (int idd, int v) {
while (idd < N) {
tree[idd] += v;
idd += idd & -idd;
}
}
inline long long qry (int idd) {
long long ret = 0;
while (idd) {
ret += tree[idd];
idd -= idd & -idd;
}
return ret;
}
void dfs(int v, int p) {
sz[v] = 1;
for (int u : g[v]) {
if (u == p) {
continue;
}
pr[u] = v;
dep[u] = dep[v] + 1;
dfs(u, v);
sz[v] += sz[u];
}
}
void dfs_hld(int v, int p, int tp) {
top[v] = tp;
deque<pair<int, int>>& deq = in_chain[chain[v]];
if (deq.empty() || deq.back().first != a[v]) {
deq.emplace_back(a[v], 1);
} else {
deq.back().second++;
}
int big = 0;
for (int u : g[v]) {
if (u != p && sz[u] > sz[big]) {
big = u;
}
}
if (big) {
chain[big] = chain[v];
id[big] = id[v] + 1;
dfs_hld(big, v, tp);
}
for (int u : g[v]) {
if (u != p && u != big) {
id[u] = 1;
chain[u] = cnt++;
dfs_hld(u, v, u);
}
}
}
vector<vector<pair<int, int>>> get(int v, int c) {
vector<vector<pair<int, int>>> ret;
while (v) {
vector<pair<int, int>> tv;
int to_delete = id[v];
int deleted = 0;
deque<pair<int, int>>& deq = in_chain[chain[v]];
while (!deq.empty() && deleted < to_delete) {
if (deq.front().second + deleted > to_delete) {
tv.push_back({deq.front().first, to_delete - deleted});
deq.front().second -= (to_delete - deleted);
deleted = to_delete;
break;
} else {
tv.push_back(deq.front());
deleted += deq.front().second;
deq.pop_front();
}
}
//assert(to_delete == deleted);
deq.push_front({c, to_delete});
ret.push_back(tv);
v = pr[top[v]];
}
return ret;
}
long long solve(vector<vector<pair<int, int>>> vec) {
reverse(vec.begin(), vec.end());
long long ret = 0;
for (int i = 0; i < (int) vec.size(); ++i) {
for (auto [c, freq] : vec[i]) {
ret += qry(N - 1) - qry(c);
upd(c, freq);
}
}
for (int i = 0; i < (int) vec.size(); ++i) {
for (auto [c, freq] : vec[i]) {
upd(c, -freq);
}
}
return ret;
}
int main(){
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n;
cin >> n;
for (int i = 1; i <= n; ++i) {
cin >> a[i];
}
vector<int> qu(n - 1), qv(n - 1);
for (int i = 0; i < n - 1; ++i) {
int u, v;
cin >> u >> v;
qu[i] = u, qv[i] = v;
g[u].push_back(v);
g[v].push_back(u);
}
dfs(1, 1);
id[1] = 1;
cnt = 2;
chain[1] = 1;
dfs_hld(1, 1, 1);
for (int i = 0; i < n - 1; ++i) {
vector<vector<pair<int, int>>> vec = get(qu[i], a[qv[i]]);
cout << solve(vec) << '\n';
}
return 0;
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Runtime error |
79 ms |
146504 KB |
Execution killed with signal 11 |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Runtime error |
79 ms |
146504 KB |
Execution killed with signal 11 |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Runtime error |
79 ms |
146504 KB |
Execution killed with signal 11 |
2 |
Halted |
0 ms |
0 KB |
- |