이 제출은 이전 버전의 oj.uz에서 채점하였습니다. 현재는 제출 당시와는 다른 서버에서 채점을 하기 때문에, 다시 제출하면 결과가 달라질 수도 있습니다.
#include <bits/stdc++.h>
#define gibon ios::sync_with_stdio(false); cin.tie(0);
#define fi first
#define se second
#define all(x) x.begin(), x.end()
#define pdd pair<long double, long double>
#define pii pair<int, int>
#define pll pair<ll, ll>
#define pvv pair<vector<int>, vector<int>>
#pragma GCC optimize("O3")
#pragma GCC optimize("Ofast")
#pragma GCC optimize("unroll-loops")
typedef long long ll;
using namespace std;
const int mxN=200005;
const int mxM=10000010;
const int lgN=20;
const ll MOD=1000000007;
const ll INF=1e11;
int dx[4]={1, 0, -1, 0}, dy[4]={0, 1, 0, -1};
int ddx[8]={1, 1, 0, -1, -1, -1, 0, 1}, ddy[8]={0, -1, -1, -1, 0, 1, 1, 1};
int N, M;
int C[mxN];
vector <int> v[mxN];
int r1, r2;
int dep[mxN], sub[mxN], par[mxN][lgN];
multiset <pii> s[mxN];
bool pr2[mxN];  ///parent of r2
pii pnt[mxN];
int ans[mxN];
void dfs1(int now, int pre=-1)
{
    sub[now]=0;
    for(int nxt : v[now])   if(nxt!=pre)
    {
        dep[nxt]=dep[now]+1;
        par[nxt][0]=now;
        dfs1(nxt, now);
        sub[now]=max(sub[nxt]+1, sub[now]);
    }
}
void make_sps()
{
    for(int i=1;i<lgN;i++)  for(int j=1;j<=N;j++)   par[j][i]=par[par[j][i-1]][i-1];
}
int lca(int a, int b)
{
    if(dep[a]<dep[b])   swap(a, b);
    for(int i=lgN-1;i>=0;i--)   if(dep[a]-(1<<i)>=dep[b])   a=par[a][i];
    if(a==b)    return a;
    for(int i=lgN-1;i>=0;i--)   if(par[a][i]!=par[b][i])    a=par[a][i], b=par[b][i];
    return par[a][0];
}
int dis(int a, int b)
{
    int c=lca(a, b);
    return dep[a]+dep[b]-2*dep[c];
}
int mov(int s, int e, int val)
{
    int c=lca(s, e);
    if(val<=dep[s]-dep[c])
    {
        for(int i=lgN-1;i>=0;i--)   if(val&(1<<i))  s=par[s][i];
        return s;
    }
    else    return mov(e, s, dep[s]+dep[e]-2*dep[c]-val);
}
void make_s()
{
    int now=r2;
    while(now)  pr2[now]=true, now=par[now][0];
    for(int i=1;i<=N;i++)
    {
        for(int nxt : v[i]) if(dep[nxt]>dep[i])
        {
            s[i].insert(pii(sub[nxt]+1, nxt));
            int udep=dep[i];
            if(!pr2[nxt])   udep=max(udep, dis(i, r2));
            s[nxt].insert(pii(udep+1, i));
        }
    }
}
struct qry{
    pii typ;    ///ci:(0, idx), ct: (1, idx), idx: (2, 0)
    int f, t;
    pii pf, pt;
    qry(): typ(), f(), t(), pf(), pt() {}
    qry(int a, int b, int f, int t): typ(a, b), f(f), t(t), pf(), pt() {}
    qry (int a, int b, pii pf, pii pt) : typ(a, b), f(), t(), pf(pf), pt(pt) {}
};
pii ct[mxN];    ///dep, col
int idx;
vector <qry> ht;
int ci[mxN];
void dfs2(int now, int pre, int nd, int root)
{
    //printf("now=%d, pre=%d, nd=%d\n", now, pre, nd);
    if(dis(now, root)>dis(now, r1+r2-root) || (dis(now, root)==dis(now, r1+r2-root) && root<r1+r2-root))
    {
        if(v[now].size()==1)    ans[now]+=idx;  ///leaf or root
        else
        {
            int ns=0;
            auto it=s[now].rbegin();
            if(it->se!=pre) ns=it->fi;
            else
            {
                it++;
                ns=it->fi;
            }
            //printf("ns=%d\n", ns);
            ans[now]+=lower_bound(ct, ct+idx, pii(nd-ns, -1))-ct;
        }
    }
    //for(int i=0;i<idx;i++)  printf("(%d, %d) ", ct[i].fi, ct[i].se);
    //printf("\n");
    for(int nxt : v[now])   if(nxt!=pre)
    {
        int cnt1=0, cnt2=0;
        int nc=C[now];
        if(v[now].size()>=3)    ///root의 degree=1이라서 괜찮다
        {
            int ns=0;
            auto it=s[now].rbegin();
            while(it->se==pre || it->se==nxt)   it++;
            ns=it->fi;
            //printf("now=%d, nxt=%d, ns=%d\n", now, nxt, ns);
            int ni=lower_bound(ct, ct+idx, pii(nd-ns, -1))-ct;
            ht.emplace_back(2, 0, idx, ni);
            cnt2++;
            idx=ni;
        }
        if(ci[nc]==-1 || ci[nc]>=idx || ct[ci[nc]].se!=nc)
        {
            ht.emplace_back(0, nc, ci[nc], idx);
            ht.emplace_back(1, idx, ct[idx], pii(nd, nc));
            ht.emplace_back(2, 0, idx, idx+1);
            cnt1+=3;
            ci[nc]=idx;
            ct[idx]=pii(nd, nc);
            idx++;
        }
        dfs2(nxt, now, nd+1, root);
        if(cnt1)
        {
            idx=ht.back().f;
            ht.pop_back();
            ct[idx]=ht.back().pf;
            ht.pop_back();
            ci[nc]=ht.back().f;
            ht.pop_back();
        }
        if(cnt2)
        {
            idx=ht.back().f;
            ht.pop_back();
        }
        //printf("now=%d, nxt=%d\n", now, nxt);
        //for(int i=0;i<idx;i++)  printf("(%d, %d) ", ct[i].fi, ct[i].se);
        //printf("\n");
    }
}
void init()
{
    for(int i=1;i<=M;i++)   ci[i]=-1;
    idx=0;
}
int main()
{
    gibon
    cin >> N >> M;
    for(int i=1;i<N;i++)
    {
        int a, b;
        cin >> a >> b;
        v[b].push_back(a);
        v[a].push_back(b);
    }
    for(int i=1;i<=N;i++)   cin >> C[i];
    dfs1(1);
    r1=1;
    for(int i=1;i<=N;i++)   if(dep[i]>dep[r1])  r1=i;
    dep[r1]=0;
    par[r1][0]=0;
    dfs1(r1);
    r2=r1;
    for(int i=1;i<=N;i++)   if(dep[i]>dep[r2])  r2=i;
    make_sps();
    make_s();
    /*for(int i=1;i<=N;i++)
    {
        printf("%d: ", i);
        for(auto [x, y] : s[i]) printf("(%d, %d) ", x, y);
        printf("\n");
    }*/
    init();
    dfs2(r1, -1, 0, r1);
    //for(int i=1;i<=N;i++)   printf("ans[%d]=%d\n", i, ans[i]);
    init();
    dfs2(r2, -1, 0, r2);
    for(int i=1;i<=N;i++)   cout << ans[i] << '\n';
}
| # | Verdict | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict | Execution time | Memory | Grader output | 
|---|
| Fetching results... |