Submission #1275477

#TimeUsernameProblemLanguageResultExecution timeMemory
1275477MisterReaperUnique Cities (JOI19_ho_t5)C++20
32 / 100
335 ms42412 KiB
// File uniquecities.cpp created on 02.10.2025 at 15:34:56
#include <bits/stdc++.h>

using i64 = long long;

#ifdef DEBUG 
    #include "/home/ahmetalp/Desktop/Workplace/debug.h"
#else
    #define debug(...) void(23)
#endif

constexpr int max_N = int(2E5) + 5;

int N, M;
std::vector<int> adj[max_N];
int C[max_N];

int dep[max_N];

void dfs1(int v, int pr) {
    for (auto u : adj[v]) {
        if (u == pr) {
            continue;
        }
        dep[u] = dep[v] + 1;
        dfs1(u, v);
    }
}

#define def int mid = (l + r) >> 1, lv = v + 1, rv = v + ((mid - l + 1) << 1)

struct segtree {
    struct node {
        int mn = 0;
        int cnt = 1;
        int lz = 0;
        void modify(int x) {
            mn += x;
            lz += x;
        }
    };
    node unite(const node& lhs, const node& rhs) {
        node res;
        if (lhs.mn == rhs.mn) {
            res.mn = lhs.mn;
            res.cnt = lhs.cnt + rhs.cnt;
        } else if (lhs.mn < rhs.mn) {
            res.mn = lhs.mn;
            res.cnt = lhs.cnt;
        } else {
            res.mn = rhs.mn;
            res.cnt = rhs.cnt;
        }
        return res;
    }
    int n;
    std::vector<node> tree;
    void init(int n_) {
        n = n_;
        tree.resize(n << 1);
        auto dfs = [&](auto&& self, int v, int l, int r) -> void {
            if (l == r) {
                tree[v] = {0, 1, 0};
                return;
            }
            def;
            self(self, lv, l, mid);
            self(self, rv, mid + 1, r);
            tree[v] = unite(tree[lv], tree[rv]);
        };
        dfs(dfs, 0, 0, n - 1);
    }
    void push(int v, int l, int r) {
        def;
        tree[lv].modify(tree[v].lz);
        tree[rv].modify(tree[v].lz);
        tree[v].lz = 0;
    }
    void modify(int v, int l, int r, int ql, int qr, int x) {
        if (ql == l && r == qr) {
            tree[v].modify(x);
            return;
        }
        def;
        push(v, l, r);
        if (qr <= mid) {
            modify(lv, l, mid, ql, qr, x);
        } else if (mid + 1 <= ql) {
            modify(rv, mid + 1, r, ql, qr, x);
        } else {
            modify(lv, l, mid, ql, mid, x);
            modify(rv, mid + 1, r, mid + 1, qr, x);
        }
        tree[v] = unite(tree[lv], tree[rv]);
    }
    void modify(int l, int r, int x) {
        modify(0, 0, n - 1, l, r, x);
    }
    node get(int v, int l, int r, int ql, int qr) {
        if (ql == l && r == qr) {
            return tree[v];
        }
        def;
        push(v, l, r);
        if (qr <= mid) {
            return get(lv, l, mid, ql, qr);
        } else if (mid + 1 <= ql) {
            return get(rv, mid + 1, r, ql, qr);
        } else {
            return unite(get(lv, l, mid, ql, mid),
                        get(rv, mid + 1, r, mid + 1, qr));
        }
    }
    node get(int l, int r) {
        return get(0, 0, n - 1, l, r);
    }
} seg;

int h[max_N];

void dfs2(int v, int pr) {
    h[v] = 0;
    for (auto u : adj[v]) {
        if (u == pr) {
            continue;
        }
        dep[u] = dep[v] + 1;
        dfs2(u, v);
        h[v] = std::max(h[v], h[u] + 1);
    }
}

int ans[max_N];

void dfs3(int v, int pr) {
    debug(v);
    #ifdef DEBUG
        for (int i = 0; i <= dep[v] - 1; ++i) {
            auto x = seg.get(i, i);
            std::cerr << x.mn << " \n"[i == dep[v] - 1];
        }
    #endif

    if (h[v] <= dep[v] - 1) {
        auto x = seg.get(0, dep[v] - 1 - h[v]);
        if (x.mn == 0) {
            debug(v, x.cnt);
            ans[v] = std::max(ans[v], x.cnt);
        }
    }
    int mx1 = -1, mx2 = -1;
    for (auto u : adj[v]) {
        if (u == pr) {
            continue;
        }
        int x = h[u] + 1;
        if (x > mx1) {
            mx2 = mx1;
            mx1 = x;
        } else if (x > mx2) {
            mx2 = x;
        }
    }
    for (auto u : adj[v]) {
        if (u == pr) {
            continue;
        }
        int x = (h[u] + 1 == mx1 ? mx2 : mx1);
        if (std::max(0, dep[v] - x) <= dep[v] - 1) {
            seg.modify(std::max(0, dep[v] - x), dep[v] - 1, +1);
        }
        dfs3(u, v);
        if (std::max(0, dep[v] - x) <= dep[v] - 1) {
            seg.modify(std::max(0, dep[v] - x), dep[v] - 1, -1);
        }
    }
}

void solve(int v) {
    seg.init(N);
    dep[v] = 0;
    dfs2(v, -1);
    debug();
    dfs3(v, -1);
}

int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);

    std::cin >> N >> M;

    for (int i = 1; i < N; ++i) {
        int A, B;
        std::cin >> A >> B;
        --A, --B;
        adj[A].emplace_back(B);
        adj[B].emplace_back(A);
    }

    for (int i = 0; i < N; ++i) {
        std::cin >> C[i];
        --C[i];
    }

    dep[0] = 0;
    dfs1(0, -1);
    int d0 = std::max_element(dep, dep + N) - dep;
    dep[d0] = 0;
    dfs1(d0, -1);
    int d1 = std::max_element(dep, dep + N) - dep;

    solve(d0);
    solve(d1);

    for (int i = 0; i < N; ++i) {
        std::cout << ans[i] << '\n';
    }

    return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...