Submission #1024430

#TimeUsernameProblemLanguageResultExecution timeMemory
1024430snpmrnhlolUnique Cities (JOI19_ho_t5)C++17
100 / 100
332 ms52416 KiB
#include<bits/stdc++.h>
using namespace std;
const int N = 2e5;
const int inf = 2e9;
vector <int> e[N];
int v[N];
int dist[N];
int dwn[N];
int ans[N];
void dfs(int node, int p, int dpth = 0){
    dist[node] = dpth;
    for(auto i:e[node]){
        if(i == p)continue;
        dfs(i, node, dpth + 1);
    }
}
void dfs2(int node, int p, int dpth){
    dwn[node] = 0;
    dist[node] = dpth;
    for(auto i:e[node]){
        if(i == p)continue;
        dfs2(i, node, dpth + 1);
        dwn[node] = max(dwn[node],dwn[i] + 1);
    }
}
vector <int> cand;
int f[N];
int cur = 0;
void add(int node){
    //cout<<"add:"<<node<<'\n';
    f[v[node]]++;
    if(f[v[node]] == 1)cur++;
    cand.push_back(node);
}
void del(){
    //cout<<"delete:"<<'\n';
    int node = cand.back();
    f[v[node]]--;
    cand.pop_back();
    if(f[v[node]] == 0)cur--;
}
void dfs3(int node, int p){
    vector <int> nodes;
    for(auto i:e[node]){
        if(i == p)continue;
        nodes.push_back(i);
    }
    if(!nodes.empty()){
        sort(nodes.begin(),nodes.end(),[&](int a,int b){
             return dwn[a] > dwn[b];
        });
        int sz;
        for(int i = 0;i < (int)nodes.size();i++){
            if(i == 0){
                if(nodes.size() > 1)sz = dwn[nodes[1]] + 1;
                else sz = -inf;
            }else{
                sz = dwn[nodes[0]] + 1;
            }
            while(!cand.empty() && dist[node] - dist[cand.back()] <= sz){
                del();
            }
            add(node);
            dfs3(nodes[i], node);
        }
        sz = dwn[nodes[0]] + 1;
        while(!cand.empty() && dist[node] - dist[cand.back()] <= sz){
            del();
        }
    }
    //cout<<"answer:"<<node<<' '<<cur<<' '<<dist[3]<<' '<<dist[9]<<' '<<dwn[6] + 1<<'\n';
    ans[node] = max(ans[node],cur);
}
void solve(int node){
    dfs2(node, -1, 0);
    dfs3(node, -1);
}
int main(){
    int n,m;
    cin>>n>>m;
    for(int i = 0;i < n - 1;i++){
        int u,w;
        cin>>u>>w;
        u--;w--;
        e[u].push_back(w);
        e[w].push_back(u);
    }
    for(int i = 0;i < n;i++){
        cin>>v[i];
        v[i]--;
    }
    dfs(0, -1);
    int mx = -1;
    int diam1 = -1,diam2 = -1;
    for(int i = 0;i < n;i++){
        if(mx < dist[i]){
            mx = dist[i];
            diam1 = i;
        }
    }
    dfs(diam1, -1);
    mx = -1;
    for(int i = 0;i < n;i++){
        if(mx < dist[i]){
            mx = dist[i];
            diam2 = i;
        }
    }
    //cout<<"diam1:"<<diam1<<'\n';
    solve(diam1);
    //cout<<"diam2:"<<diam2<<'\n';
    solve(diam2);
    for(int i = 0;i < n;i++){
        cout<<ans[i]<<'\n';
    }
    return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...