Submission #102631

#TimeUsernameProblemLanguageResultExecution timeMemory
102631forestryksConstruction of Highway (JOI18_construction)C++14
100 / 100
1416 ms21848 KiB
///////////////////////////////////////////////////////////////////////////////////////////////
#include <bits/stdc++.h>
using namespace std;

#define mp make_pair
#define pb push_back
#define FAST_IO ios_base::sync_with_stdio(false); cin.tie(0); cout.tie(0)
#define FILE_IO(x) freopen((string(x) + ".in").c_str(), "r", stdin); freopen((string(x) + ".out").c_str(), "w", stdout)
#define f first
#define s second
#define x1 x1qwer
#define y1 y1qwer
#define right right123
#define left left123
#define foreach(it, v) for (auto it : v)
#define rep(it, n) for (int it = 0; it < n; ++it)
#define forin(it, l, r) for (int it = l; it < r; ++it)
#define all(x) x.begin(), x.end()

typedef long long ll;
typedef unsigned long long ull;
typedef long double ld;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
const double DINF = numeric_limits<double>::infinity();
const ll MOD = 1e9 + 7;
const double EPS = 1e-7;
ll gcd(ll a, ll b) { return b ? gcd(b, a % b) : a; }
mt19937 mmtw_(MOD);
uniform_int_distribution<ll> rd_;
ll randomll() { return rd_(mmtw_);}
ll rnd(ll x, ll y) { return rd_(mmtw_) % (y - x + 1) + x; }
////////////////////////////////////////////////////////////////////////////////////////////////

const int MAXN = 1e5 + 5;
int n;
int a[MAXN];
vector<int> ord;
int p[MAXN];
vector<int> g[MAXN];
int sz[MAXN];
int top[MAXN];
int tin[MAXN];
int tout[MAXN];
int h[MAXN];
vector<pair<int, int>> seg[MAXN];

void rebuild_graph(int v) {
    sz[v] = 1;
    for (auto to : g[v]) {
        rebuild_graph(to);
        sz[v] += sz[to];
    }

    nth_element(g[v].begin(), g[v].begin(), g[v].end(), [&](int a, int b){
        return sz[a] > sz[b];
    });
}

int tt = 0;
void build_hld(int v) {
    tin[v] = tt++;

    if (!seg[top[v]].empty() && seg[top[v]].back().f == a[v]) {
        seg[top[v]].back().s++;
    } else {
        seg[top[v]].push_back({a[v], 1});
    }

    if (!g[v].empty()) {
        for (auto to : g[v]) {
            top[to] = to;
            h[to] = h[v] + 1;
        }
        top[g[v][0]] = top[v];
        for (auto to : g[v]) {
            build_hld(to);
        }
    }

    reverse(all(seg[v]));

    tout[v] = tt;
}

void compress() {
    vector<int> vals(a, a + n);
    sort(all(vals));
    vals.erase(unique(all(vals)), vals.end());
    rep(i, n) {
        a[i] = lower_bound(all(vals), a[i]) - vals.begin();
    }
}

void prepare() {
    compress();
    rebuild_graph(0);
    build_hld(0);
}

int t[MAXN * 4];

void add(int v, int tl, int tr, int pos, int x) {
    if (tr - tl == 1) {
        t[v] += x;
        return;
    }

    int tm = tl + (tr - tl) / 2;
    if (pos < tm) {
        add(v * 2 + 1, tl, tm, pos, x);
    } else {
        add(v * 2 + 2, tm, tr, pos, x);
    }
    t[v] = t[v * 2 + 1] + t[v * 2 + 2];
}

int get(int v, int tl, int tr, int l, int r) {
    if (l <= tl && tr <= r) return t[v];
    if (r <= tl || tr <= l) return 0;

    int tm = tl + (tr - tl) / 2;
    int a = get(v * 2 + 1, tl, tm, l, r);
    int b = get(v * 2 + 2, tm, tr, l, r);
    return a + b;
}

// void add(int v, int tl, int tr, int pos, int x) {
//     t[pos] += x;
// }

