This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
#include <bits/stdc++.h>
#define gibon ios::sync_with_stdio(false); cin.tie(0);
#define fi first
#define se second
#define all(x) x.begin(), x.end()
#define pdd pair<long double, long double>
#define pii pair<int, int>
#define pll pair<ll, ll>
#define pvv pair<vector<int>, vector<int>>
#pragma GCC optimize("O3")
#pragma GCC optimize("Ofast")
#pragma GCC optimize("unroll-loops")
typedef long long ll;
using namespace std;
const int mxN=200005;
const int mxM=10000010;
const int lgN=20;
const ll MOD=1000000007;
const ll INF=1e11;
int dx[4]={1, 0, -1, 0}, dy[4]={0, 1, 0, -1};
int ddx[8]={1, 1, 0, -1, -1, -1, 0, 1}, ddy[8]={0, -1, -1, -1, 0, 1, 1, 1};
int N, M;
int C[mxN];
vector <int> v[mxN];
int r1, r2;
int dep[mxN], sub[mxN], par[mxN][lgN];
multiset <pii> s[mxN];
bool pr2[mxN]; ///parent of r2
pii pnt[mxN];
int ans[mxN];
void dfs1(int now, int pre=-1)
{
sub[now]=0;
for(int nxt : v[now]) if(nxt!=pre)
{
dep[nxt]=dep[now]+1;
par[nxt][0]=now;
dfs1(nxt, now);
sub[now]=max(sub[nxt]+1, sub[now]);
}
}
void make_sps()
{
for(int i=1;i<lgN;i++) for(int j=1;j<=N;j++) par[j][i]=par[par[j][i-1]][i-1];
}
int lca(int a, int b)
{
if(dep[a]<dep[b]) swap(a, b);
for(int i=lgN-1;i>=0;i--) if(dep[a]-(1<<i)>=dep[b]) a=par[a][i];
if(a==b) return a;
for(int i=lgN-1;i>=0;i--) if(par[a][i]!=par[b][i]) a=par[a][i], b=par[b][i];
return par[a][0];
}
int dis(int a, int b)
{
int c=lca(a, b);
return dep[a]+dep[b]-2*dep[c];
}
int mov(int s, int e, int val)
{
int c=lca(s, e);
if(val<=dep[s]-dep[c])
{
for(int i=lgN-1;i>=0;i--) if(val&(1<<i)) s=par[s][i];
return s;
}
else return mov(e, s, dep[s]+dep[e]-2*dep[c]-val);
}
void make_s()
{
int now=r2;
while(now) pr2[now]=true, now=par[now][0];
for(int i=1;i<=N;i++)
{
for(int nxt : v[i]) if(dep[nxt]>dep[i])
{
s[i].insert(pii(sub[nxt]+1, nxt));
int udep=dep[i];
if(!pr2[nxt]) udep=max(udep, dis(i, r2));
s[nxt].insert(pii(udep+1, i));
}
}
}
struct qry{
pii typ; ///ci:(0, idx), ct: (1, idx), idx: (2, 0)
int f, t;
pii pf, pt;
qry(): typ(), f(), t(), pf(), pt() {}
qry(int a, int b, int f, int t): typ(a, b), f(f), t(t), pf(), pt() {}
qry (int a, int b, pii pf, pii pt) : typ(a, b), f(), t(), pf(pf), pt(pt) {}
};
pii ct[mxN]; ///dep, col
int idx;
vector <qry> ht;
int ci[mxN];
void dfs2(int now, int pre, int nd, int root)
{
//printf("now=%d, pre=%d, nd=%d\n", now, pre, nd);
if(dis(now, root)>dis(now, r1+r2-root) || (dis(now, root)==dis(now, r1+r2-root) && root<r1+r2-root))
{
if(v[now].size()==1) ans[now]+=idx; ///leaf or root
else
{
int ns=0;
auto it=s[now].rbegin();
if(it->se!=pre) ns=it->fi;
else
{
it++;
ns=it->fi;
}
//printf("ns=%d\n", ns);
ans[now]+=lower_bound(ct, ct+idx, pii(nd-ns, -1))-ct;
}
}
//for(int i=0;i<idx;i++) printf("(%d, %d) ", ct[i].fi, ct[i].se);
//printf("\n");
for(int nxt : v[now]) if(nxt!=pre)
{
int cnt1=0, cnt2=0;
int nc=C[now];
if(v[now].size()>=3) ///root의 degree=1이라서 괜찮다
{
int ns=0;
auto it=s[now].rbegin();
while(it->se==pre || it->se==nxt) it++;
ns=it->fi;
//printf("now=%d, nxt=%d, ns=%d\n", now, nxt, ns);
int ni=lower_bound(ct, ct+idx, pii(nd-ns, -1))-ct;
ht.emplace_back(2, 0, idx, ni);
cnt2++;
idx=ni;
}
if(ci[nc]==-1 || ci[nc]>=idx || ct[ci[nc]].se!=nc)
{
ht.emplace_back(0, nc, ci[nc], idx);
ht.emplace_back(1, idx, ct[idx], pii(nd, nc));
ht.emplace_back(2, 0, idx, idx+1);
cnt1+=3;
ci[nc]=idx;
ct[idx]=pii(nd, nc);
idx++;
}
dfs2(nxt, now, nd+1, root);
if(cnt1)
{
idx=ht.back().f;
ht.pop_back();
ct[idx]=ht.back().pf;
ht.pop_back();
ci[nc]=ht.back().f;
ht.pop_back();
}
if(cnt2)
{
idx=ht.back().f;
ht.pop_back();
}
//printf("now=%d, nxt=%d\n", now, nxt);
//for(int i=0;i<idx;i++) printf("(%d, %d) ", ct[i].fi, ct[i].se);
//printf("\n");
}
}
void init()
{
for(int i=1;i<=M;i++) ci[i]=-1;
idx=0;
}
int main()
{
gibon
cin >> N >> M;
for(int i=1;i<N;i++)
{
int a, b;
cin >> a >> b;
v[b].push_back(a);
v[a].push_back(b);
}
for(int i=1;i<=N;i++) cin >> C[i];
dfs1(1);
r1=1;
for(int i=1;i<=N;i++) if(dep[i]>dep[r1]) r1=i;
dep[r1]=0;
par[r1][0]=0;
dfs1(r1);
r2=r1;
for(int i=1;i<=N;i++) if(dep[i]>dep[r2]) r2=i;
make_sps();
make_s();
/*for(int i=1;i<=N;i++)
{
printf("%d: ", i);
for(auto [x, y] : s[i]) printf("(%d, %d) ", x, y);
printf("\n");
}*/
init();
dfs2(r1, -1, 0, r1);
//for(int i=1;i<=N;i++) printf("ans[%d]=%d\n", i, ans[i]);
init();
dfs2(r2, -1, 0, r2);
for(int i=1;i<=N;i++) cout << ans[i] << '\n';
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |