#include <bits/stdc++.h>
#define size(x) ((int) (x).size())
using namespace std;
const int inf = 1e9;
struct Path { int dd, dc, cd, cc; }; // min cost such that lhs is dog/cat-friendly and rhs is dog/cat-friendly
struct Child { int d, c; }; // min cost to be dog-friendly and cat-friendly
// merge path clusters a and b, with b being closer to the root
Path merge_paths(const Path& a, const Path& b) {
return {
min({a.dd + b.dd, a.dc + b.cd}),
min({a.dd + b.dc, a.dc + b.cc}),
min({a.cd + b.dd, a.cc + b.cd}),
min({a.cd + b.dc, a.cc + b.cc})
};
}
// attach child cluster b to unit path cluster a
Path attach_child(const Path& a, const Child& b) {
if (a.cc) return {b.d, b.d + 1, b.d + 1, b.d + 2}; // there is a dog
else if (a.dd) return {b.c + 2, b.c + 1, b.c + 1, b.c}; // there is a cat
else return {b.d, min(b.d, b.c) + 1, min(b.d, b.c) + 1, b.c};
}
Child merge_children(const Child& a, const Child& b) {
return {a.d + b.d, a.c + b.c};
}
Child make_child(const Path& a) {
return {min(a.dd, a.cd), min(a.cc, a.dc)};
}
enum Type { MakeVertex, MergePaths, AttachChild, MergeChildren, MakeChild };
struct StaticTopTree {
int n;
vector<vector<int>> adj;
int root; // an index of the root in g
int stt_root; // an index of the root in static top tree
vector<int> par, lc, rc; // parent, left child, right child
vector<Type> type; // type of vertices
int nxt; // next node to write into
vector<Path> path;
vector<Child> child;
function<Path(int)> make_vertex;
void init(int _n, vector<vector<int>>& _adj, function<Path(int)> _make_vertex, int _root = 0) {
n = _n;
adj = _adj;
make_vertex = _make_vertex;
root = _root;
type.resize(4 * n);
par = lc = rc = vector<int>(4 * n, -1);
path.resize(4 * n);
child.resize(4 * n);
nxt = n;
build_stt();
build(stt_root);
}
void build(int u = -2) {
if (u == -1) return;
build(lc[u]);
build(rc[u]);
pull(u);
}
void update(int u) {
while (u != -1) pull(u), u = par[u];
}
Path query() {
return path[stt_root];
}
void print_stt(int u = -2) {
if (u == -2) {
for (int i = 0; i <= stt_root; i++) cerr << i << ": " << type[i] << endl;
u = stt_root;
}
if (u == -1) return;
print_stt(lc[u]);
print_stt(rc[u]);
if (~lc[u]) cerr << lc[u] << " " << u << endl;
if (~rc[u]) cerr << rc[u] << " " << u << endl;
}
private:
int dfs(int u) {
int sz = 1, mx = 0;
for (int& v : adj[u]) {
adj[v].erase(find(adj[v].begin(), adj[v].end(), u));
int res = dfs(v);
sz += res;
if (res > mx) mx = res, swap(v, adj[u][0]);
}
return sz;
}
int add(int u, int l, int r, Type t) {
if (u == -1) u = nxt++;
par[u] = -1, lc[u] = l, rc[u] = r, type[u] = t;
if (l != -1) par[l] = u;
if (r != -1) par[r] = u;
return u;
}
pair<int, int> merge(const vector<pair<int, int>>& nodes, Type t) {
if (size(nodes) == 1) return nodes[0];
int totsz = 0;
for (auto& [_, sz] : nodes) totsz += sz;
vector<pair<int, int>> lhs, rhs;
for (auto& [i, sz] : nodes) (totsz > sz ? lhs : rhs).emplace_back(i, sz), totsz -= sz * 2;
auto [l, szl] = merge(lhs, t);
auto [r, szr] = merge(rhs, t);
return {add(-1, l, r, t), szl + szr};
}
pair<int, int> _merge_path(int u) {
vector<pair<int, int>> nodes {_add_vertex(u)};
while (!adj[u].empty()) nodes.push_back(_add_vertex(u = adj[u][0]));
reverse(nodes.begin(), nodes.end());
return merge(nodes, Type::MergePaths);
}
pair<int, int> _merge_children(int u) {
vector<pair<int, int>> nodes;
for (int j = 1; j < size(adj[u]); j++) nodes.push_back(_make_child(adj[u][j]));
return nodes.empty() ? make_pair(-1, 0) : merge(nodes, Type::MergeChildren);
}
pair<int, int> _make_child(int u) {
auto [v, szv] = _merge_path(u);
return {add(-1, v, -1, Type::MakeChild), szv};
}
pair<int, int> _add_vertex(int u) {
auto [v, szv] = _merge_children(u);
return {add(u, -1, v, v == -1 ? Type::MakeVertex : Type::AttachChild), szv + 1};
}
void pull(int u) {
switch (type[u]) {
case MakeVertex:
path[u] = make_vertex(u);
break;
case MergePaths:
path[u] = merge_paths(path[lc[u]], path[rc[u]]);
break;
case AttachChild:
path[u] = attach_child(make_vertex(u), child[rc[u]]);
break;
case MergeChildren:
child[u] = merge_children(child[lc[u]], child[rc[u]]);
break;
case MakeChild:
child[u] = make_child(path[lc[u]]);
break;
}
}
void build_stt() {
dfs(root);
auto [i, n] = _merge_path(root);
stt_root = i;
}
};
int n;
vector<int> t;
StaticTopTree stt;
void initialize(int N, vector<int> A, vector<int> B) {
n = N;
t.resize(n);
vector<vector<int>> adj(n);
for (int i = 0; i < n - 1; i++) {
adj[A[i] - 1].push_back(B[i] - 1);
adj[B[i] - 1].push_back(A[i] - 1);
}
auto make_vertex = [&](int u) -> Path {
if (t[u] == 0) return {0, inf, inf, 0};
if (t[u] == 1) return {0, 1, 1, 2};
if (t[u] == 2) return {2, 1, 1, 0};
assert(0);
};
stt.init(n, adj, make_vertex);
}
int dog(int v) {
v--;
t[v] = 1;
stt.update(v);
Path res = stt.query();
return min({res.dd, res.dc, res.cd, res.cc});
}
int cat(int v) {
v--;
t[v] = 2;
stt.update(v);
Path res = stt.query();
return min({res.dd, res.dc, res.cd, res.cc});
}
int neighbor(int v) {
v--;
t[v] = 0;
stt.update(v);
Path res = stt.query();
return min({res.dd, res.dc, res.cd, res.cc});
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |