#include<bits/stdc++.h>
using namespace std;
#define all(x) (x).begin(), (x).end()
#define sz(x) ( (int)(x).size() )
using LL = long long;
template<class T>
inline bool asMn(T &a, const T &b) { return a > b ? a = b, true : false; }
template<class T>
inline bool asMx(T &a, const T &b) { return a < b ? a = b, true : false; }
const int inf = 1e9;
mt19937 rng( (uint32_t)chrono::steady_clock::now().time_since_epoch().count() );
int n, k;
vector<vector<int> > gr;
vector<int> be, en, ver;
void dfs(int u, int pr) {
static int iTime = 0;
ver[iTime] = u;
be[u] = iTime++;
for (int v : gr[u]) if (v != pr) dfs(v, u);
en[u] = iTime - 1;
}
pair<int, int> operator + (const pair<int, int> &a, const pair<int, int> &b) {
return { min(a.first, b.first), max(a.second, b.second) };
}
vector<vector<pair<int, int> > > rmq;
pair<int, int> get(int l, int r) {
int i = __lg(r - l + 1);
return rmq[i][l] + rmq[i][r - (1 << i) + 1];
}
struct Dsu {
vector<int> pSet;
Dsu(int nSet = 0) { pSet.assign(nSet, 0); iota(all(pSet), 0); }
void reset(int nSet = 0) { pSet.assign(nSet, 0); iota(all(pSet), 0); }
int findSet(int i) { return i == pSet[i] ? i : pSet[i] = findSet(pSet[i]); }
void unite(int i, int j) {
i = findSet(i); j = findSet(j);
if (i == j) return ;
pSet[i] = j;
}
} dsu;
vector<pair<int, int> > edge;
void solve(int u, int pr) {
for (int v : gr[u]) if (v != pr) {
auto tmp = get(be[v], en[v]);
if (be[v] <= tmp.first && tmp.second <= en[v]) edge.emplace_back(u, v);
else dsu.unite(u, v);
solve(v, u);
}
}
int main() {
ios_base::sync_with_stdio(0); cin.tie(0);
#ifdef FourLeafClover
freopen("input", "r", stdin);
#endif // FourLeafCLover
cin >> n >> k;
gr.assign(n, {} );
for (int i = 1; i < n; ++i) {
int u, v; cin >> u >> v; --u; --v;
gr[u].emplace_back(v);
gr[v].emplace_back(u);
}
be.assign(n, 0);
en.assign(n, 0);
ver.assign(n, 0);
dfs(0, -1);
vector<int> s(n), L(k, inf), R(k, -inf);
for (int u = 0; u < n; ++u) {
cin >> s[u]; --s[u];
asMn(L[ s[u] ], be[u]);
asMx(R[ s[u] ], be[u]);
}
rmq.assign(__lg(n) + 1, vector<pair<int, int> >(n, { 0, 0 } ) );
for (int u = 0; u < n; ++u) rmq[0][ be[u] ] = { L[ s[u] ], R[ s[u] ] };
for (int j = 1; j < sz(rmq); ++j) {
for (int i = 0; i + (1 << j) <= n; ++i) rmq[j][i] = rmq[j - 1][i] + rmq[j - 1][i + (1 << (j - 1) )];
}
dsu.reset(n);
solve(0, -1);
vector<int> cnt(n);
for (auto e : edge) {
++cnt[ dsu.findSet(e.first) ];
++cnt[ dsu.findSet(e.second) ];
}
cout << max(0, (int)(count(all(cnt), 1) ) - 1) << '\n';
return 0;
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
5 ms |
376 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
5 ms |
376 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
5 ms |
376 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
72 ms |
23412 KB |
Output is correct |
2 |
Incorrect |
89 ms |
25824 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
5 ms |
376 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |