이 제출은 이전 버전의 oj.uz에서 채점하였습니다. 현재는 제출 당시와는 다른 서버에서 채점을 하기 때문에, 다시 제출하면 결과가 달라질 수도 있습니다.
// Make the best become better
// No room for laziness
#include<bits/stdc++.h>
#define pb push_back
#define fi first
#define se second
using namespace std;
using ll = long long;
using ld = long double;
using ull = unsigned long long;
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
const int maxN = 1e5 + 5;
const int mod = 1e9 + 7;
const ll oo = 1e9;
int n, c[maxN];
vector<int> adj[maxN];
pair<int, int> e[maxN];
int p[maxN], sz[maxN];
int chainhead[maxN], chainid[maxN], nChain = 0;
int tin[maxN], Time = 0;
void ReadInput()
{
cin >> n;
vector<int> vc;
for(int i=1; i<=n; i++)
{
cin >> c[i];
vc.pb(c[i]);
}
sort(vc.begin(), vc.end());
vc.erase(unique(vc.begin(), vc.end()), vc.end());
for(int i=1; i<=n; i++)
c[i] = upper_bound(vc.begin(), vc.end(), c[i]) - vc.begin();
for(int i=1; i<n; i++)
{
int u, v;
cin >> u >> v;
adj[u].pb(v);
adj[v].pb(u);
e[i] = {u, v};
}
}
void dfs(int u)
{
sz[u] = 1;
for(int v : adj[u])
{
if(v == p[u]) continue;
p[v] = u;
dfs(v);
sz[u] += sz[v];
}
}
void HLD(int u)
{
if(!chainhead[nChain]) chainhead[nChain] = u;
chainid[u] = nChain;
tin[u] = ++Time;
int bigChild = -1;
for(int v : adj[u])
{
if(v == p[u]) continue;
if(bigChild == -1 || sz[bigChild] < sz[v]) bigChild = v;
}
if(bigChild != -1)
HLD(bigChild);
for(int v : adj[u])
{
if(v == p[u] || v == bigChild) continue;
nChain++;
HLD(v);
}
}
struct TSeg
{
int l, r, col;
bool operator < (const TSeg& other) const
{
if(l == other.l) return r < other.r;
return l < other.l;
}
};
int bit[maxN];
vector<int> bin;
void update(int pos, int val)
{
while(pos <= n)
{
bin.pb(pos);
bit[pos] += val;
pos += pos & -pos;
}
}
int get(int pos)
{
int res = 0;
while(pos)
{
res += bit[pos];
pos -= pos & -pos;
}
return res;
}
set<TSeg> S[maxN];
int query(int u, int col)
{
int res = 0;
while(u)
{
int t = chainhead[chainid[u]];
TSeg v = {-1, -1, -1};
while(!S[t].empty())
{
auto it = S[t].lower_bound({tin[u], oo, oo});
if(it == S[t].begin()) break;
it--;
TSeg tmp = *it;
S[t].erase(it);
// if(col == 8) cout << tmp.l << " " << tmp.r << '\n';
if(tmp.r <= tin[u])
{
res += get(tmp.col - 1) * (tmp.r - tmp.l + 1);
update(tmp.col, tmp.r - tmp.l + 1);
}
else
{
// cout << tmp.l << " " << tmp.r << '\n';
v = {tin[u] + 1, tmp.r, tmp.col};
res += get(tmp.col - 1) * (tin[u] - tmp.l + 1);
update(tmp.col, tin[u] - tmp.l + 1);
}
}
if(v.col != -1) S[t].insert(v);
S[t].insert({tin[t], tin[u], col});
u = p[t];
}
for(int v : bin)
bit[v] = 0;
bin.clear();
return res;
}
void Solve()
{
dfs(1);
HLD(1);
for(int i=1; i<=n; i++)
{
S[chainhead[chainid[i]]].insert({tin[i], tin[i], c[i]});
}
for(int i=1; i<n; i++)
{
int u = e[i].fi, v = e[i].se;
cout << query(u, c[v]) << '\n';
}
}
int32_t main()
{
//freopen("sol.inp", "r", stdin);
ios_base::sync_with_stdio(false);
cin.tie(nullptr);
ReadInput();
Solve();
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |