#include <bits/stdc++.h>
using namespace std;
#define Foxyy cin.tie(0); cout.sync_with_stdio(0);
#define ll long long
namespace HLD {
int n;
vector<vector<int>>& adj = *(new vector<vector<int>>());
vector<int> chainRootOfNode{};
vector<int> heavyChild{};
vector<int> parent{};
vector<int> height{};
vector<int> depth{};
void scanHeavy(int u, int p) {
if (p != -1) {
depth[u] = depth[p] + 1;
}
parent[u] = p;
int currentHeavyChild = -1;
for (int v : adj[u]) if (v != p) {
scanHeavy(v, u);
if (currentHeavyChild == -1 or height[v] + 1 > height[currentHeavyChild]) {
currentHeavyChild = v;
}
}
heavyChild[u] = currentHeavyChild;
height[u] = height[heavyChild[u]] + 1;
}
void buildChains(int u, int root) {
chainRootOfNode[u] = root;
for (int v : adj[u]) if (v != parent[u]) {
if (v == heavyChild[u]) {
buildChains(v, root);
} else {
buildChains(v, v);
}
}
}
class DSU {
private:
int n;
vector<int> p;
public:
DSU() {}
DSU(int _n) : n(_n) {
p.resize(n);
iota(p.begin(), p.end(), 0);
}
int find(int x) {
return x == p[x] ? x : p[x] = find(p[x]);
}
void unite(int x, int y) {
int px = find(x);
int py = find(y);
// cerr << px << ' ' << py << '\n';
if (px != py) {
if (depth[px] > depth[py]) {
swap(px, py);
}
p[py] = px;
}
}
} dsu;
void initialize(vector<vector<int>>& _adj) {
n = (int)_adj.size();
adj = _adj;
chainRootOfNode.resize(n, -1);
parent.resize(n, -1);
heavyChild.resize(n, -1);
height.resize(n);
depth.resize(n);
dsu = DSU(n);
chainRootOfNode[0] = parent[0] = 0;
depth[0] = 0;
scanHeavy(0, -1);
buildChains(0, 0);
}
int getLCAOf(int a, int b) {
while (chainRootOfNode[a] != chainRootOfNode[b]) {
if (depth[chainRootOfNode[a]] < depth[chainRootOfNode[b]]) {
swap(a, b);
}
a = chainRootOfNode[a];
}
if (depth[a] > depth[b]) {
return b;
} else {
return a;
}
}
int getChainRootOfNode(int u) {
return chainRootOfNode[u];
}
}; // namespace HLD
struct Solver {
int n;
int k;
vector<vector<int>>& adj;
vector<int>& s;
void solve() {
HLD::initialize(adj);
// cerr << "init\n";
vector<vector<int>> stateCities(k);
for (int i = 0; i < n; i++) {
stateCities[s[i]].push_back(i);
}
for (int i = 0; i < k; i++) if (not stateCities[k].empty()) {
int u = stateCities[i][0];
for (int v : stateCities[i]) {
u = HLD::dsu.find(u);
v = HLD::dsu.find(v);
int lca = HLD::dsu.find(HLD::getLCAOf(u, v));
// cerr << u << ' ' << v << ' ' << lca << '\n';
while (u != lca) {
// cerr << u << '\n';
HLD::dsu.unite(u, lca);
u = HLD::dsu.find(HLD::parent[u]);
}
while (v != lca) {
// cerr << "v: " << v << " " << HLD::parent[v] << " " << HLD::dsu.find(HLD::parent[v]) << '\n';
HLD::dsu.unite(v, lca);
v = HLD::dsu.find(HLD::parent[v]);
}
}
}
vector<set<int>> s(n);
for (int i = 0; i < n; i++) {
for (int j : adj[i]) if (HLD::dsu.find(i) != HLD::dsu.find(j)) {
s[HLD::dsu.find(i)].insert(HLD::dsu.find(j));
}
// cerr << "dsu.find(" << i << ") = " << HLD::dsu.find(i) << '\n';
}
int cnt = 0;
for (int i = 0; i < n; i++) {
if (s[i].size() == 1u) {
cnt++;
}
}
// cerr << cnt << '\n';
cout << (cnt+1)/2 << '\n';
}
};
signed main() {
Foxyy
int T = 1;
// cin >> T;
while (T--) {
int n, k;
cin >> n >> k;
vector<vector<int>> adj(n);
vector<int> s(n);
for (int i = 0; i < n-1; i++) {
int a, b;
cin >> a >> b;
a--, b--;
adj[a].push_back(b);
adj[b].push_back(a);
}
for (int i = 0; i < n; i++) {
cin >> s[i];
s[i]--;
}
Solver solver{n, k, adj, s};
solver.solve();
}
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Execution timed out |
3073 ms |
212 KB |
Time limit exceeded |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Execution timed out |
3073 ms |
212 KB |
Time limit exceeded |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Execution timed out |
3073 ms |
212 KB |
Time limit exceeded |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Execution timed out |
3033 ms |
15320 KB |
Time limit exceeded |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Execution timed out |
3073 ms |
212 KB |
Time limit exceeded |
2 |
Halted |
0 ms |
0 KB |
- |