제출 #1248675

#제출 시각아이디문제언어결과실행 시간메모리
1248675CodeLakVNFactories (JOI14_factories)C++20
15 / 100
8058 ms397228 KiB
#include "factories.h"
#include <bits/stdc++.h>
using namespace std;

#define task "main"
#define no "NO"
#define yes "YES"
#define F first
#define S second
#define vec vector
#define _mp make_pair
#define ii pair<int, int>
#define sz(x) (int)x.size()
#define all(x) x.begin(), x.end()
#define evoid(val) return void(std::cout << val)
#define FOR(i, a, b) for(int i = (a); i <= (b); ++i)
#define FOD(i, b, a) for(int i = (b); i >= (a); --i)

template<class X, class Y>
    bool minimize(X &x, Y y) {
        if (x > y) {
            x = y;
            return true;
        }
        return false;
    }

const int MAX_N = (int)5e5 + 5;
const long long INF = (long long)1e18;

int numNode, numQuery;
vector<ii> adj[MAX_N];

int sz[MAX_N], par[MAX_N];
bool del[MAX_N];
map<int, long long> dist[MAX_N];

long long minDist[MAX_N]; // minimum distance from u to its children in centroid tree

void countChild(int u, int p) {
    sz[u] = 1;
    for (ii e : adj[u]) {
        int v = e.F;
        if (v == p || del[v]) continue;
        countChild(v, u);
        sz[u] += sz[v];
    }
}

int centroid(int u, int p, int m) {
    for (ii e : adj[u]) {
        int v = e.F;
        if (del[v] || v == p) continue;
        if (sz[v] > m / 2) return centroid(v, u, m);
    }
    return u;
}

void calcDist(int u, int p, int root) {
    for (ii e : adj[u]) {
        int v = e.F, w = e.S;
        if (v == p || del[v]) continue;
        dist[root][v] = dist[root][u] + w;
        calcDist(v, u, root);
    }
}

int decompose(int u) {
    countChild(u, -1);
    int m = sz[u];
    u = centroid(u, -1, m);
    del[u] = 1;
    minDist[u] = INF;

    calcDist(u, -1, u);
    for (ii e : adj[u]) {
        int v = e.F;
        if (del[v]) continue;
        int x = decompose(v);
        par[x] = u;
    }

    return u;
}

void addNode(int u) {
    int curPar = u;
    while (curPar != -1) {
        minimize(minDist[curPar], dist[curPar][u]);
        curPar = par[curPar];
    }
}

void delNode(int u) {
    int curPar = u;
    while (curPar != -1) {
        minDist[curPar] = INF;
        curPar = par[curPar];
    }
}

long long calcMinDist(int u) {
    int curPar = u;
    long long res = INF;
    while (curPar != -1) {
        minimize(res, minDist[curPar] + dist[curPar][u]);
        curPar = par[curPar];
    }
    return res;
}

void Init(int n, int a[], int b[], int d[]) {
    numNode = n;
    FOR(i, 0, numNode - 2) {
        adj[a[i]].push_back({b[i], d[i]});
        adj[b[i]].push_back({a[i], d[i]});
    }

    int root = decompose(0);
    par[root] = -1;
}

long long Query(int s, int x[], int t, int y[]) {
    FOR(i, 0, s - 1) addNode(x[i]);
    long long ans = INF;
    FOR(i, 0, t - 1) minimize(ans, calcMinDist(y[i]));
    FOR(i, 0, s - 1) delNode(x[i]);
    return ans;
}

/* Lak lu theo dieu nhac!!!! */
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...