// int get(int v, int tl, int tr, int l, int r) {
//     int res = 0;
//     for (int i = l; i < r; ++i) {
//         res += t[i];
//     }
//     return res;
// }

ll solve(const vector<pair<int, int>> &a) {
    ll res = 0;
    for (int i = 0; i < a.size(); ++i) {
        res += 1LL * get(0, 0, n, a[i].f + 1, n) * a[i].s;
        add(0, 0, n, a[i].f, a[i].s);
    }

    for (int i = 0; i < a.size(); ++i) {
        add(0, 0, n, a[i].f, -a[i].s);
    }

    return res;

    // ll res = 0;
    // for (int i = 0; i < a.size(); ++i) {
    //     for (int j = i + 1; j < a.size(); ++j) {
    //         if (a[i].f > a[j].f) {
    //             res += 1LL * a[i].s * a[j].s;
    //         }
    //     }
    // }
    // return res;
}

ll query(int v) {
    vector<int> vc;
    int w = p[v];
    while (w != -1) {
        vc.push_back(w);
        w = p[top[w]];
    }

    reverse(all(vc));

    vector<pair<int, int>> vec;
    for (auto w : vc) {
        int sz = 0;
        while (true) {
            if (seg[top[w]].empty()) break;

            if (h[top[w]] + sz + seg[top[w]].back().s - 1 <= h[w]) {
                sz += seg[top[w]].back().s;
                vec.push_back(seg[top[w]].back());
                seg[top[w]].pop_back();
            } else {
                if (h[top[w]] + sz <= h[w]) {
                    int our = h[w] - (h[top[w]] + sz) + 1;
                    seg[top[w]].back().s -= our;
                    sz += our;
                    vec.push_back(seg[top[w]].back());
                    vec.back().s = our;
                }
                break;
            }
        }
        seg[top[w]].push_back({a[v], sz});
        // cout << a[v] << ' ' << sz << endl;
    }

    // ll res = 0;
    // for (int i = 0; i < vec.size(); ++i) {
    //     for (int j = i + 1; j < vec.size(); ++j) {
    //         if (vec[i].f > vec[j].f) {
    //             res += 1LL * vec[i].s * vec[j].s;
    //         }
    //     }
    // }

    return solve(vec);

    // cout << "Q " << v + 1 << endl;
    // for (auto it : vec) {
    //     cout << it.f << ' ' << it.s << endl;
    // }

    // cout << "vc" << endl;
    // for (auto it : vc) {
    //     cout << it + 1 << ' ';
    // }
    // cout << endl << endl;
}

int main() {
    FAST_IO;
    cin >> n;
    rep(i, n) {
        cin >> a[i];
    }
    p[0] = -1;
    rep(i, n - 1) {
        int x, y;
        cin >> x >> y;
        x--; y--;
        ord.push_back(y);
        p[y] = x;
        g[x].push_back(y);
    }

    prepare();

    // query(1);
    // query(2);
    // query(3);
    // query(4);
    // query(5);

    // rep(v, n) {
    //     cout << "V " << v + 1 << endl;

    //     cout << "top " << top[v] + 1 << endl;
    //     cout << "seg ";
    //     for (auto it : seg[v]) {
    //         cout << "(" << it.f << ' ' << it.s << ") ";
    //     }
    //     cout << endl;
    //     cout << "h " << h[v] << endl;

    //     cout << endl;
    // }

    for (auto it : ord) {
        cout << query(it) << '\n';
        // query(it);
    }
}

Compilation message (stderr)

construction.cpp: In function 'll solve(const std::vector<std::pair<int, int> >&)':
construction.cpp:142:23: warning: comparison between signed and unsigned integer expressions [-Wsign-compare]
     for (int i = 0; i < a.size(); ++i) {
                     ~~^~~~~~~~~~
construction.cpp:147:23: warning: comparison between signed and unsigned integer expressions [-Wsign-compare]
     for (int i = 0; i < a.size(); ++i) {
                     ~~^~~~~~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...