#include <bits/stdc++.h>
using namespace std;
#define sz(x) ((int)size(x))
#define all(x) begin(x), end(x)
#define trace(x) cout << #x << ": " << (x) << endl;
typedef long long ll;
mt19937 rnd(chrono::steady_clock::now().time_since_epoch().count());
int rand(int l, int r) { return (int) ((ll) rnd() % (r - l + 1)) + l; }
template<typename T>
bool ckmn(T &x, T y) {
if (x > y) {
x = y;
return true;
}
return false;
}
template<typename T>
bool ckmx(T &x, T y) {
if (x < y) {
x = y;
return true;
}
return false;
}
const int N = 1000001, infI = 1e9 + 7;
const ll infL = 3e18;
int n, t, m;
vector<int> g[N];
int dp[N], parent[N], depth[N], deg[N], gg[N], sum_deg[N];
vector<int> tops;
void dfs(int v, int p) {
parent[v] = p;
int mx1 = 0, mx2 = 0;
deg[v] = (p == v ? sz(g[v]) : sz(g[v]) - 1);
sum_deg[v] = sum_deg[p] + (p == v ? 1 : deg[v]);
for (int to: g[v]) {
if (to != p) {
depth[to] = depth[v] + 1;
dfs(to, v);
if (mx1 < dp[to]) {
mx2 = mx1;
mx1 = dp[to];
} else if (mx2 < dp[to]) {
mx2 = dp[to];
}
}
}
dp[v] = mx2 + deg[v];
tops.push_back(v);
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
auto smart_solve = [&]() -> int {
int ans = infI;
int l = infI, r = 3 * n;
int cnt = 0;
for (int i = 0; i < n; ++i) {
if (gg[i] != -1) {
l = min(l, gg[i]);
++cnt;
}
}
if (l == infI) {
return 0;
}
if (cnt == 1) {
return 1;
}
l -= 1;
while (l + 1 < r) {
int mid = (l + r) >> 1;
vector<int> pos;
vector<bool> used(n);
for (int x: tops) {
if (gg[x] > mid) {
pos.push_back(x);
}
}
int x = m;
int pnt = 0;
bool ok = true;
while (x != t) {
if (pnt < sz(pos)) {
used[pos[pnt++]] = true;
}
for (int to: g[x]) {
if (to != parent[x] && gg[to] > mid && !used[to]) {
ok = false;
break;
}
}
if (!ok) {
break;
}
x = parent[x];
}
if (ok) {
ans = min(ans, mid);
r = mid;
} else {
l = mid;
}
}
return ans;
};
auto not_very_smart_solve = [&]() -> int {
int ans = infI;
int l = infI;
int cnt = 0;
for (int i = 0; i < n; ++i) {
if (gg[i] != -1) {
l = min(l, gg[i]);
++cnt;
assert(gg[i] <= 3 * n);
}
}
if (l == infI) {
return 0;
}
if (cnt == 1) {
return 1;
}
for (int mid = l; mid <= 3 * n; ++mid) {
vector<int> pos;
vector<bool> used(n);
for (int x: tops) {
if (gg[x] > mid) {
pos.push_back(x);
}
}
int x = m;
int pnt = 0;
bool ok = true;
while (x != t) {
if (pnt < sz(pos)) {
used[pos[pnt++]] = true;
}
for (int to: g[x]) {
if (to != parent[x] && gg[to] > mid && !used[to]) {
ok = false;
break;
}
}
if (!ok) {
break;
}
x = parent[x];
}
if (ok) {
ans = min(ans, mid);
}
}
return ans;
};
auto stupid_solve= [&]() -> int {
int ans = infI;
vector<int> pos;
for (int i = 0; i < n; ++i) {
if (gg[i] != -1)
pos.push_back(i);
}
if (sz(pos) == 1)
return 1;
if (pos.empty())
return 0;
do {
vector<bool> used(n);
int x = m;
int prv = -1;
int pnt = 0;
while (x != t) {
if (pnt < sz(pos))
used[pos[pnt++]] = true;
for (int to : g[x]) {
if (to != parent[x] && prv != to && !used[to]) {
int now = dp[to] + 1;
int pp = prv;
for (int z = x; z != t; z = parent[z]) {
for (int uu : g[z]) {
if (!used[uu] && pp != uu & to != uu && uu != parent[z])
++now;
}
pp = z;
}
now += pnt;
ckmn(ans, now);
}
}
prv = x;
x = parent[x];
}
} while (next_permutation(all(pos)));
return ans;
};
if (1) {
cin >> n >> t >> m;
--t, --m;
if (t == m) {
exit(1);
}
for (int i = 1; i < n; ++i) {
int a, b;
cin >> a >> b;
--a, --b;
g[a].push_back(b);
g[b].push_back(a);
}
dfs(t, t);
memset(gg, -1, sizeof(gg));
{
int x = m, prv = -1;
while (x != t) {
for (int to: g[x]) {
if (to != parent[x] && to != prv) {
gg[to] = dp[to] + (x == m ? sum_deg[x] - depth[x] : sum_deg[x] - depth[x] - 1);
}
}
prv = x;
x = parent[x];
}
}
cout << stupid_solve();
} else {
n = 100;
t = 1;
m = 2;
--t, --m;
while (true) {
if (t == m) {
exit(1);
}
tops.clear();
for (int i = 0; i < n; ++i) {
g[i].clear();
sum_deg[i] = 0;
deg[i] = 0;
dp[i] = 0;
depth[i] = 0;
parent[i] = 0;
}
for (int i = 1; i < n; ++i) {
int a = rand(0, i - 1);
g[a].push_back(i);
g[i].push_back(a);
}
dfs(t, t);
memset(gg, -1, sizeof(gg[0]) * n);
{
int x = m, prv = -1;
while (x != t) {
for (int to: g[x]) {
if (to != parent[x] && to != prv) {
gg[to] = dp[to] + (x == m ? sum_deg[x] - depth[x] : sum_deg[x] - depth[x] - 1);
}
}
prv = x;
x = parent[x];
}
}
int smart = not_very_smart_solve();
int stupid = dp[m];
if (smart != stupid) {
cout << "NO!" << endl;
cout << n << " " << t + 1 << " " << m + 1 << endl;
for (int i = 1; i < n; ++i)
cout << parent[i] + 1 << " " << i + 1 << endl;
trace(smart)
trace(stupid)
return 0;
}
cout << "OK!" << endl;
}
}
return 0;
}
Compilation message
mousetrap.cpp: In lambda function:
mousetrap.cpp:194:53: warning: suggest parentheses around comparison in operand of '&' [-Wparentheses]
194 | if (!used[uu] && pp != uu & to != uu && uu != parent[z])
| ~~~^~~~~
mousetrap.cpp: In function 'int main()':
mousetrap.cpp:66:10: warning: variable 'smart_solve' set but not used [-Wunused-but-set-variable]
66 | auto smart_solve = [&]() -> int {
| ^~~~~~~~~~~
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
19 ms |
27660 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
326 ms |
84888 KB |
Output is correct |
2 |
Correct |
274 ms |
79424 KB |
Output is correct |
3 |
Incorrect |
794 ms |
85992 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
19 ms |
27660 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
19 ms |
27660 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |