Submission #367514

#TimeUsernameProblemLanguageResultExecution timeMemory
367514mihai145Factories (JOI14_factories)C++14
100 / 100
5919 ms218992 KiB
#include "factories.h"
#include <vector>
//#include <iostream>
#include <algorithm>
#include <set>
#include <stack>

using namespace std;

const int NMAX = 500000;
const int LGMAX = 19;

const long long INF = 1e18;

vector<pair<int, int> > g[NMAX + 5];

long long dist[NMAX + 5];
int lvl[NMAX + 5];

int firstAp[NMAX + 2];
vector<int> rmq[LGMAX + 2];

int log2[2 * NMAX + 2];

void dfs(int node, int parent = -1) {

    firstAp[node] = (int) rmq[0].size();
    rmq[0].push_back(node);

    for (auto it : g[node]) {
        if (it.first != parent) {
            lvl[it.first] = lvl[node] + 1;
            dist[it.first] = dist[node] + it.second;
            dfs(it.first, node);
            rmq[0].push_back(node);
        }
    }
}

int GetMinLvl(int A, int B) {
    if (lvl[A] < lvl[B]) {
        return A;
    }

    return B;
}

void BuildRmq() {

    log2[1] = 0;
    for (int i = 2; i <= 2 * NMAX; i++) {
        log2[i] = log2[i / 2] + 1;
    }

    for (int i = 1; i <= LGMAX; i++) {
        if ((1 << i) > (int) rmq[0].size()) {
            break;
        } else {
            for (int j = 0; j < (int) rmq[0].size(); j++) {
                if (j + (1 << i) > (int) rmq[0].size()) {
                    break;
                } else {
                    rmq[i].push_back(GetMinLvl(rmq[i - 1][j], rmq[i - 1][j + (1 << (i - 1))]));
                }
            }
        }
    }
}

int LCA(int A, int B) {
    A = firstAp[A];
    B = firstAp[B];

    if (A > B)
        swap(A, B);

    int k = log2[B - A + 1];
    return GetMinLvl(rmq[k][A], rmq[k][B - (1 << k) + 1]);
}

void Init(int N, int A[], int B[], int D[]) {
    for (int i = 1; i < N; i++) {
        g[A[i - 1]].push_back({B[i - 1], D[i - 1]});
        g[B[i - 1]].push_back({A[i - 1], D[i - 1]});
    }

    dfs(0);
    BuildRmq();
}

vector<pair<int, long long>> g2[NMAX + 2];
long long dp1[NMAX + 2], dp2[NMAX + 2];

void computeDp1(int node, int parent = -1) {
    for (auto it : g2[node]) {
        if (it.first != parent) {
            computeDp1(it.first, node);
            dp1[node] = min(dp1[node], it.second + dp1[it.first]);
        }
    }
}

void computeDp2(int node, int parent = -1) {
    dp2[node] = min(dp1[node], dp2[node]);

    for (auto it : g2[node]) {
        if (it.first != parent) {
            dp2[it.first] = min(dp2[it.first], it.second + dp2[node]);
            computeDp2(it.first, node);
        }
    }
}

bool isAncestor(int A, int B) {
    return (A == LCA(A, B));
}

long long GetDist(int A, int B) {
    return dist[A] + dist[B] - 2 * dist[LCA(A, B)];
}

long long Solve(vector<int> &base, vector<int> &red, vector<int> &blue) {
    for (auto it : base) {
        g2[it].clear();
        dp1[it] = dp2[it] = INF;
    }

    for (auto it : red) {
        dp1[it] = 0;
    }

    stack<int> st;
    for (auto vert : base) {
        while (!st.empty() && isAncestor(st.top(), vert) == false) {
            st.pop();
        }

        if (!st.empty()) {
            long long dist = GetDist(st.top(), vert);
            g2[st.top()].push_back({vert, dist});
            g2[vert].push_back({st.top(), dist});
        }

        st.push(vert);
    }

    computeDp1(base[0]);
    computeDp2(base[0]);

    long long ans = INF;
    for (auto it : blue) {
        ans = min(ans, dp2[it]);
    }

    return ans;
}

inline bool cmp(const int A, const int B) {
    return firstAp[A] < firstAp[B];
}

long long Query(int S, int X[], int T, int Y[]) {
    vector<int> base;
    for (int i = 0; i < S; i++)
        base.push_back(X[i]);
    for (int i = 0; i < T; i++)
        base.push_back(Y[i]);

    sort(base.begin(), base.end(), cmp);

    vector<int> aux = base;
    for (int i = 1; i < (int) base.size(); i++)
        aux.push_back(LCA(base[i - 1], base[i]));

    set<int> vertices;
    for (auto it : aux)
        vertices.insert(it);
    base.clear();
    for (auto it : vertices)
        base.push_back(it);

    sort(base.begin(), base.end(), cmp);

    vector<int> red, blue;
    for (int i = 0; i < S; i++)
        red.push_back(X[i]);
    for (int i = 0; i < T; i++)
        blue.push_back(Y[i]);

    return Solve(base, red, blue);
}

Compilation message (stderr)

factories.cpp:23:5: warning: built-in function 'log2' declared as non-function [-Wbuiltin-declaration-mismatch]
   23 | int log2[2 * NMAX + 2];
      |     ^~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...