This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define double long double
#define FOR(i, l, r, d) for(int i=(l); i<=(r); i+=(d))
#define szof(x) ((int)(x).size())
#define vi vector<int>
#define pii pair<int, int>
#define F first
#define S second
#define pb push_back
#define eb emplace_back
#define mkp make_pair
#define V st[cidx]
#define LC st[cidx*2]
#define RC st[cidx*2+1]
#define lsb(x) ((x)&(-(x)))
const int INF = INT_MAX;
const int LNF = INF*INF;
const int MOD = 1000000007;
const int mod = 998244353;
const double eps = 1e-12;
//#pragma GCC optimize("Ofast")
//#pragma GCC optimize("unroll-loops")
const int MAX = 100010;
// tree, query
int n;
int val[MAX];
int qv[MAX]; // 1 --> qv[i]
bool vis[MAX];
vi ch[MAX]; // child
// HLD
int par[MAX], sz[MAX];
int heavy[MAX], fr[MAX], dfn[MAX];
int ts;
void find_heavy(int v){
sz[v] = 1;
int Max = 0;
heavy[v] = -1;
par[v] = v;
for(int i : ch[v]){
find_heavy(i);
if(sz[i] > Max){
Max = sz[i];
heavy[v] = i;
}
sz[v] += sz[i];
par[i] = par[v];
}
}
void HLD(int v, int frv){
ts++;
dfn[v] = ts;
fr[v] = frv;
if(heavy[v] != -1) HLD(heavy[v], frv);
for(int i : ch[v]){
if(i == heavy[v]) continue;
HLD(i, i);
}
}
// 1 --> v
vector<pii> find_segments(int v){
vector<pii> ret;
while(v > 1){
ret.eb(dfn[fr[v]], dfn[v]);
v = par[fr[v]];
}
if(ret.empty() or ret.back().F != 1) ret.eb(1, 1);
reverse(ret.begin(), ret.end());
return ret;
}
// segment tree
struct ST_Node{
int sl, sr;
int val, all;
int tag;
};
struct SegTree{
ST_Node st[4*MAX];
void push(int cidx){
if(V.tag == -1) return;
V.val = V.tag;
V.all = 1;
if(V.sl < V.sr){
LC.tag = V.tag;
RC.tag = V.tag;
}
V.tag = -1;
}
void pull(int cidx){
push(cidx);
if(V.sl < V.sr){
push(cidx*2);
push(cidx*2+1);
V.val = LC.val;
V.all = (LC.val == RC.val and LC.all and RC.all);
}
}
void build(int cidx, int cl, int cr){
V.sl = cl;
V.sr = cr;
V.tag = -1;
if(cl < cr){
int mid = (cl + cr) / 2;
build(cidx*2, cl, mid);
build(cidx*2+1, mid+1, cr);
pull(cidx);
}
else{
V.val = val[cl];
V.all = 1;
}
}
void modify(int cidx, int ml, int mr, int mval){
if(mr < V.sl or V.sr < ml) return;
if(ml <= V.sl and V.sr <= mr){
V.tag = mval;
return;
}
modify(cidx*2, ml, mr, mval);
modify(cidx*2+1, ml, mr, mval);
pull(cidx);
}
void query(vector<pii>& arr, int cidx, int ql, int qr){
if(qr < V.sl or V.sr < ql) return;
pull(cidx);
if(ql <= V.sl and V.sr <= qr and V.all){
arr.eb(V.val, V.sr - V.sl + 1);
return;
}
query(arr, cidx*2, ql, qr);
query(arr, cidx*2+1, ql, qr);
}
};
SegTree st;
// BIT
struct BIT{
int Node[MAX];
void modify(int pos, int val){
while(pos < MAX){
Node[pos] += val;
pos += lsb(pos);
}
}
int query(int pos){
int ret = 0;
while(pos > 0){
ret += Node[pos];
pos -= lsb(pos);
}
return ret;
}
};
BIT bit;
// solve
int solve(int v){
int pv = par[v];
// 1 --> pv : [ ][ ][ ] ...
vector<pii> segs_HLD = find_segments(pv); // [l, r]
vector<pii> segs_ST; // <val, cnt>
for(pii p : segs_HLD) st.query(segs_ST, 1, p.F, p.S);
/*
cerr<<"solve "<<v<<" : "<<endl;
cerr<<"HLD : ";
for(pii p : segs_HLD){
cerr<<"["<<p.F<<", "<<p.S<<"] ";
}
cerr<<endl;
cerr<<"arr : ";
for(pii p : segs_ST){
FOR(i, 1, p.S, 1){
cerr<<p.F;
}
cerr<<" ";
}
cerr<<endl;
*/
// find ans.
int ret = 0;
for(pii p : segs_ST){
ret += p.S * (bit.query(MAX-1) - bit.query(p.F)); // [p.F+1, ...)
bit.modify(p.F, p.S);
}
// init. BIT
for(pii p : segs_ST){
bit.modify(p.F, -p.S);
}
// modify ST
for(pii p : segs_HLD){
st.modify(1, p.F, p.S, val[v]);
}
st.modify(1, dfn[v], dfn[v], val[v]);
//cerr<<"ret = "<<ret<<endl<<endl;
return ret;
}
signed main(){
ios_base::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
// in
cin>>n;
FOR(i, 1, n, 1){
cin>>val[i];
}
vis[1] = 1;
FOR(i, 1, n-1, 1){
int u, v;
cin>>u>>v;
if(vis[v]) swap(u, v);
qv[i] = v;
vis[v] = 1;
ch[u].pb(v);
}
// val -> [1, n]
map<int, int> mp;
vi tmp;
FOR(i, 1, n, 1){
tmp.pb(val[i]);
}
sort(tmp.begin(), tmp.end());
for(int i : tmp){
if(mp.find(i) == mp.end()){
int sz = szof(mp);
mp[i] = sz+1;
}
}
FOR(i, 1, n, 1){
val[i] = mp[val[i]];
}
// HLD
find_heavy(1);
HLD(1, 1);
/*
cerr<<"dfn : ";
FOR(i, 1, n, 1){
cerr<<dfn[i]<<" ";
}
cerr<<endl;
*/
// build segment tree
st.build(1, 1, n);
// solve
FOR(i, 1, n-1, 1){
//solve(qv[i]);
cout<<solve(qv[i])<<'\n';
}
return 0;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |