#include <bits/stdc++.h>
using namespace std;
#pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
// Ofast, O0, O1, O2, O3, unroll-loops, fast-math, trapv
typedef long long ll;
typedef pair<ll, ll> pll;
typedef pair<int, int> pii;
#define Mp make_pair
#define sep ' '
#define endl '\n'
#define F first
#define S second
#define pb push_back
#define all(x) (x).begin(),(x).end()
#define kill(res) cout << res << '\n', exit(0);
#define set_dec(x) cout << fixed << setprecision(x);
#define fast_io ios_base::sync_with_stdio(false);cin.tie(0);cout.tie(0);
#define file_io freopen("input.txt", "r", stdin) ; freopen("output.txt", "w", stdout);
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
const ll L = 19;
const ll sq = 1000;
const ll N = 1e5 + 50;
const ll Mod = 1e9 + 7;
ll n, c[N], par[N][L], h[N], mx[N];
vector<int> edge, adj[N];
void dfs(int v, int p = 0){
par[v][0] = p;
for(int i = 1; i < L; i++){
if(!par[v][i-1]) break;
par[v][i] = par[par[v][i-1]][i-1];
}
for(int u: adj[v]){
h[u] = h[v] + 1; dfs(u, v);
}
}
int getPar(int v, int k){
for(int i = 0; i < L; i++){
if(k >> i & 1) v = par[v][i];
}
return v;
}
int lca(int v, int u){
if(h[v] < h[u]) swap(u, v);
v = getPar(v, h[v] - h[u]);
if(u == v) return v;
for(int i = L-1; i >= 0; i--){
if(par[v][i] != par[u][i]){
v = par[v][i]; u = par[u][i];
}
}
return par[v][0];
}
void prep(int v){
for(int u: adj[v]){
prep(u); mx[v] = max(mx[v], mx[u]);
}
if(mx[v]) c[v] = c[edge[mx[v]]];
}
vector<int> vec;
vector<pii> block;
int bit[N], m;
void upd(int x, int y){
for(; x <= m; x += x & (-x)) bit[x] += y;
}
int get(int x){
int out = 0;
for(; x; x -= x & (-x)) out += bit[x];
return out;
}
ll solve(){
fill(bit, bit + N, 0);
vec.clear();
for(auto i: block) vec.pb(i.F);
sort(all(vec)); vec.resize(unique(all(vec)) - vec.begin());
m = vec.size();
int sum = 0; ll res = 0;
for(auto [x, y]: block){
int id = lower_bound(all(vec), x) - vec.begin() + 1;
res += (sum - get(id)); upd(id, y); sum += y;
}
return res;
}
int main(){
fast_io;
cin >> n;
for(int i = 1; i <= n; i++) cin >> c[i];
int u, v; edge.pb(1);
for(int i = 1; i < n; i++){
cin >> u >> v;
adj[u].pb(v); edge.pb(v);
}
dfs(1);
for(int i = 1; i < n; i++){
block.clear();
int tah = -1; v = edge[i];
for(int j = i-1; j > 0; j--){
u = lca(v, edge[j]);
if(h[u] > tah){
block.pb({c[edge[j]], h[u] - tah});
tah = h[u];
}
}
/*vec.clear();
while(h[v] - 1 > tah){
v = par[v][0]; vec.pb(c[v]);
}
reverse(all(vec));
for(int j: vec) block.pb({j, 1});*/
/*vec.clear();
for(auto [x, y]: block) for(int j = 0; j < y; j++) vec.pb(x);
ll res = 0;
for(int j = 0; j < vec.size(); j++) for(int k = j+1; k < vec.size(); k++) if(vec[j] > vec[k]) res++;
cout << res << endl;*/
cout << solve() << endl;
/*
if(i % sq == 0){
fill(mx, mx + N, 0);
for(int j = i; j > i - sq; j--) mx[edge[j]] = j;
prep(1);
}*/
}
}
Compilation message
construction.cpp: In function 'll solve()':
construction.cpp:98:14: warning: structured bindings only available with '-std=c++17' or '-std=gnu++17'
98 | for(auto [x, y]: block){
| ^
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
1 ms |
6748 KB |
Output is correct |
2 |
Correct |
1 ms |
6748 KB |
Output is correct |
3 |
Incorrect |
2 ms |
6748 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
1 ms |
6748 KB |
Output is correct |
2 |
Correct |
1 ms |
6748 KB |
Output is correct |
3 |
Incorrect |
2 ms |
6748 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
1 ms |
6748 KB |
Output is correct |
2 |
Correct |
1 ms |
6748 KB |
Output is correct |
3 |
Incorrect |
2 ms |
6748 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |