This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
#include <bits/stdc++.h>
#define all(dataStructure) dataStructure.begin(),dataStructure.end()
#define ll long long
using namespace std;
namespace std {
template <typename T, int D>
struct _vector : public vector <_vector <T, D - 1>> {
static_assert(D >= 1, "Dimension must be positive!");
template <typename... Args>
_vector(int n = 0, Args... args) : vector <_vector <T, D - 1>> (n, _vector <T, D - 1> (args...)) {}
};
// _vector <int, 3> a(n, m, k);: int a[n][m][k].
// _vector <int, 3> a(n, m, k, x);: int a[n][m][k] initialized with x.
template <typename T>
struct _vector <T, 1> : public vector <T> {
_vector(int n = 0, const T& val = T()) : vector <T> (n, val) {}
};
}
const int MAX = 5e5 + 3;
const ll MOD[] = {1000000007, 998244353};
int n, k;
int a[MAX];
vector <int> adj[MAX];
void sub1() {
sort(a + 1, a + k + 1);
vector <int> dp(k + 2, n + 2);
dp[0] = 0;
for (int i = 1; i <= k; i++) {
dp[i] = dp[i - 1] + 1;
if (i >= 2 && (a[i] - a[i - 1]) % 2 == 0) dp[i] = min(dp[i], dp[i - 2] + 1);
}
cout << dp[k];
}
void sub2() {
vector <int> minDist(n + 1, n + 1);
vector <int> f(n + 1);
vector <int> g(1 << k);
vector <int> dp(1 << k, n);
for (int i = 1; i <= k; i++) {
minDist[a[i]] = 0;
f[a[i]] = 1 << (i - 1);
queue <int> q;
q.push(a[i]);
while (q.size()) {
int u = q.front();
q.pop();
for (int &v : adj[u]) {
if (minDist[v] == minDist[u] + 1) {
f[v] |= (1 << (i - 1));
q.push(v);
}
if (minDist[v] > minDist[u] + 1) {
minDist[v] = minDist[u] + 1;
f[v] = (1 << (i - 1));
q.push(v);
}
}
}
}
for (int i = 1; i <= n; i++) g[f[i]] = 1;
for (int i = 1; i < (1 << k); i++) {
if (g[i]) for (int j = i; j > 0; j = (j - 1) & i) {
g[j] = 1;
}
}
dp[0] = 0;
for (int i = 1; i < (1 << k); i++) {
for (int j = i; j > 0; j = (j - 1) & i) {
if (g[j]) {
dp[i] = min(dp[i], dp[i ^ j] + 1);
}
}
}
cout << dp[(1 << k) - 1];
}
void sub3() {
if (n > 5000) return void(cout << k);
_vector <int, 2> dp(n + 1, n + 2, n + 2);
vector <int> f(n + 1);
vector <int> best(n + 1);
vector <bool> spec(n + 1);
for (int i = 1; i <= k; i++) spec[a[i]] = 1;
function <void(int, int)> dfs = [&](int u, int pre) -> void {
if (spec[u]) {
dp[u][0] = 0;
fill(all(f), 0);
fill(all(best), 0);
for (int &v : adj[u]) if (v != pre) {
dfs(v, u);
dp[u][0] += dp[v][n + 1];
for (int i = 0; i < n; i++) {
f[i + 1] += dp[v][n + 1];
best[i + 1] = min(best[i + 1], -dp[v][n + 1] + dp[v][i]);
}
}
for (int i = 2; i <= n; i += 2) {
dp[u][n + 1] = min(dp[u][n + 1], 1 + f[i] + best[i]);
}
} else {
fill(all(dp[u]), 0);
fill(all(f), 0);
fill(all(best), n + 1);
int children = 0;
for (int &v : adj[u]) if (v != pre) {
children++;
dfs(v, u);
dp[u][n + 1] += dp[v][n + 1];
for (int i = 0; i < n; i++) {
dp[u][i + 1] += min(dp[v][n + 1], dp[v][i]);
}
}
for (int i = 1; i <= n; i++) {
dp[u][n + 1] = min(dp[u][n + 1], 1 + dp[u][i]);
}
}
};
dfs(1, 0);
cout << dp[1][n + 1];
}
void Solve() {
cin >> n >> k;
bool isSub1 = 1;
for (int i = 1; i < n; i++) {
int u, v;
cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
if (abs(u - v) != 1) isSub1 = 0;
}
for (int i = 1; i <= k; i++) cin >> a[i];
if (isSub1) return sub1();
if (k <= 15) return sub2();
return sub3();
}
int32_t main() {
ios_base::sync_with_stdio(0);
cin.tie(0);
#define TASK "WHARF"
if (fopen(TASK".INP", "r")) {
freopen(TASK".INP", "r", stdin);
freopen(TASK".OUT", "w", stdout);
}
/* int TEST = 1; cin >> TEST; while (TEST--) */ Solve();
cerr << "\nTime elapsed: " << 1.0 * clock() / CLOCKS_PER_SEC << "s\n";
return 0;
}
Compilation message (stderr)
pastiri.cpp: In function 'int32_t main()':
pastiri.cpp:159:24: warning: ignoring return value of 'FILE* freopen(const char*, const char*, FILE*)' declared with attribute 'warn_unused_result' [-Wunused-result]
159 | freopen(TASK".INP", "r", stdin);
| ~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~
pastiri.cpp:160:24: warning: ignoring return value of 'FILE* freopen(const char*, const char*, FILE*)' declared with attribute 'warn_unused_result' [-Wunused-result]
160 | freopen(TASK".OUT", "w", stdout);
| ~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |