답안 #705457

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
705457 2023-03-04T12:36:12 Z pakhomovee Unique Cities (JOI19_ho_t5) C++17
0 / 100
306 ms 26888 KB
#include <iostream>
#include <vector>
#include <algorithm>
#include <cmath>
#include <set>
#include <stdlib.h>
#include <map>
#include <random>
#include <cstring>
#include <functional>
#include <iomanip>
#include <cassert>
#include <queue>
#include <unordered_map>
#include <array>

using namespace std;

const int inf = 1e9;
vector<int> bfs(int v, vector<vector<int>> &g) {
    const int n = g.size();
    vector<int> dist(n, inf);
    dist[v] = 0;
    deque<int> q{v};
    while (!q.empty()) {
        int v = q.front();
        q.pop_front();
        for (int u : g[v]) {
            if (dist[u] > dist[v] + 1) {
                dist[u] = dist[v] + 1;
                q.push_back(u);
            }
        }
    }
    return dist;
}

const int maxn = 1 << 18;

int md[maxn];
int d[maxn];
void dfs(int v, vector<vector<int>> &g, int p) {
    md[v] = d[v];
    for (int u : g[v]) {
        if (u != p) {
            d[u] = d[v] + 1;
            dfs(u, g, v);
            md[v] = max(md[v], md[u]);
        }
    }
}
int ans[maxn];
vector<int> c;
int lst[maxn];

int f[maxn];
int get(int i) {
    int ans = 0;
    while (i >= 0) {
        ans += f[i];
        i = (i & (i + 1)) - 1;
    }
    return ans;
}

void upd(int i, int x) {
    while (i < maxn) {
        f[i] += x;
        i = i | (i + 1);
    }
}

struct node {
    int x, y, sz;
    node*l = nullptr;
    node*r = nullptr;
    node(int x): x(x) {
        y = rand();
        sz = 1;
    }
};

int s(node* v) {
    if (!v) return 0;
    return v->sz;
}

void upd(node* v) {
    if (!v) return;
    v->sz = 1 + s(v->l) + s(v->r);
}

node* merge(node*l, node*r) {
    if (!l) return r;
    if (!r) return l;
    if (l->y > r->y) {
        l->r = merge(l->r, r);
        upd(l);
        return l;
    }
    r->l = merge(l, r->l);
    upd(r);
    return r;
}

pair<node*, node*> split(node* v, int k) {
    if (!v) return { nullptr, nullptr };
    if (v->x <= k) {
        pair<node*, node*> q = split(v->r, k);
        v->r = q.first;
        upd(v);
        return { v, q.second };
    }
    pair<node*, node*> q = split(v->l, k);
    v->l = q.second;
    upd(v);
    return { q.first, v };
}

node* r = nullptr;

void add(int x) {
    node* t = new node(x);
    r = merge(r, t);
}

void cut(int x) {
    pair<node*, node*> q = split(r, x - 1);
    r = q.first;
}

void dfs1(int v, vector<vector<int>> &g, int p) {
    pair<int, int> mx = { -1, -1 };
    for (int u : g[v]) {
        if (u != p) {
            if (mx.first == -1) {
                mx.first = u;
            } else {
                if (md[mx.first] < md[u]) {
                    mx.second = mx.first;
                    mx.first = u;
                } else if (mx.second == -1 || md[mx.second] < md[u]) {
                    mx.second = u;
                }
            }
        }
    }
    int memlst = lst[c[v]];
    pair<node*, node*> q = { nullptr, nullptr };
    if (mx.second != -1) {
        upd(max(0, d[v] * 2 - md[mx.second]), 1);
        upd(d[v], -1);
        q = split(r, d[v] * 2 - md[mx.second] - 1);
        r = q.first;
        if (lst[c[v]] == -1 || get(lst[c[v]])) {
            add(d[v]);
            lst[c[v]] = d[v];
        }
        dfs1(mx.first, g, v);
        cut(d[v]);
        upd(max(0, d[v] * 2 - md[mx.second]), -1);
        upd(d[v], 1);
        r = merge(r, q.second);
    } else if (mx.first != -1) {
        if (lst[c[v]] == -1 || get(lst[c[v]])) {
            add(d[v]);
            lst[c[v]] = d[v];
        }
        dfs1(mx.first, g, v);
        cut(d[v]);
    }
    if (mx.first != -1) {
        upd(max(0, d[v] * 2 - md[mx.first]), 1);
        upd(d[v], -1);
        q = split(r, d[v] * 2 - md[mx.first] - 1);
        r = q.first;
        ans[v] = s(r);
        if (lst[c[v]] == -1 || get(lst[c[v]])) {
            add(d[v]);
            lst[c[v]] = d[v];
        }
    } else {
        ans[v] = s(r);
    }
    for (int u : g[v]) {
        if (u != p) {
            if (u == mx.first) continue;
            dfs1(u, g, v);
        }
    }
    if (mx.first != -1) {
        upd(max(0, d[v] * 2 - md[mx.first]), -1);
        upd(d[v], 1);
        cut(d[v]);
        r = merge(r, q.second);
    }
    lst[c[v]] = memlst;
}

void solve() {
    int n, m;
    cin >> n >> m;
    vector<vector<int>> g(n);
    for (int i = 1; i < n; ++i) {
        int u, v;
        cin >> u >> v;
        --u; --v;
        g[u].push_back(v);
        g[v].push_back(u);
    }
    c.resize(n);
    for (int& i : c) {
        cin >> i;
    }
    vector<int> x = bfs(0, g);
    int i = max_element(x.begin(), x.end()) - x.begin();
    x = bfs(i, g);
    i = max_element(x.begin(), x.end()) - x.begin();
    x = bfs(i, g);
    int j = max_element(x.begin(), x.end()) - x.begin();
    vector<int> answ(n, 0);
    if (1) {
        d[i] = 0;
        dfs(i, g, i);
        fill(lst, lst + maxn, -1);
        fill(f, f + maxn, 0);
        dfs1(i, g, i);
        for (int i = 0; i < n; ++i) {
            answ[i] = max(answ[i], ans[i]);
        }
    }
    if (1) {
        d[j] = 0;
        dfs(j, g, j);
        fill(f, f + maxn, 0);
        dfs1(j, g, j);
        for (int i = 0; i < n; ++i) {
            answ[i] = max(answ[i], ans[i]);
        }
    }
    for (int i : answ) {
        cout << i << ' ';
    }
}

int32_t main() {
    ios_base::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    int t;
    t = 1;
    while (t--) {
        solve();
    }
}
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 2388 KB Output is correct
2 Incorrect 3 ms 2616 KB Output isn't correct
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 178 ms 16748 KB Output isn't correct
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 306 ms 26888 KB Output isn't correct
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 2388 KB Output is correct
2 Incorrect 3 ms 2616 KB Output isn't correct
3 Halted 0 ms 0 KB -