#include"bits/stdc++.h"
#define int long long
#define endl '\n'
#define USACO(x) ifstream cin ((string)#x + ".in"); ofstream cout((string)#x + ".out")
using namespace std;
const int maxN = 1000001;
int ROOT;
int n, k;
vector<int> adj[maxN];
int s[maxN];
int tot[maxN];
int subtree[maxN];
map<int,int>* s2b[maxN];
int cnt[maxN];
bool bad[maxN];
void inc (int a, int b, int v) {
(*s2b[a])[b] += v;
if ((*s2b[a])[b] == tot[b]) cnt[a]++;
}
int cunt = 0;
void subtrees (int cur, int par) {
subtree[cur] = 1;
for (int i : adj[cur]) {
if (i != par) {
subtrees(i, cur);
subtree[cur] += subtree[i];
}
}
}
void dfs (int cur, int par) {
if (adj[cur].size() == 1 and cur != ROOT) {
inc(cur, s[cur], 1);
bad[cur] = (cnt[cur] == (int)((*s2b[cur]).size()));
return;
}
int best = -1;
for (int i : adj[cur]) {
if (i != par) {
dfs(i, cur);
if (best == -1 or subtree[i] > subtree[best]) best = i;
}
}
s2b[cur] = s2b[best];
cnt[cur] = cnt[best];
for (int i : adj[cur]) {
if (i != par and i != best) {
for (auto j : (*s2b[i])) {
inc(cur, j.first, j.second);
}
(*s2b[i]).clear();
}
}
inc(cur, s[cur], 1);
bad[cur] = (cnt[cur] == (int)((*s2b[cur]).size()));
}
int calc_ans (int cur, int par) {
int ret = 0;
for (int i : adj[cur]) {
if (i != par) {
int x = calc_ans(i, cur);
if (bad[cur] == 0 or cur == ROOT) {
cunt += min(ret, x);
// cerr << cunt << endl;
ret = abs(ret - x);
}
else ret += x;
}
}
if (ret == 0 and cur != ROOT) ret += bad[cur];
return ret;
}
signed main() {
for (auto& i : s2b) i = new map<int,int>();
ios_base::sync_with_stdio(0);
cin.tie(0);
cin >> n >> k;
ROOT = rand() % n + 1;
if (n ==1) {
cout << 0 << endl;
return 0;
}
for (int i = 1 ; i < n ; i++) {
int a, b;
cin >> a >> b;
adj[a].push_back(b);
adj[b].push_back(a);
}
for (int i = 1 ; i <= n ; i++) {
cin >> s[i];
tot[s[i]]++;
}
subtrees(1, 0);
int cur = 1;
bool good = 1;
vector<bool> vis(maxN);
do {
vis[cur] = 1;
good = 0;
int v = 0;
for (int i : adj[cur]) if (!vis[i] and subtree[i] > ceil(n / 2.0)) {v = i;good = 1;}
cur = v;
} while (good);
ROOT = cur;
dfs(ROOT, 0);
// for (int i = 1 ; i <= n ; i++) cout << bad[i] << " "; cout << endl;
int x = calc_ans(ROOT, 0);
cout << cunt + (x ? 1 + ceil((x-1)/2.0) : 0);
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Runtime error |
169 ms |
191184 KB |
Execution killed with signal 11 |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Runtime error |
169 ms |
191184 KB |
Execution killed with signal 11 |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Runtime error |
169 ms |
191184 KB |
Execution killed with signal 11 |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Runtime error |
212 ms |
202232 KB |
Execution killed with signal 11 |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Runtime error |
169 ms |
191184 KB |
Execution killed with signal 11 |
2 |
Halted |
0 ms |
0 KB |
- |