//Huyduocdithitp
#include <bits/stdc++.h>
typedef long long ll;
#define pii pair<ll, ll>
#define fi first
#define se second
#define TASK "mansion"
#define start if(fopen(TASK".in","r")){freopen(TASK".in","r",stdin);freopen(TASK".out","w",stdout);}
#define faster ios_base::sync_with_stdio(false);cin.tie(NULL);
#define N 200005
#define endl '\n'
using namespace std;
ll n, c[N], par[N], bit[N], m;
pii canh[N];
void update(ll u, ll v) {
ll idx = u;
while (idx <= m) {
bit[idx] += v;
idx += idx & (-idx);
}
}
ll get(ll u) {
ll idx = u, ans = 0;
while (idx > 0) {
ans += bit[idx];
idx -= idx & (-idx);
}
return ans;
}
vector<ll> ns;
vector<ll> adj[N];
ll cur_chain, cur_pos, chain_head[N], chain_id[N], tour[N], pos[N], sz[N], high[N];
stack<pii> st[N]; // stack cua nhung con cha, no se luon la con dau chain
void dfs(ll u, ll parent) {
sz[u] = 1;
for (auto v : adj[u]) {
if (v == parent) continue;
par[v] = u;
high[v] = high[u] + 1;
dfs(v, u);
sz[u] += sz[v];
}
}
void hld(ll u, ll parent) {
if (chain_head[cur_chain] == 0) {
chain_head[cur_chain] = u;
}
pos[u] = cur_pos;
tour[cur_pos] = u;
cur_pos ++ ;
chain_id[u] = cur_chain;
ll nxt = 0;
for (auto v : adj[u]) {
if (v == parent) continue;
if (sz[v] > sz[nxt]) {
nxt = v;
}
}
if (nxt) {
hld(nxt, u);
}
for (auto v : adj[u]) {
if (v == parent || v == nxt) continue;
cur_chain ++ ;
hld(v, u);
}
}
vector<pii> vt, vt1; // xet.fi luu gia tri, xet.se luu so con
void lay(ll u, ll v, bool first) {
ll ugoc = u;
ll root = chain_head[chain_id[u]];
ll socon = high[u] - high[root] + 1;
if (first && chain_id[u] == chain_id[v]) socon ++ ;
while (st[root].size() && socon > 0) {
pii xet = st[root].top(); st[root].pop();
ll x = min(socon, xet.se);
vt.push_back({xet.fi, x});
socon -= x;
xet.se -= x;
if (xet.se > 0) {
st[root].push(xet);
}
if (socon == 0) break;
}
if (first) {
if (chain_id[u] == chain_id[v]) {
st[root].push({c[v], high[u] - high[root] + 2});
}
else {
st[root].push({c[v], high[u] - high[root] + 1});
ll root_v = chain_head[chain_id[v]];
if (st[root_v].size()) {
st[root_v].top().se -- ;
if (st[root_v].top().se == 0) st[root_v].pop();
}
st[root_v].push({c[v], 1});
}
return;
}
st[root].push({c[v], high[u] - high[root] + 1});
}
void cv(ll u, ll v) {
ll ugoc = u;
//cout << "u v " << u << " " << v << endl;
// di tu u len cha
bool first = 1;
while (true) {
if (u == 0) break;
vt.clear();
lay(u, v, first);
first = 0;
for (int i = vt.size() - 1; i >= 0; i --) {
vt1.push_back(vt[i]);
}
u = par[chain_head[chain_id[u]]];
if (u == 0) break;
}
ll ans = 0;
for (int i = 0; i < vt1.size(); i ++) {
pii xet = vt1[i];
ll to = get(xet.fi - 1);
ans += to * xet.se;
update(xet.fi, xet.se);
}
for (int i = 0; i < vt1.size(); i ++) {
pii xet = vt1[i];
update(xet.fi, -xet.se);
}
vt1.clear();
vt.clear();
cout << ans << endl;
}
void pre() {
cur_chain = cur_pos = 1;
dfs(1, 1);
hld(1, 1);
sort(ns.begin(), ns.end());
ns.resize(unique(ns.begin(), ns.end()) - ns.begin());
m = ns.size();
for (int i = 1; i <= n; i ++) {
c[i] = lower_bound(ns.begin(), ns.end(), c[i]) - ns.begin() + 1;
}
for (int i = 1; i <= n - 1; i ++) {
cv(canh[i].fi, canh[i].se);
}
}
signed main() {
faster;
cin >> n;
for (int i = 1; i <= n; i ++) {
cin >> c[i];
ns.push_back(c[i]);
}
for (int i = 1; i <= n - 1; i ++) {
cin >> canh[i].fi >> canh[i].se;
adj[canh[i].fi].push_back(canh[i].se);
adj[canh[i].se].push_back(canh[i].fi);
}
pre();
return 0;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |