#include<bits/stdc++.h>
using namespace std;
#define task "a"
#define se second
#define fi first
#define ll long long
#define ii pair<ll, ll>
const long mxN = 2e5 + 7;
int n, a[mxN], lf[mxN], root[mxN], in[mxN], par[mxN], h[mxN], stt, num, tree[mxN * 4], v[mxN], bit[mxN];
vector<int> val, w[mxN], mem;
bool cmp(int u, int v)
{
return lf[u] > lf[v];
}
void DFS(int j)
{
lf[j] = h[j];
for (int i : w[j])
{
h[i] = h[j] + 1;
par[i] = j;
DFS(i);
lf[j] = max(lf[j], lf[i]);
}
sort(w[j].begin(), w[j].end(), cmp);
}
void HLD(int j)
{
if (!stt)
stt = j;
root[j] = stt;
in[j] = ++num;
for (int i : w[j])
HLD(i);
stt = 0;
}
void Down(int j)
{
if (tree[j] == -1)
return;
tree[j * 2] = tree[j];
tree[j * 2 + 1] = tree[j];
}
void Up(int j)
{
if (tree[j * 2] == tree[j * 2 + 1])
tree[j] = tree[j * 2];
else
tree[j] = -1;
}
void Upd(int u, int v, int nw, int j = 1, int l = 1, int r = n)
{
if (u > r || l > v)
return;
if (u <= l && r <= v)
{
tree[j] = nw;
return;
}
Down(j);
int mid = (l + r) / 2;
Upd(u, v, nw, j * 2, l, mid);
Upd(u, v, nw, j * 2 + 1, mid + 1, r);
Up(j);
}
ll Get_BIT(int j)
{
int res = 0;
while (j)
{
res += bit[j];
j -= j & (-j);
}
return res;
}
void Upd_BIT(int j, int inc)
{
while (j < val.size())
{
bit[j] += inc;
j += j & (-j);
}
}
ll Get(int u, int v, int j = 1, int l = 1, int r = n)
{
if (u > r || l > v)
return 0;
if (u <= l && r <= v && tree[j] != -1)
{
//cout << tree[j] << " " << r - l + 1 << " ";
ll res = Get_BIT(tree[j]) * (r - l + 1);
Upd_BIT(tree[j] + 1, r - l + 1);
mem.push_back(tree[j]);
return res;
}
Down(j);
int mid = (l + r) / 2;
return Get(u, v, j * 2 + 1, mid + 1, r) + Get(u, v, j * 2, l, mid);
}
void Erase(int j)
{
while (j < val.size())
{
bit[j] = 0;
j += j & (-j);
}
}
void Add(int j)
{
int tmp = a[j];
Upd(in[j], in[j], tmp);
j = par[j];
ll ans = 0;
while (j)
{
ans += Get(in[root[j]], in[j]);
Upd(in[root[j]], in[j], tmp);
j = par[root[j]];
}
for (int i : mem)
Erase(i + 1);
mem.clear();
cout << ans << '\n';
}
int main()
{
ios_base::sync_with_stdio(false);
cin.tie(NULL);
cout.tie(NULL);
//freopen(task".INP", "r", stdin);
//freopen(task".OUT", "w", stdout);
cin >> n;
for (int i = 1; i <= n; i++)
{
cin >> a[i];
val.push_back(a[i]);
}
val.push_back(0);
sort(val.begin(), val.end());
val.erase(unique(val.begin(), val.end()), val.end());
for (int i = 1; i <= n; i++)
a[i] = lower_bound(val.begin(), val.end(), a[i]) - val.begin();
for (int i = 1; i < n; i++)
{
int u;
cin >> u >> v[i];
w[u].push_back(v[i]);
}
DFS(1);
HLD(1);
Upd(1, 1, a[1]);
for (int i = 1; i < n; i++)
Add(v[i]);
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |