#include "factories.h"
#include <bits/stdc++.h>
using namespace std;
#define task "main"
#define no "NO"
#define yes "YES"
#define F first
#define S second
#define vec vector
#define _mp make_pair
#define ii pair<int, int>
#define il pair<int, long long>
#define sz(x) (int)x.size()
#define all(x) x.begin(), x.end()
#define evoid(val) return void(std::cout << val)
#define FOR(i, a, b) for(int i = (a); i <= (b); ++i)
#define FOD(i, b, a) for(int i = (b); i >= (a); --i)
template<class X, class Y>
bool minimize(X &x, Y y) {
if (x > y) {
x = y;
return true;
}
return false;
}
const int MAX_N = (int)5e5 + 5;
const long long INF = (long long)1e18;
int numNode, numQuery;
vector<ii> adj[MAX_N];
int in[MAX_N], out[MAX_N], firstPos[MAX_N];
int timer = 0, len = 0;
ii LCA[20][MAX_N];
long long dist[MAX_N];
int mode[MAX_N];
void dfs(int u, int p) {
in[u] = ++timer;
LCA[0][++len] = {in[u], u};
firstPos[u] = len;
for (auto [v, w] : adj[u]) {
if (v == p) continue;
dist[v] = dist[u] + w;
dfs(v, u);
LCA[0][++len] = {in[u], u};
}
out[u] = timer;
}
void setupLCA() {
dfs(0, -1);
FOR(k, 1, 19) FOR(i, 1, len - (1 << k) + 1)
LCA[k][i] = min(LCA[k - 1][i], LCA[k - 1][i + (1 << (k - 1))]);
}
int getLCA(int u, int v) {
if (firstPos[u] > firstPos[v]) swap(u, v);
int l = firstPos[u], r = firstPos[v];
int k = __lg(r - l + 1);
return min(LCA[k][l], LCA[k][r - (1 << k) + 1]).S;
}
bool isAncestor(int u, int v) {
return in[u] <= in[v] && in[v] <= out[u];
}
void Init(int n, int a[], int b[], int d[]) {
numNode = n;
FOR(i, 0, numNode - 2) {
adj[a[i]].push_back({b[i], d[i]});
adj[b[i]].push_back({a[i], d[i]});
}
setupLCA();
memset(mode, -1, sizeof(mode));
}
bool cmp(int &x, int &y) {
return in[x] < in[y];
}
vector<il> newADJ[MAX_N];
void addNewEdge(int u, int v, long long w) {
newADJ[u].push_back({v, w});
newADJ[v].push_back({u, w});
}
long long res = INF;
// for each node, compute minimum distance from it to a 1-type node and a 0-type node in its subtree (in virtual tree)
vector<long long> calc(int u, int p) {
vector<long long> curDP(2, INF);
if (mode[u] != -1) curDP[mode[u]] = 0;
for (auto [v, w] : newADJ[u]) {
if (v == p) continue;
vector<long long> tmp = calc(v, u);
FOR(i, 0, 1) minimize(curDP[i], tmp[i] + w);
}
minimize(res, curDP[0] + curDP[1]);
return curDP;
}
long long Query(int s, int x[], int t, int y[]) {
// build virtual tree
vector<int> nodes;
FOR(i, 0, s - 1) nodes.push_back(x[i]), mode[x[i]] = 0;
FOR(i, 0, t - 1) nodes.push_back(y[i]), mode[y[i]] = 1;
sort(nodes.begin(), nodes.end(), cmp);
FOR(i, 1, s + t - 1) nodes.push_back(getLCA(nodes[i - 1], nodes[i]));
sort(nodes.begin(), nodes.end(), cmp);
nodes.erase(unique(nodes.begin(), nodes.end()), nodes.end());
stack<int> st;
st.push(nodes[0]);
int n = (int)nodes.size();
FOR(i, 1, n - 1) {
while (!st.empty() && !isAncestor(st.top(), nodes[i]))
st.pop();
if (!st.empty()) addNewEdge(st.top(), nodes[i], dist[nodes[i]] - dist[st.top()]);
st.push(nodes[i]);
}
res = INF;
calc(nodes[0], -1);
for (int u : nodes) {
mode[u] = -1;
newADJ[u].clear();
}
return res;
}
/* Lak lu theo dieu nhac!!!! */
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |