Submission #1331905

#TimeUsernameProblemLanguageResultExecution timeMemory
1331905toma_ariciuWombats (IOI13_wombats)C++20
100 / 100
4423 ms190280 KiB
#include <bits/stdc++.h>

#include "wombats.h"

using namespace std;
using ll = long long;

const int maxR = 5005, maxC = 205;
const int inf = 1e9;
int r, c;
int hh[maxR][maxC], vv[maxR][maxC];
int sp[maxC];
int aux[10][maxC][maxC];
int dlin[10][maxC][maxC];

struct Aint {
    struct Node {
        vector <vector <int>> dist;

        Node() : dist(c, vector <int> (c, 0)) {}
    };

    int n;
    vector <int> l, r;
    vector <Node> v;

    Node mergeNodes(Node &a, Node &b, int mid) {
        Node ans;
        vector <vector <int>> opt(c, vector <int> (c, 0));

        for (int cc = 0; cc < c; cc++) {
            ans.dist[cc][cc] = inf;
            for (int k = 0; k < c; k++) {
                int d = a.dist[cc][k] + vv[mid][k] + b.dist[k][cc];
                if (d < ans.dist[cc][cc]) {
                    ans.dist[cc][cc] = d;
                    opt[cc][cc] = k;
                }
            }
        }

        for (int d = 1; d < c; d++) {
            for (int c1 = 0; c1 + d < c; c1++) {
                int c2 = c1 + d;
                int l = opt[c1][c2 - 1], r = opt[c1 + 1][c2];

                ans.dist[c1][c2] = inf;
                for (int k = l; k <= r; k++) {
                    int dist = a.dist[c1][k] + vv[mid][k] + b.dist[k][c2];
                    if (dist < ans.dist[c1][c2]) {
                        ans.dist[c1][c2] = dist;
                        opt[c1][c2] = k;
                    }
                }
            }
            for (int c2 = 0; c2 + d < c; c2++) {
                int c1 = c2 + d;                
                int l = opt[c1 - 1][c2], r = opt[c1][c2 + 1];

                ans.dist[c1][c2] = inf;
                for (int k = l; k <= r; k++) {
                    int dist = a.dist[c1][k] + vv[mid][k] + b.dist[k][c2];
                    if (dist < ans.dist[c1][c2]) {
                        ans.dist[c1][c2] = dist;
                        opt[c1][c2] = k;
                    }
                }
            }
        }

        return ans;
    }

    int addNode() {
        int idx = v.size();
        v.emplace_back();
        l.emplace_back();
        r.emplace_back();
        return idx;
    }

    Node initInterval(int st, int dr) {
        for (int i = st; i <= dr; i++) {
            vector <int> sp(c, 0);
            for (int j = 0; j < c; j++) {
                sp[j] = hh[i][j];
                if (j) {
                    sp[j] += sp[j - 1];
                }
                // cerr << i << ' ' << j << ' ' << sp[j] << '\n';
                for (int x = 0; x <= j; x++) {
                    int val = 0;
                    if (j) {
                        val = sp[j - 1];
                    }
                    if (x) {
                        val -= sp[x - 1];
                    }
                    dlin[i - st][x][j] = val;
                    dlin[i - st][j][x] = val;
                }
            }
        }

        for (int c1 = 0; c1 < c; c1++) {
            for (int c2 = 0; c2 < c; c2++) {
                aux[0][c1][c2] = dlin[0][c1][c2];
            }
        }
        
        vector <vector <int>> opt(c, vector <int> (c, 0));

        for (int i = 1; i <= dr - st; i++) {
            for (int cc = 0; cc < c; cc++) {
                aux[i][cc][cc] = inf;
                for (int k = 0; k < c; k++) {
                    int dist = aux[i - 1][cc][k] + vv[st + i - 1][k] + dlin[i][k][cc];
                    if (dist < aux[i][cc][cc]) {
                        aux[i][cc][cc] = dist;
                        opt[cc][cc] = k;
                    }
                }
            }
    
            for (int d = 1; d < c; d++) {
                for (int c1 = 0; c1 + d < c; c1++) {
                    int c2 = c1 + d;
                    int l = opt[c1][c2 - 1], r = opt[c1 + 1][c2];
                    
                    aux[i][c1][c2] = inf;
                    for (int k = l; k <= r; k++) {
                        int dist = aux[i - 1][c1][k] + vv[st + i - 1][k] + dlin[i][k][c2];
                        if (dist < aux[i][c1][c2]) {
                            aux[i][c1][c2] = dist;
                            opt[c1][c2] = k;
                        }
                    }
                }
            }
    
            for (int d = 1; d < c; d++) {
                for (int c2 = 0; c2 + d < c; c2++) {
                    int c1 = c2 + d;
                    int l = opt[c1 - 1][c2], r = opt[c1][c2 + 1];
                    
                    aux[i][c1][c2] = inf;
                    for (int k = l; k <= r; k++) {
                        int dist = aux[i - 1][c1][k] + vv[st + i - 1][k] + dlin[i][k][c2];
                        if (dist < aux[i][c1][c2]) {
                            aux[i][c1][c2] = dist;
                            opt[c1][c2] = k;
                        }
                    }
                }
            }
        }

        Node ans;
        for (int c1 = 0; c1 < c; c1++) {
            for (int c2 = 0; c2 < c; c2++) {
                ans.dist[c1][c2] = aux[dr - st][c1][c2];
            }
        }
        return ans;
    }

    void build(int nod, int st, int dr) {
        if (dr - st <= 10) {
            // cerr << nod << ' ' << v.size() << '\n';
            v[nod] = initInterval(st - 1, dr - 1);
            return;
        }
        int med = (st + dr) / 2;
        l[nod] = addNode();
        build(l[nod], st, med);
        r[nod] = addNode();
        build(r[nod], med + 1, dr);

        v[nod] = mergeNodes(v[l[nod]], v[r[nod]], med - 1);
    }

    void update(int nod, int st, int dr, int poz) {
        if (dr - st <= 10) {
            v[nod] = initInterval(st - 1, dr - 1);
            return;
        }

        int med = (st + dr) / 2;
        if (poz <= med) {
            update(l[nod], st, med, poz);
        } else {
            update(r[nod], med + 1, dr, poz);
        }

        v[nod] = mergeNodes(v[l[nod]], v[r[nod]], med - 1);
    }

    void update(int poz) {
        update(0, 1, n, poz);
    }

    int query(int x, int y) {
        return v[0].dist[x][y];
    }

    void init(int N) {
        n = N;
        addNode();
        build(0, 1, n);
    }
}aint;

void minSelf(ll &x, ll y) {
    if (y < x) {
        x = y;
    }
}

void calc() {
    
}

void init(int R, int C, int H[5000][200], int V[5000][200]) {
    r = R, c = C;
    for (int i = 0; i < r; i++) {
        for (int j = 0; j < c; j++) {
            hh[i][j] = H[i][j];
            vv[i][j] = V[i][j];
            // cerr << i << ' ' << j << ' ' << hh[i][j] << '\n';
        }
    }
    
    aint.init(r);
}

void changeH(int P, int Q, int W) {
    hh[P][Q] = W;
    aint.update(P + 1);
}

void changeV(int P, int Q, int W) {
    vv[P][Q] = W;
    aint.update(P + 1);
}

int escape(int V1, int V2) {
    return aint.query(V1, V2);
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...