Submission #305918

#TimeUsernameProblemLanguageResultExecution timeMemory
305918limabeansFactories (JOI14_factories)C++17
33 / 100
8045 ms423932 KiB
#include "factories.h"
#include <bits/stdc++.h>
using namespace std;

template<typename T>
void out(T x) { cout << x << endl; exit(0); }
#define watch(x) cout << (#x) << " is " << (x) << endl


struct sparse_table {
 
    int LOG = 0;
    int n;
    vector<pair<int,int>> a[100];
    vector<int> lg_floor;
 
    pair<int,int> eval(pair<int,int> x, pair<int,int> y) {
	return min(x, y);
    }
 
    void init(vector<pair<int,int>> v) {
	n = v.size();
	LOG = 0;
	while ((1<<LOG) <= n) LOG++;
	lg_floor.resize(n+10);
	lg_floor[1] = 0;
	for (int i=2; i<n+10; i++) lg_floor[i] = 1 + lg_floor[i/2];
	for (int j=0; j<LOG; j++) a[j].resize(n);
	for (int i=0; i<n; i++) a[0][i] = v[i];
	
 
	for (int j=1; j<LOG; j++) {
	    for (int i=0; i<n; i++) {
		a[j][i] = a[j-1][i];
		if (i + (1<<(j-1)) < n) {
		    a[j][i] = eval(a[j][i], a[j-1][i + (1<<(j-1))]);
		}
	    }
	}
    }
 
    pair<int,int> get(int l, int r) {
	assert(l<=r);
	int lg = lg_floor[r - l + 1];
 
	return eval(a[lg][l], a[lg][r-(1<<lg)+1]);
	
    }
};


using ll = long long;

const ll inf = 1e18;
const int maxn = 5e5 + 100;
const int LOG = 20;

int n;
vector<pair<ll,int>> g[maxn];
int par[LOG+1][maxn];
int dep[maxn];
ll len[maxn];

int cloc = 0;
int tin[maxn];
vector<pair<int,int>> ett;

sparse_table tbl;

void dfs(int at, int p) {
    tin[at] = cloc++;
    ett.push_back({tin[at], at});
    
    for (int j=1; j<LOG; j++) {
	par[j][at] = par[j-1][par[j-1][at]];
    }
    for (auto ed: g[at]) {
	ll wei = ed.first;
	int to = ed.second;
	if (to == p) continue;
	dep[to] = 1+dep[at];
	len[to] = len[at] + wei;
	par[0][to] = at;
	dfs(to, at);
	ett.push_back({tin[at], at});
	cloc++;
    }

    ett.push_back({tin[at], at});
    cloc++;
}


int lca(int u, int v) {
    if (u==v) return u;
    int lo = tin[u];
    int hi = tin[v];
    if (lo>hi) swap(lo, hi);
    pair<int,int> res = tbl.get(lo, hi);
    return res.second;
}


int lcaBrute(int u, int v) {
    if (dep[u]<dep[v]) swap(u, v);
    int dx = dep[u]-dep[v];
    for (int j=0; j<LOG; j++) {
	if (dx>>j&1) {
	    u = par[j][u];
	}
    }

    if (u == v) return u;

    for (int j=LOG-1; j>=0; j--) {
	if (par[j][u] != par[j][v]) {
	    u = par[j][u];
	    v = par[j][v];
	}
    }

    return par[0][u];
}


ll dist(int u, int v) {
    int mid = lca(u, v);
    return len[u] + len[v] - 2ll*len[mid];
}


vector<int> ctree[maxn];
bool bad[maxn];
int siz[maxn];

void dfs2(int at, int p) {
    siz[at] = 1;
    for (auto ed: g[at]) {
	int to = ed.second;
	if (!bad[to] && to != p) {
	    dfs2(to, at);
	    siz[at] += siz[to];
	}
    }
}

int findCenter(int at) {
 

    dfs2(at, -1);
    int all = siz[at];
    bool ok = true;
    while (ok) {
	ok = false;
	for (auto ed: g[at]) {
	    int to = ed.second;
	    if (bad[to]) continue;
	    if (siz[to] > siz[at]) continue;
	    if (siz[to]*2 > all) {
		at = to;
		ok = true;
		break;
	    }
	}
    }

    return at;
}

int build(int at) {
    at = findCenter(at);
    bad[at] = true;
    for (auto ed: g[at]) {
	int to = ed.second;
	if (!bad[to]) {
	    int nc = build(to);
	    ctree[at].push_back(nc);
	    ctree[nc].push_back(at);
	}
    }
    return at;
}

int cpar[maxn];

void cdfs(int at, int p, int dep=0) {
    assert(dep <= 23);
    for (int to: ctree[at]) {
	if (to == p) continue;
	cpar[to] = at;
	cdfs(to, at, dep+1);
    }
}



ll nearest[maxn];

void Init(int N, int A[], int B[], int D[]) {
    n = N;
    for (int i=0; i<n-1; i++) {
	int u = A[i];
	int v = B[i];
	ll d = D[i];
	g[u].push_back({d,v});
	g[v].push_back({d,u});
    }
    
    dfs(0, -1);
    tbl.init(ett);
    
    int root = build(0);
    cpar[root] = -1;
    cdfs(root, -1);

    for (int i=0; i<n+10; i++) {
	nearest[i] = inf;
    }
}



long long Query(int S, int X[], int T, int Y[]) {
    ll res = inf;
    for (int i=0; i<S; i++) {
	int node = X[i];
	int at = node;
	int iter=0;
	while (~at) {
	    nearest[at] = min(nearest[at], dist(node, at));
	    at = cpar[at];
	    assert(++iter <= 23);
	}
    }


    for (int i=0; i<T; i++) {
	int node = Y[i];
	int at = node;
	int iter=0;
	while (~at) {
	    ll cur = nearest[at] + dist(node, at);
	    res = min(res, cur);
	    at = cpar[at];
	    assert(++iter <= 23);
	}
    }

     for (int i=0; i<S; i++) {
	int node = X[i];
	int at = node;
	int iter=0;
	while (~at) {
	    nearest[at] = inf;
	    at = cpar[at];
	    assert(++iter <= 23);
	}
    }

    
    
    return res;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...