#include <bits/stdc++.h>
using namespace std;
#define int long long
const int maxn = 5e5 + 5, maxk = 5e5+5, lg = 20;
int n, k;
vector<int> ADJ[maxn];
int A[maxn];
int SZ[maxn], FA[maxn], newid[maxn], lim[maxn], cnt = 0;
bool cmp(int x, int y) {
return SZ[x] > SZ[y];
}
void DFS0(int u) {
SZ[u] = 1;
for (int v:ADJ[u]) {
FA[v] = u;
ADJ[v].erase(find(ADJ[v].begin(), ADJ[v].end(), u));
DFS0(v);
SZ[u] += SZ[v];
}
sort(ADJ[u].begin(), ADJ[u].end(), cmp);
}
void DFS1(int u) {
cnt++;
newid[u] = cnt;
for (int v:ADJ[u]) DFS1(v);
lim[newid[u]] = cnt;
}
vector<int> adj[maxn];
int fa[maxn], sz[maxn];
vector<int> mem[maxk];
int a[maxn];
int st[maxn][lg];
void dfs0(int u) {
for (int i=1;i<lg;i++) {
if (!st[u][i-1] || !st[st[u][i-1]][i-1]) break;
st[u][i] = st[st[u][i-1]][i-1];
}
for (int v:adj[u]) {
st[v][0] = u;
dfs0(v);
}
}
int psum[maxn];
int lca(int u, int c) {
int l = mem[c][0], r = mem[c].back();
for (int i=lg-1;i>=0;i--) if (st[u][i]) {
int v = st[u][i];
if (!(v<=l && r<=lim[v])) u = v;
}
// cout << c << " " << l << " " << r << endl;
if (!(u<=l && r<=lim[u])) u = fa[u];
return u;
}
bool sep[maxn];
bool sepleaf[maxn];
int sum[maxn];
int dead = 0, ans = 0;
bool frst = false;
void dfs1(int u, int minus) {
// cout << u << " " << minus << endl;
if (minus==0 && sep[u]) frst = true;
if (sum[lim[u]] - sum[u-1] - minus == 1 && sep[u]) dead++;
if (adj[u].size()==0) return;
vector<pair<int,int>> vec;
for (int v:adj[u]) {
vec.push_back({sum[lim[v]] - sum[v-1], v});
}
sort(vec.begin(), vec.end());
int l = 0, r = 1e18;
while (l<r) {
int mid = (l+r)/2, used = 0;
for (auto [x, y]:vec) used += max(x - mid, (int)0);
if (used <= minus) r = mid;
else l = mid+1;
}
for (auto &[x, y]:vec) x = min(x, l);
int total = sum[lim[u]] - sum[u] - minus;
if (vec[0].first <= (total+1)/2) {
ans += (total+1)/2;
return;
}
int rest = total - vec[0].first;
vec[0].first -= rest;
int v = vec[0].second;
ans += rest;
dfs1(v, sum[lim[v]] - sum[v-1] - vec[0].first);
}
signed main() {
ios::sync_with_stdio(0); cin.tie(0);
cin >> n >> k;
for (int i=1;i<=n-1;i++) {
int u, v;
cin >> u >> v;
ADJ[u].push_back(v), ADJ[v].push_back(u);
}
for (int i=1;i<=n;i++) cin >> A[i];
DFS0(1);
DFS1(1);
// for (int i=1;i<=n;i++) cout << newid[i] << " "; cout << endl;
// for (int i=1;i<=n;i++) cout << lim[i] << " "; cout << endl;
for (int i=1;i<=n;i++) {
for (int j:ADJ[i]) adj[newid[i]].push_back(newid[j]);
if (i!=1) fa[newid[i]] = newid[FA[i]];
else fa[i] = 0;
sz[newid[i]] = SZ[i];
}
for (int i=1;i<=n;i++) {
a[newid[i]] = A[i];
mem[A[i]].push_back(newid[i]);
}
dfs0(1);
for (int i=1;i<=k;i++) {
sort(mem[i].begin(), mem[i].end());
psum[lca(mem[i][0], i)] += mem[i].size();
}
for (int i=1;i<=n;i++) psum[i] += psum[i-1];
for (int i=2;i<=n;i++) {
sep[i] = (psum[lim[i]] - psum[i-1] == lim[i] - i + 1);
// cout << psum[lim[i]] - psum[i-1] << " " << lim[i] - i + 1 << endl;
}
// for (int i=1;i<=n;i++) cout << sep[i]; cout << endl;
for (int i=1;i<=n;i++) sum[i] = sum[i-1] + sep[i];
for (int i=1;i<=n;i++) sepleaf[i] = (sep[i] && sum[lim[i]] - sum[i-1] == 1);
bool add = false;
for (int i=2;i<=n;i++) if (sum[lim[i]] - sum[i-1] == sum[n] && sep[i] && !sepleaf[i]) add = true;
int ans = 0;
for (int i=1;i<=n;i++) ans += sepleaf[i];
cout << (ans+1)/2 + add << endl;
// if (sum[n]==0) {
// cout << 0 << endl;
// return 0;
// }
// dfs1(1, 0);
// // assert(dead<=1);
// cout << ans + (dead+1)/2 << endl;
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
19 ms |
35672 KB |
Output is correct |
2 |
Correct |
16 ms |
35676 KB |
Output is correct |
3 |
Incorrect |
16 ms |
35676 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
19 ms |
35672 KB |
Output is correct |
2 |
Correct |
16 ms |
35676 KB |
Output is correct |
3 |
Incorrect |
16 ms |
35676 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
19 ms |
35672 KB |
Output is correct |
2 |
Correct |
16 ms |
35676 KB |
Output is correct |
3 |
Incorrect |
16 ms |
35676 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
57 ms |
66632 KB |
Output is correct |
2 |
Correct |
75 ms |
68936 KB |
Output is correct |
3 |
Correct |
18 ms |
36700 KB |
Output is correct |
4 |
Correct |
17 ms |
36700 KB |
Output is correct |
5 |
Correct |
16 ms |
35676 KB |
Output is correct |
6 |
Correct |
18 ms |
35672 KB |
Output is correct |
7 |
Correct |
18 ms |
36656 KB |
Output is correct |
8 |
Correct |
97 ms |
68432 KB |
Output is correct |
9 |
Correct |
18 ms |
36700 KB |
Output is correct |
10 |
Correct |
68 ms |
67616 KB |
Output is correct |
11 |
Correct |
16 ms |
35672 KB |
Output is correct |
12 |
Correct |
89 ms |
68180 KB |
Output is correct |
13 |
Correct |
93 ms |
68432 KB |
Output is correct |
14 |
Correct |
98 ms |
69712 KB |
Output is correct |
15 |
Correct |
61 ms |
66372 KB |
Output is correct |
16 |
Correct |
19 ms |
36700 KB |
Output is correct |
17 |
Correct |
19 ms |
35676 KB |
Output is correct |
18 |
Correct |
69 ms |
68468 KB |
Output is correct |
19 |
Correct |
117 ms |
74576 KB |
Output is correct |
20 |
Correct |
18 ms |
36700 KB |
Output is correct |
21 |
Incorrect |
17 ms |
35672 KB |
Output isn't correct |
22 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
19 ms |
35672 KB |
Output is correct |
2 |
Correct |
16 ms |
35676 KB |
Output is correct |
3 |
Incorrect |
16 ms |
35676 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |