제출 #1248692

#제출 시각아이디문제언어결과실행 시간메모리
1248692CodeLakVNFactories (JOI14_factories)C++20
15 / 100
8090 ms98664 KiB
//#include "factories.h"
#include <bits/stdc++.h>
using namespace std;

#define task "main"
#define no "NO"
#define yes "YES"
#define F first
#define S second
#define vec vector
#define _mp make_pair
#define ii pair<int, int>
#define sz(x) (int)x.size()
#define all(x) x.begin(), x.end()
#define evoid(val) return void(std::cout << val)
#define FOR(i, a, b) for(int i = (a); i <= (b); ++i)
#define FOD(i, b, a) for(int i = (b); i >= (a); --i)

template<class X, class Y>
    bool minimize(X &x, Y y) {
        if (x > y) {
            x = y;
            return true;
        }
        return false;
    }

const int MAX_N = (int)5e5 + 5;
const long long INF = (long long)1e18;

int numNode, numQuery;
vector<ii> adj[MAX_N];

int sz[MAX_N], par[MAX_N];
bool del[MAX_N];

long long minDist[MAX_N]; // minimum distance from u to its children in centroid tree

struct LCA {
    int par[20][MAX_N];
    long long dist[MAX_N];
    int high[MAX_N];

    void dfs(int u, int p) {
        for (ii e : adj[u]) {
            int v = e.F, w = e.S;
            if (v == p) continue;
            par[0][v] = u;
            dist[v] = dist[u] + w;
            high[v] = high[u] + 1;
            dfs(v, u);
        }
    }

    void setup() {
        dfs(0, -1);
        FOR(i, 1, 19) FOR(u, 0, numNode - 1)
            par[i][u] = par[i - 1][par[i - 1][u]];
    }

    int getLCA(int u, int v) {
        if (high[u] < high[v]) swap(u, v);
        FOD(i, 19, 0) if (high[par[i][u]] >= high[v])
            u = par[i][u];
        if (u == v) return u;
        FOD(i, 19, 0) if (par[i][u] != par[i][v]) {
            u = par[i][u];
            v = par[i][v];
        }
        return par[0][u];
    }

    long long getDist(int u, int v) {
        return dist[u] + dist[v] - 2LL * dist[getLCA(u, v)];
    }
} lca;

void countChild(int u, int p) {
    sz[u] = 1;
    for (ii e : adj[u]) {
        int v = e.F;
        if (v == p || del[v]) continue;
        countChild(v, u);
        sz[u] += sz[v];
    }
}

int centroid(int u, int p, int m) {
    for (ii e : adj[u]) {
        int v = e.F;
        if (del[v] || v == p) continue;
        if (sz[v] > m / 2) return centroid(v, u, m);
    }
    return u;
}

int decompose(int u) {
    countChild(u, -1);
    int m = sz[u];
    u = centroid(u, -1, m);
    del[u] = 1;
    minDist[u] = INF;

    for (ii e : adj[u]) {
        int v = e.F;
        if (del[v]) continue;
        int x = decompose(v);
        par[x] = u;
    }

    return u;
}

void addNode(int u) {
    int curPar = u;
    while (curPar != -1) {
        minimize(minDist[curPar], lca.getDist(curPar, u));
        curPar = par[curPar];
    }
}

void delNode(int u) {
    int curPar = u;
    while (curPar != -1) {
        minDist[curPar] = INF;
        curPar = par[curPar];
    }
}

long long calcMinDist(int u) {
    int curPar = u;
    long long res = INF;
    while (curPar != -1) {
        minimize(res, minDist[curPar] + lca.getDist(curPar, u));
        curPar = par[curPar];
    }
    return res;
}

void Init(int n, int a[], int b[], int d[]) {
    numNode = n;
    FOR(i, 0, numNode - 2) {
        adj[a[i]].push_back({b[i], d[i]});
        adj[b[i]].push_back({a[i], d[i]});
    }

    int root = decompose(0);
    par[root] = -1;
    lca.setup();
}

long long Query(int s, int x[], int t, int y[]) {
    FOR(i, 0, s - 1) addNode(x[i]);
    long long ans = INF;
    FOR(i, 0, t - 1) minimize(ans, calcMinDist(y[i]));
    FOR(i, 0, s - 1) delNode(x[i]);
    return ans;
}

/* Lak lu theo dieu nhac!!!! */
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...