제출 #644999

#제출 시각아이디문제언어결과실행 시간메모리
644999pakhomoveeFireworks (APIO16_fireworks)C++17
100 / 100
986 ms140828 KiB
#include <iostream>
#include <vector>
#include <set>
#include <algorithm>
#include <cstdio>
#include <algorithm>
#include <functional>

using namespace std;

#define int long long

const int maxn = 300'000;

struct func {
    int k;
    int b;
    
    func(int k, int b): k(k), b(b) {}
    func(){}
    int get(int x) {
        return k * x + b;
    }
};

struct node {
    int x, y, add, sz;
    node* l = nullptr;
    node* r = nullptr;
    
    node(int x): x(x) {
        y = rand();
        add = 0;
        sz = 1;
    }
    node() {}
};

void push(node* v) {
    if (!v) return;
    if (v->l) {
        v->l->add += v->add;
    }
    if (v->r) {
        v->r->add += v->add;
    }
    v->x += v->add;
    v->add = 0;
}

int s(node* v) {
    if (!v) return 0;
    return v->sz;
}

void upd(node* v) {
    if (!v) return;
    v->sz = 1 + s(v->l) + s(v->r);
}

pair<node*, func> root[maxn];

node* merge(node* l, node* r) {
    if (!l) return r;
    if (!r) return l;
    if (l->y > r->y) {
        push(l);
        l->r = merge(l->r, r);
        upd(l);
        return l;
    }
    push(r);
    r->l = merge(l, r->l);
    upd(r);
    return r;
}

pair<node*, node*> split(node* v, int k) {
    if (!v) return { nullptr, nullptr };
    push(v);
    if (s(v->l) + 1 <= k) {
        pair<node*, node*> q = split(v->r, k - s(v->l) - 1);
        v->r = q.first;
        upd(v);
        return { v, q.second };
    }
    pair<node*, node*> q = split(v->l, k);
    v->l = q.second;
    upd(v);
    return { q.first, v };
}

pair<node*, node*> split1(node* v, int k) {
    if (!v) return { nullptr, nullptr };
    push(v);
    if (v->x <= k) {
        pair<node*, node*> q = split1(v->r, k);
        v->r = q.first;
        upd(v);
        return { v, q.second };
    }
    pair<node*, node*> q = split1(v->l, k);
    v->l = q.second;
    upd(v);
    return { q.first, v };
}

void add(node*& root, int x) {
    pair<node*, node*> q = split1(root, x);
    node* t = new node(x);
    root = merge(q.first, merge(t, q.second));
}

int back(node* v) {
    push(v);
    if (v->r) {
        return back(v->r);
    }
    return v->x;
}

void walk(node* v, vector<int> &x) {
    if (!v) return;
    push(v);
    walk(v->l, x);
    x.push_back(v->x);
    walk(v->r, x);
}

void dfs(int v, vector<vector<pair<int, int>>> &g, int l) {
    sort(g[v].begin(), g[v].end(), [&] (pair<int, int> u, pair<int, int> v) {
        return s(root[u.first].first) > s(root[v.first].first);
    });
    root[v] = { nullptr, func(0, 0) };
    if (g[v].size() == 0) {
        root[v] = { nullptr, func(1, -l) };
        add(root[v].first, l);
        add(root[v].first, l);
    } else {
        for (auto [u, c] : g[v]) {
            dfs(u, g, c);
        }
        root[v] = root[g[v][0].first];
        for (int i = 1; i < g[v].size(); ++i) {
            root[v].second.k += root[g[v][i].first].second.k;
            root[v].second.b += root[g[v][i].first].second.b;
            vector<int> x;
            walk(root[g[v][i].first].first, x);
            for (int c : x) {
                add(root[v].first, c);
            }
            x.clear();
        }
        if (1) {
            while (root[v].second.k > 1) {
                int x = back(root[v].first);
                int val = root[v].second.k * x + root[v].second.b;
                --root[v].second.k;
                root[v].second.b = val - root[v].second.k * x;
                pair<node*, node*> q = split(root[v].first, s(root[v].first) - 1);
                root[v].first = q.first;
            }
        }
        if (1) {
            pair<node*, node*> q = split(root[v].first, s(root[v].first) - 2);
            if (q.second) q.second->add += l;
            root[v].second.b -= l;
            root[v].first = merge(q.first, q.second);
        }
    }
    return;
}

int32_t main() {
    ios_base::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    int n, m;
    cin >> n >> m;
    vector<vector<pair<int, int>>> g(n + m);
    for (int i = 1; i < n + m; ++i) {
        int p, c;
        cin >> p >> c;
        --p;
        g[p].push_back({ i, c });
    }
    dfs(0, g, 0);
    int curr = root[0].second.k;
    vector<int> x;
    walk(root[0].first, x);
    auto pt = x.rbegin();
    while (curr--) {
        int x = *pt;
        int val = root[0].second.k * x + root[0].second.b;
        --root[0].second.k;
        root[0].second.b = val - root[0].second.k * x;
        ++pt;
    }
    cout << root[0].second.b;
}
/*
 4 6
 1 5
 2 5
 2 8
 3 3
 3 2
 3 3
 2 9
 4 4
 4 3
 */

컴파일 시 표준 에러 (stderr) 메시지

fireworks.cpp: In function 'void dfs(long long int, std::vector<std::vector<std::pair<long long int, long long int> > >&, long long int)':
fireworks.cpp:144:27: warning: comparison of integer expressions of different signedness: 'long long int' and 'std::vector<std::pair<long long int, long long int> >::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  144 |         for (int i = 1; i < g[v].size(); ++i) {
      |                         ~~^~~~~~~~~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...