#include <bits/stdc++.h>
using namespace std;
struct FENWICK_TREE
{
int tree[100001];
inline void Reset(int x)
{
while (x <= 1e5)
{
tree[x] = 0;
x += (x & (~(x - 1)));
}
}
inline void Update(int x, int v)
{
while (x <= 1e5)
{
tree[x] += v;
x += (x & (~(x - 1)));
}
}
inline int Get(int x)
{
int res = 0;
while (0 < x)
{
res += tree[x];
x -= (x & (~(x - 1)));
}
return res;
}
} ft;
int n, a, b;
int head[100000], sz[100000], par[100000], depth[100000];
int compress[100000], c[100000];
long long res;
pair<int, int> e[100000];
vector<int> g[100000];
vector<pair<int, int>> v[100000], cur, temp;
inline void DFS(int node)
{
sz[node] = 1;
for (auto &i : g[node])
{
par[i] = node;
depth[i] = depth[node] + 1;
DFS(i);
sz[node] += sz[i];
}
}
inline void HLD(int node)
{
int mx = 0;
for (auto &i : g[node])
{
mx = max(mx, sz[i]);
}
for (auto &i : g[node])
{
if (sz[i] == mx)
{
head[i] = head[node];
HLD(i);
break;
}
}
for (auto &i : g[node])
{
if (sz[i] == mx)
{
mx = -1;
}
else
{
head[i] = i;
HLD(i);
break;
}
}
}
inline void Update(int node, int val)
{
v[head[node]].push_back({depth[node], val});
if (head[node] != 0)
{
Update(par[head[node]], val);
}
}
inline void Get(int node)
{
temp.clear();
while (!v[head[node]].empty() && (temp.empty() || temp.back().first < depth[node]))
{
temp.push_back(v[head[node]].back());
v[head[node]].pop_back();
if (temp.back().first > depth[node])
{
v[head[node]].push_back(temp.back());
temp.back().first = depth[node];
}
}
while (!temp.empty())
{
cur.push_back(temp.back());
temp.pop_back();
}
if (head[node] != 0)
{
Get(par[head[node]]);
}
}
int main()
{
ios_base::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
cin >> n;
for (int i = 0; i < n; ++i)
{
cin >> c[i];
compress[i] = c[i];
}
sort(compress, compress + n);
a = unique(compress, compress + n) - compress;
for (int i = 0; i < n; ++i)
{
c[i] = lower_bound(compress, compress + n, c[i]) - compress + 1;
}
for (int i = 0; i < n - 1; ++i)
{
cin >> e[i].first >> e[i].second;
e[i].first--;
e[i].second--;
g[e[i].first].push_back(e[i].second);
}
head[0] = 0;
depth[0] = 1;
DFS(0);
HLD(0);
Update(0, c[0]);
for (int i = 0; i < n - 1; ++i)
{
res = 0;
cur.clear();
Get(e[i].first);
Update(e[i].second, c[e[i].second]);
for (int j = 0; j < cur.size(); ++j)
{
if (j + 1 < cur.size())
{
ft.Update(cur[j].second, cur[j].first - cur[j + 1].first);
res += (long long)(cur[j].first - cur[j + 1].first) * ft.Get(cur[j].second - 1);
}
else
{
res += (long long)cur[j].first * ft.Get(cur[j].second - 1);
}
}
for (auto & j : cur)
{
ft.Reset(j.second);
}
cout << res << "\n";
}
return 0;
}
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |