#include <bits/stdc++.h>
using namespace std;
const int maxn = 3000 + 10, maxlg = 15, sq = 500;
struct str {
int l, r, ind;
};
int n, k, timer, h[maxn], st[maxn], fn[maxn], par[maxn][maxlg];
vector <int> adj[maxn], s[maxlg];
vector <pair <int, int>> euler;
bool cmp (str a, str b){
if (a.l / sq != b.l / sq)
return a.l < b.l;
return (a.r < b.r) ^ ((a.l / sq) & 1);
}
void dfs (int v){
st[v] = timer++;
euler.push_back({v, 0});
for (int i = 1; i < maxlg; i++)
par[v][i] = par[par[v][i - 1]][i - 1];
for (auto u : adj[v]){
if (u != par[v][0]){
h[u] = h[v] + 1;
par[u][0] = v;
dfs(u);
}
}
fn[v] = timer++;
euler.push_back({v, 1});
}
int lca (int x, int y){
if (x == y) return x;
if (h[x] < h[y]) swap(x, y);
for (int i = maxlg - 1; ~i; i--){
int u = par[x][i];
if (!(st[u] <= st[y] && fn[u] >= fn[y]))
x = u;
}
return par[x][0];
}
int main (){
ios_base::sync_with_stdio(0);
cin >> n >> k;
for (int x, y, i = 0; i < n - 1; i++)
cin >> x >> y,
adj[--x].push_back(--y),
adj[y].push_back(x);
for (int x, i = 0; i < n; i++)
cin >> x,
s[--x].push_back(i);
dfs(0);
vector <str> que;
for (int i = 0; i < k; i++){
int m = s[i].size();
for (int j = 0; j < m; j++){
int v = s[i][j], u = s[i][(j + 1) % m], l = lca(u, v);
que.push_back({st[v], st[l], i});
que.push_back({st[u], st[l], i});
}
}
sort(que.begin(), que.end(), cmp);
int L = 0, R = 0;
set <int> st; st.insert(euler[0].first);
vector <int> msk(k, 0);
for (auto [l, r, ind] : que){
while (L < l){
if (euler[L].second) st.erase(euler[L].first);
else st.insert(euler[L++].first);
}
while (L > l){
if (euler[--L].second) st.insert(euler[L].first);
else st.erase(euler[L].first);
}
while (R < r){
if (euler[++R].second) st.erase(euler[R].first);
else st.insert(euler[R].first);
}
while (R > r){
if (euler[R].second) st.erase(euler[R].first);
else st.insert(euler[R--].first);
}
for (auto u : st)
msk[ind] |= (1ll << u);
}
vector <int> dp((1ll << k), int(1e5));
dp[0] = 0;
for (int i = 0; i < k; i++)
for (int mask = 0; mask < (1ll << k); mask++)
dp[mask | msk[i]] = min(dp[mask | msk[i]], dp[mask] + 1);
cout << dp[(1ll << k) - 1];
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Execution timed out |
3050 ms |
336 KB |
Time limit exceeded |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Execution timed out |
3050 ms |
336 KB |
Time limit exceeded |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Execution timed out |
3050 ms |
336 KB |
Time limit exceeded |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Runtime error |
2 ms |
592 KB |
Execution killed with signal 11 |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Execution timed out |
3050 ms |
336 KB |
Time limit exceeded |
2 |
Halted |
0 ms |
0 KB |
- |