이 제출은 이전 버전의 oj.uz에서 채점하였습니다. 현재는 제출 당시와는 다른 서버에서 채점을 하기 때문에, 다시 제출하면 결과가 달라질 수도 있습니다.
//#define _GLIBCXX_DEBUG
//#pragma GCC optimize("Ofast")
//#pragma GCC optimize("unroll-loops")
//#pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native")
#include <bits/stdc++.h>
using namespace std;
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;
template<typename T>
using ordered_set = tree<T, null_type, less < T>, rb_tree_tag, tree_order_statistics_node_update>;
template<typename T>
using normal_queue = priority_queue <T, vector<T>, greater<>>;
mt19937 rnd(chrono::steady_clock::now().time_since_epoch().count());
#define ll long long
#define trace(x) cout << #x << ": " << (x) << endl;
#define all(x) begin(x), end(x)
#define rall(x) rbegin(x), rend(x)
#define uniq(x) x.resize(unique(all(x)) - begin(x))
#define ld long double
#define sz(s) ((int) size(s))
#define pii pair<int, int>
#define mp(x, y) make_pair(x, y)
#define int128 __int128
#define pb push_back
#define eb emplace_back
template<typename T>
bool ckmn(T &x, T y) {
    if (x > y) {
        x = y;
        return true;
    }
    return false;
}
template<typename T>
bool ckmx(T &x, T y) {
    if (x < y) {
        x = y;
        return true;
    }
    return false;
}
int bit(int x, int b) {
    return (x >> b) & 1;
}
int rand(int l, int r) { return (int) ((ll) rnd() % (r - l + 1)) + l; }
const ll infL = 3e18;
const int infI = 1000000000 + 7;
const int infM = 2139062143;
const ll infML = 9187201950435737471LL;
const int N = 200001;
const ll mod = 998244353;
const ld eps = 1e-9;
vector<int> g[N];
bool insus[N], used[N];
int ans[N], cnt[N], c[N], val = 0, d[2], dist[2][N], n, m, pp[N], mxd[N], ww[N];
vector<int> sus;
int calc_farthest(int v) {
    vector<int> dst(n, infI);
    dst[v] = 0;
    queue<int> q;
    q.push(v);
    int last = -1;
    while (!q.empty()) {
        int u = q.front();
        last = u;
        q.pop();
        for (int to: g[u]) {
            if (dst[to] == infI) {
                q.push(to);
                dst[to] = dst[u] + 1;
            }
        }
    }
    d[0] = last;
    return last;
}
void bfs1() {
    fill(dist[0], dist[0] + n, infI);
    dist[0][d[0]] = 0;
    queue<int> q;
    q.push(d[0]);
    pp[d[0]] = -1;
    while (!q.empty()) {
        int v = q.front();
        d[1] = v;
        q.pop();
        for (int to: g[v]) {
            if (dist[0][to] == infI) {
                q.push(to);
                dist[0][to] = dist[0][v] + 1;
                pp[to] = v;
            }
        }
    }
}
void bfs2() {
    fill(dist[1], dist[1] + n, infI);
    dist[1][d[1]] = 0;
    queue<int> q;
    q.push(d[1]);
    while (!q.empty()) {
        int v = q.front();
        q.pop();
        for (int to: g[v]) {
            if (dist[1][to] == infI) {
                q.push(to);
                dist[1][to] = dist[1][v] + 1;
            }
        }
    }
}
void calc_sus() {
    int v = d[1];
    while (v != -1) {
        sus.pb(v);
        insus[v] = true;
        v = pp[v];
    }
}
void calc_mxd(int v, int p) {
    for (int to: g[v]) {
        if (!insus[to] && to != p) {
            calc_mxd(to, v);
            ckmx(mxd[v], mxd[to] + 1);
        }
    }
}
vector<int> st;
void add(int x) {
    if (used[x]) return;
    used[x] = true;
    if (cnt[c[x]] == 0)
        ++val;
    ++cnt[c[x]];
    st.pb(x);
}
void del(int x) {
    if (!used[x]) return;
    used[x] = false;
    --cnt[c[x]];
    if (cnt[c[x]] == 0)
        --val;
    assert(!st.empty() && st.back() == x);
    st.pop_back();
}
int iteration = 0;
//don't forget to if this is in sus
int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cin >> n >> m;
    for (int i = 1; i < n; ++i) {
        int a, b;
        cin >> a >> b;
        --a, --b;
        g[a].pb(b), g[b].pb(a);
    }
    for (int i = 0; i < n; ++i) {
        cin >> c[i];
    }
    calc_farthest(0);
    bfs1();
    bfs2();
//    cout << d[0] << " " << d[1] << endl;
    calc_sus();
    for (int x: sus) calc_mxd(x, -1);
    //sus = d[1], v1, v2, ..., vn, d[0]
    m = sz(sus);
    auto del_stack = [](int v, int len) {
        while (!st.empty()) {
            int u = st.back();
            if (dist[iteration][u] + len >= dist[iteration][v]) {
                del(u);
            } else {
                break;
            }
        }
    };
    function<void(int, int)> dfs = [&](int v, int p) {
        int mx = -1;
        for (int i = 0; i < sz(g[v]); ++i) {
            int to = g[v][i];
            if (insus[to] || to == p) continue;
            if (mx == -1 || mxd[mx] < mxd[to]) {
                mx = to;
            }
        }
        if (mx == -1) {
            ans[v] = val;
            return;
        }
        for (int to: g[v]) {
            if (!insus[to] && to != p && to != mx) {
                del_stack(v, mxd[to] + 1);
            }
        }
        add(v);
        dfs(mx, v);
        del_stack(v, mxd[mx] + 1);
        add(v);
        for (int to: g[v]) {
            if (!insus[to] && to != p && to != mx) {
                add(v);
                dfs(to, v);
            }
        }
        del(v);
//        del_stack(v, mxd[v]);
        ans[v] = val;
    };
    int last = m - 1;
    int rr = m;
    memset(ww, 127, sizeof(ww));
    for (int i = m - 1; i >= 0; --i) {
        int lenL = i;
        while (last - i > lenL) {
            ckmn(rr, ww[last]);
            if (rr > last) {
                add(sus[last]);
            }
            --last;
        }
        if (i <= m - 1 - m / 2) {
            dfs(sus[i], -1);
        }
        del_stack(sus[i], mxd[sus[i]]);
        if (i + mxd[sus[i]] >= last) ckmn(rr, i + 1);
        else ckmn(ww[i + mxd[sus[i]]], i + 1);
    }
    last = 0;
    while (!st.empty()) {
        del(st.back());
    }
    iteration = 1;
    fill(ww, ww + m, -infI);
    int LL = -1;
    for (int i = 0; i < m; ++i) {
        int lenR = m - i - 1;
        while (i - last > lenR) {
            ckmx(LL, ww[last]);
            if (last > LL) {
                add(sus[last]);
            }
            ++last;
        }
        if (i > m - 1 - m / 2) {
            dfs(sus[i], -1);
        }
        del_stack(sus[i], mxd[sus[i]]);
        if (i - mxd[sus[i]] <= last) ckmx(LL, i - 1);
        else ckmx(ww[i - mxd[sus[i]]], i - 1);
    }
    for (int i = 0; i < n; ++i) {
        cout << ans[i] << '\n';
    }
    return 0;
}
| # | Verdict | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict | Execution time | Memory | Grader output | 
|---|
| Fetching results... |