제출 #529044

#제출 시각아이디문제언어결과실행 시간메모리
529044Haruto810198Construction of Highway (JOI18_construction)C++17
100 / 100
757 ms39616 KiB
#include <bits/stdc++.h>

using namespace std;

#define int long long
#define double long double

#define FOR(i, l, r, d) for(int i=(l); i<=(r); i+=(d))
#define szof(x) ((int)(x).size())

#define vi vector<int>
#define pii pair<int, int>

#define F first
#define S second

#define pb push_back
#define eb emplace_back
#define mkp make_pair

#define V st[cidx]
#define LC st[cidx*2]
#define RC st[cidx*2+1]

#define lsb(x) ((x)&(-(x)))

const int INF = INT_MAX;
const int LNF = INF*INF;
const int MOD = 1000000007;
const int mod = 998244353;
const double eps = 1e-12;

//#pragma GCC optimize("Ofast")
//#pragma GCC optimize("unroll-loops")

const int MAX = 100010;

// tree, query

int n;
int val[MAX];
int qv[MAX]; // 1 --> qv[i]
bool vis[MAX];
vi ch[MAX]; // child

// HLD

int par[MAX], sz[MAX];
int heavy[MAX], fr[MAX], dfn[MAX];
int ts;

void find_heavy(int v){
	
	sz[v] = 1;
	int Max = 0;
	heavy[v] = -1;
	par[v] = v;

	for(int i : ch[v]){
		find_heavy(i);
		if(sz[i] > Max){
			Max = sz[i];
			heavy[v] = i;
		}
		sz[v] += sz[i];
		par[i] = par[v];
	}
}

void HLD(int v, int frv){
	
	ts++;
	dfn[v] = ts;
	fr[v] = frv;
	if(heavy[v] != -1) HLD(heavy[v], frv);

	for(int i : ch[v]){
		if(i == heavy[v]) continue;
		HLD(i, i);
	}
}

// 1 --> v
vector<pii> find_segments(int v){
	vector<pii> ret;
	while(v > 1){
		ret.eb(dfn[fr[v]], dfn[v]);
		v = par[fr[v]];
	}

	if(ret.empty() or ret.back().F != 1) ret.eb(1, 1);
	reverse(ret.begin(), ret.end());
	return ret;
}

// segment tree

struct ST_Node{
	int sl, sr;
	int val, all;
	int tag;
};

struct SegTree{
	
	ST_Node st[4*MAX];
		
	void push(int cidx){
		if(V.tag == -1) return;

		V.val = V.tag;
		V.all = 1;
		if(V.sl < V.sr){
			LC.tag = V.tag;
			RC.tag = V.tag;
		}

		V.tag = -1;
	}

	void pull(int cidx){
		push(cidx);

		if(V.sl < V.sr){
			push(cidx*2);
			push(cidx*2+1);

			V.val = LC.val;
			V.all = (LC.val == RC.val and LC.all and RC.all);
		}
	}

	void build(int cidx, int cl, int cr){
		V.sl = cl;
		V.sr = cr;
		V.tag = -1;
		if(cl < cr){
			int mid = (cl + cr) / 2;
			build(cidx*2, cl, mid);
			build(cidx*2+1, mid+1, cr);
			pull(cidx);
		}
		else{
			V.val = val[cl];
			V.all = 1;
		}
	}

	void modify(int cidx, int ml, int mr, int mval){
		if(mr < V.sl or V.sr < ml) return;
		if(ml <= V.sl and V.sr <= mr){
			V.tag = mval;
			return;
		}

		modify(cidx*2, ml, mr, mval);
		modify(cidx*2+1, ml, mr, mval);
		pull(cidx);
	}

	void query(vector<pii>& arr, int cidx, int ql, int qr){
		if(qr < V.sl or V.sr < ql) return;

		pull(cidx);
		if(ql <= V.sl and V.sr <= qr and V.all){
			arr.eb(V.val, V.sr - V.sl + 1);
			return;
		}
		
		query(arr, cidx*2, ql, qr);
		query(arr, cidx*2+1, ql, qr);
	}
	
};

SegTree st;

// BIT

struct BIT{
	
	int Node[MAX];

	void modify(int pos, int val){
		while(pos < MAX){
			Node[pos] += val;
			pos += lsb(pos);
		}
	}

	int query(int pos){
		int ret = 0;
		while(pos > 0){
			ret += Node[pos];
			pos -= lsb(pos);
		}
		return ret;
	}

};

BIT bit;

// solve
int solve(int v){
	
	int pv = par[v];

	// 1 --> pv : [  ][  ][  ] ...
	vector<pii> segs_HLD = find_segments(pv); // [l, r]
	vector<pii> segs_ST; // <val, cnt>
	for(pii p : segs_HLD) st.query(segs_ST, 1, p.F, p.S);
	/*	
	cerr<<"solve "<<v<<" : "<<endl;

	cerr<<"HLD : ";
	for(pii p : segs_HLD){
		cerr<<"["<<p.F<<", "<<p.S<<"] ";
	}
	cerr<<endl;
	
	cerr<<"arr : ";
	for(pii p : segs_ST){
		FOR(i, 1, p.S, 1){
			cerr<<p.F;
		}
		cerr<<" ";
	}
	cerr<<endl;
	*/
	// find ans.
	int ret = 0;
	for(pii p : segs_ST){
		ret += p.S * (bit.query(MAX-1) - bit.query(p.F)); // [p.F+1, ...)
		bit.modify(p.F, p.S);
	}

	// init. BIT
	for(pii p : segs_ST){
		bit.modify(p.F, -p.S);
	}
	
	// modify ST
	for(pii p : segs_HLD){
		st.modify(1, p.F, p.S, val[v]);
	}
	st.modify(1, dfn[v], dfn[v], val[v]);
	
	//cerr<<"ret = "<<ret<<endl<<endl;

	return ret;
}

signed main(){
	
	ios_base::sync_with_stdio(false);
	cin.tie(0);
	cout.tie(0);
	
	// in
	cin>>n;
	FOR(i, 1, n, 1){
		cin>>val[i];
	}
	
	vis[1] = 1;
	FOR(i, 1, n-1, 1){
		int u, v;
		cin>>u>>v;
		if(vis[v]) swap(u, v);
		qv[i] = v;
		vis[v] = 1;
		ch[u].pb(v);
	}
	
	// val -> [1, n]
	map<int, int> mp;
	vi tmp;
	FOR(i, 1, n, 1){
		tmp.pb(val[i]);
	}
	sort(tmp.begin(), tmp.end());

	for(int i : tmp){
		if(mp.find(i) == mp.end()){
			int sz = szof(mp);
			mp[i] = sz+1;
		}
	}

	FOR(i, 1, n, 1){
		val[i] = mp[val[i]];
	}

	// HLD
	find_heavy(1);
	HLD(1, 1);
	/*	
	cerr<<"dfn : ";
	FOR(i, 1, n, 1){
		cerr<<dfn[i]<<" ";
	}
	cerr<<endl;
	*/
	// build segment tree
	st.build(1, 1, n);
	
	// solve
	FOR(i, 1, n-1, 1){
		//solve(qv[i]);
		cout<<solve(qv[i])<<'\n';
	}
	
	return 0;

}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...