#include <bits/stdc++.h>
using namespace std;
using ll = long long;
using ld = long double;
using pl = pair<ll,ll>;
using pii = pair<int,int>;
using tpl = tuple<int,int,int>;
#define all(a) a.begin(), a.end()
#define filter(a) a.erase(unique(all(a)), a.end())
const int mn = 1e5 + 5;
const int INF = 1e9;
struct node : vector<vector<int>> {
node() : vector<vector<int>>(2, vector<int>(2, INF)) {}
node (int a, int b, int c, int d) : vector<vector<int>>({{a, b}, {c, d}}) {}
const node operator + (const node &o) const {
if ((*this) == node()) return o;
if (o == node()) return (*this);
node ans;
for (int i = 0; i < 2; i++)
for (int j = 0; j < 2; j++)
for (int s1 = 0; s1 < 2; s1++)
for (int s2 = 0; s2 < 2; s2++)
if (max((*this)[i][s1], o[s2][j]) != INF)
ans[i][j] = min(ans[i][j], (*this)[i][s1] + o[s2][j] + (s1 ^ s2));
return ans;
}
};
struct IT {
vector<node> tr;
IT (int sz) : tr(4 * sz) {}
void update (int pos, node cur, int k, int l, int r) {
for (; l < r;) {
int mid = (l + r) >> 1;
if (pos <= mid) k <<= 1, r = mid;
else k <<= 1, k |= 1, l = mid + 1;
}
tr[k] = cur;
for (k >>= 1; k > 0; k >>= 1)
tr[k] = tr[k << 1] + tr[k << 1 | 1];
}
node query (int a, int b, int k, int l, int r) {
if (b < l || r < a) return node();
if (a <= l && r <= b) return tr[k];
int mid = (l + r) >> 1;
return query(a, b, k << 1, l, mid) + query(a, b, k << 1 | 1, mid + 1, r);
}
};
int depth[mn], par[mn], num[mn], sz[mn], timeDfs, n;
int chain[mn], tail[mn], sumChild[2][mn];
vector<int> adj[mn];
IT tree(mn);
bool isDog[mn], isCat[mn]; // for debug
int szDfs (int u, int p) {
sz[u] = 1;
for (int v : adj[u])
if (v != p) sz[u] += szDfs(v, u);
return sz[u];
}
void dfs (int u, int p, int d, bool toP) {
if (u == 1) szDfs(u, p);
chain[u] = (toP ? chain[p] : u), tail[chain[u]] = u;
num[u] = ++timeDfs, depth[u] = d, par[u] = p;
sort(all(adj[u]), [&] (int a, int b) { return sz[a] > sz[b]; });
bool heavy = 1;
for (int v : adj[u])
if (v != p) dfs(v, u, d + 1, heavy), heavy = 0;
}
void update (int u, node upd) {
while (u) {
int head = chain[u]; node nxt = tree.query(num[par[head]], num[par[head]], 1, 1, n);
if (nxt[0][0] != INF) nxt[0][0] -= sumChild[0][par[head]];
if (nxt[1][1] != INF) nxt[1][1] -= sumChild[1][par[head]];
// remove contribution to sumChild[...][par[head]]
node cur = tree.query(num[head], num[tail[head]], 1, 1, n);
sumChild[0][par[head]] -= min(min(cur[0][0], cur[0][1]), min(cur[1][0], cur[1][1]) + 1);
sumChild[1][par[head]] -= min(min(cur[1][0], cur[1][1]), min(cur[0][0], cur[0][1]) + 1);
// update segment tree accordingly
if (upd[0][0] != INF) upd[0][0] += sumChild[0][u];
if (upd[1][1] != INF) upd[1][1] += sumChild[1][u];
tree.update(num[u], upd, 1, 1, n);
// re-add contribution to sumChild[...][par[head]]
cur = tree.query(num[head], num[tail[head]], 1, 1, n);
sumChild[0][par[head]] += min(min(cur[0][0], cur[0][1]), min(cur[1][0], cur[1][1]) + 1);
sumChild[1][par[head]] += min(min(cur[1][0], cur[1][1]), min(cur[0][0], cur[0][1]) + 1);
u = par[head], upd = nxt;
}
}
int query() { return min(sumChild[0][0], sumChild[1][0]); }
int cat (int u) {
update(u, node(0, INF, INF, INF));
isCat[u] = 1;
return query();
}
int dog (int u) {
update(u, node(INF, INF, INF, 0));
isDog[u] = 1;
return query();
}
int neighbor (int u) {
update(u, node(0, INF, INF, 0));
isCat[u] = isDog[u] = 0;
return query();
}
void initialize (int _n, vector<int> _a, vector<int> _b) {
n = _n;
for (int i = 0; i < n - 1; i++) {
adj[_a[i]].push_back(_b[i]);
adj[_b[i]].push_back(_a[i]);
}
dfs(1, 0, 1, 0);
for (int i = 1; i <= n; i++)
tree.update(i, node(0, INF, INF, 0), 1, 1, n);
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |