답안 #874822

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
874822 2023-11-17T21:14:37 Z tvladm2009 Construction of Highway (JOI18_construction) C++17
0 / 100
2 ms 10588 KB
#include <bits/stdc++.h>

using namespace std;

typedef long long ll;

const int N = 1e5 + 7;

int n;
int c[N];
int a[N], b[N];
vector<int> g[N];
int par[N], sz[N], top[N];
int tin[N], tout[N];
vector<int> order;
int t = 0;

void dfs(int u) {
    tin[u] = t++;
    order.push_back(u);
    for (auto v : g[u]) {
        if (sz[v] > sz[g[u][0]]) {
            swap(v, g[u][0]);
        }
    }
    for (auto v : g[u]) {
        top[v] = (v == g[u][0] ? top[u] : v);
        dfs(v);
    }
    tout[u] = t;
}

struct Node {
    int mn;
    int mx;

    Node(int val = 0) {
        mn = val;
        mx = val;
    }
};

Node operator + (Node a, Node b) {
    Node c;
    c.mn = min(a.mn, b.mn);
    c.mx = max(a.mx, b.mx);
    return c;
}

Node tree[4 * N];
int color[4 * N];

void build(int v, int tl, int tr) {
    color[v] = -1;
    if (tr - tl == 1) {
        tree[v] = Node(c[order[tl]]);
        return;
    }
    int tm = (tl + tr) / 2;
    build(2 * v, tl, tm);
    build(2 * v + 1, tm, tr);
    tree[v] = tree[2 * v] + tree[2 * v + 1];
}

void push(int v) {
    if (color[v] == -1) {
        return;
    }
    tree[2 * v] = Node(color[v]);
    tree[2 * v + 1] = Node(color[v]);
    color[2 * v] = color[v];
    color[2 * v + 1] = color[v];
    color[v] = -1;
}

void update(int v, int tl, int tr, int l, int r, int c) {
    if (tr <= l || tl >= r) {
        return;
    }
    if (l <= tl && tr <= r) {
        tree[v] = Node(c);
        color[v] = c;
        return;
    }
    push(v);
    int tm = (tl + tr) / 2;
    update(2 * v, tl, tm, l, r, c);
    update(2 * v + 1, tm, tr, l, r, c);
    tree[v] = tree[2 * v] + tree[2 * v + 1];
}

int find_prev(int v, int tl, int tr, int x, int y, int val) {
    if (tr <= x || tl >= y || (tree[v].mn == val && tree[v].mx == val)) {
        return -1;
    }
    if (tr - tl == 1) {
        return tl;
    }
    push(v);
    int tm = (tl + tr) / 2;
    int sol = find_prev(2 * v, tl, tm, x, y, val);
    if (sol != -1) {
        return sol;
    }
    return find_prev(2 * v + 1, tm, tr, x, y, val);
}

int get_color(int v, int tl, int tr, int pos) {
    if (color[v] != -1) {
        return color[v];
    }
    if (tr - tl == 1) {
        return c[order[tl]];
    }
    int tm = (tl + tr) / 2;
    if (pos < tm) {
        return get_color(2 * v, tl, tm, pos);
    }
    return get_color(2 * v + 1, tm, tr, pos);
}

int fen[N];

void add(int i, int val) {
    for (++i; i < N; i += i & -i) {
        fen[i] += val;
    }
}

int get(int i) {
    int sol = 0;
    for (++i; i >= 1; i -= i & -i) {
        sol += fen[i];
    }
    return sol;
}

ll get_cost(int u) {
    vector<pair<int, int>> segs;
    while (u != -1) {
        int c = get_color(1, 0, n, tin[u]);
        int pos = find_prev(1, 0, n, 0, tin[u] + 1, c);
        if (pos < tin[top[u]]) {
            segs.push_back({c, tin[u] - tin[top[u]] + 1});
            u = par[top[u]];
        } else {
            segs.push_back({c, tin[u] - pos});
            u = order[pos];
        }
    }
    ll ans = 0;
    for (auto it : segs) {
        ans += get(it.first) * 1ll * it.second;
        add(it.first, it.second);
    }
    for (auto it : segs) {
        add(it.first, -it.second);
    }
    return ans;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(0);
    cin >> n;
    for (int i = 0; i < n; ++i) {
        cin >> c[i];
    }
    vector<int> sorted;
    for (int i = 0; i < n; ++i) {
        sorted.push_back(c[i]);
    }
    sort(sorted.begin(), sorted.end());
    for (int i = 0; i < n; ++i) {
        c[i] = lower_bound(sorted.begin(), sorted.end(), c[i]) - sorted.begin();
    }
    for (int i = 0; i < n - 1; ++i) {
        cin >> a[i] >> b[i];
        --a[i];
        --b[i];
    }
    par[0] = -1;
    for (int i = 0; i < n - 1; ++i) {
        par[b[i]] = a[i];
        g[a[i]].push_back(b[i]);
    }
    for (int i = 0; i < n; ++i) {
        sz[i] = 1;
    }
    for (int i = n - 1; i >= 0; --i) {
        sz[a[i]] += sz[b[i]];
    }
    dfs(0);
    build(1, 0, n);
    for (int i = 0; i < n - 1; ++i) {
        cout << get_cost(a[i]) << "\n";
        for (int u = b[i]; u != -1; u = par[top[u]]) {
            update(1, 0, n, tin[top[u]], tin[u] + 1, c[b[i]]);
        }
    }
    return 0;
}
# 결과 실행 시간 메모리 Grader output
1 Correct 2 ms 10584 KB Output is correct
2 Incorrect 2 ms 10588 KB Output isn't correct
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 2 ms 10584 KB Output is correct
2 Incorrect 2 ms 10588 KB Output isn't correct
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 2 ms 10584 KB Output is correct
2 Incorrect 2 ms 10588 KB Output isn't correct
3 Halted 0 ms 0 KB -