제출 #1348465

#제출 시각아이디문제언어결과실행 시간메모리
1348465nguyenkhangninh99Unique Cities (JOI19_ho_t5)C++20
4 / 100
2096 ms60128 KiB
#include <bits/stdc++.h>
using namespace std;

signed main(){
    ios_base::sync_with_stdio(false);
    cin.tie(0); cout.tie(0);

    int n, m; cin >> n >> m;

    vector<vector<int>> adj(n + 1);
    for(int i = 1; i <= n - 1; i++){
        int u, v; cin >> u >> v;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }

    vector<int> typ(n + 1);
    for (int i = 1; i <= n; ++i) cin >> typ[i];

    vector<int> dist(n + 1), len(n + 1), res(n + 1, 0);

    function<void(int, int, int)> dfs = [&](int u, int p, int d){
        dist[u] = d;
        len[u] = 0;
        for(int v: adj[u]){
            if(v == p) continue;
            dfs(v, u, d + 1);
            len[u] = max(len[u], len[v] + 1);
        }
    };

    vector<int> cnt(m + 1);
    int resact = 0;
    vector<int> pila; 

    auto add = [&](int u) {
        if (++cnt[typ[u]] == 1) resact++;
    };
    auto del = [&](int u) {
        if (--cnt[typ[u]] == 0) resact--;
    };

    function<void(int, int)> ndfs = [&](int u, int p){
        vector<pair<int, int>> radj;
        for(int v : adj[u]){
            if(v != p) radj.push_back({len[v] + 1, v});
        }
        sort(radj.begin(), radj.end(), greater<pair<int, int>>());

        int mx = (radj.size() > 1 ? radj[1].first : 0);
        //tại đây. pila bằng root -> p
        for(auto [d, v]: radj){
            vector<int> popped;
            while(!pila.empty() && mx >= dist[u] - dist[pila.back()]){
                popped.push_back(pila.back());
                del(pila.back());
                pila.pop_back();
            }

            add(u);
            pila.push_back(u);
            
            ndfs(v, u);

            pila.pop_back();
            del(u);
            while (!popped.empty()) {
                add(popped.back());
                pila.push_back(popped.back());
                popped.pop_back();
            }

            mx = max(mx, d);
        }

        vector<int> popped;
        while(!pila.empty() && mx >= dist[u] - dist[pila.back()]){
            popped.push_back(pila.back());
            del(pila.back());
            pila.pop_back();
        }

        res[u] = max(res[u], resact);
        while (!popped.empty()) {
            add(popped.back());
            pila.push_back(popped.back());
            popped.pop_back();
        }
    };

    dfs(1, 0, 0);
    int d1 = max_element(dist.begin(), dist.end()) - dist.begin();
    
    dfs(d1, 0, 0);
    ndfs(d1, -1);

    int d2 = max_element(dist.begin(), dist.end()) - dist.begin();
    dfs(d2, -1, 0); 
    resact = 0; 
    fill(cnt.begin(), cnt.end(), 0); 
    pila.clear();
    ndfs(d2, -1);

    for (int i = 1; i <= n; i++) cout << res[i] << "\n";
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...