#include <bits/stdc++.h>
// #pragma GCC optimize("O3", "unroll-loops")
#define ll long long
#define int long long
#define pb push_back
#define fi first
#define se second
#define lf (id<<1)
#define rg ((id<<1)|1)
#define md ((l+r)>>1)
#define ld long double
using namespace std;
typedef pair<int,int> pii;
typedef pair<pii, int> ipii;
const int MAXN = 5e5+10;
const int MAXA = 9e3+20;
const ll INF = 1e9+10;
const int LOG = 13;
const int SQRT = 450;
const vector<int> NOL = {};
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
const vector<int> dx = {1, -1, 0, 0};
const vector<int> dy = {0, 0, 1, -1};
vector <int> key = {29, 31};
vector <int> mod = {998244353, 1000000007};
void chmx(int &a, int b){ a = max(a, b); }
int n, k;
vector <int> adj[MAXN], vec[MAXN];
int a[MAXN], tot[MAXN], u[MAXN], v[MAXN];
vector <int> tree[MAXN];
struct dsu {
int par[MAXN], siz[MAXN];
void bd(){
for(int i=1; i<=n; i++){
par[i] = i; siz[i] = 1;
}
}
int f(int x){
if(par[x]==x) return x;
return par[x] = f(par[x]);
}
bool con(int x, int y){ return f(x) == f(y); }
void mer(int x, int y){
x = f(x); y = f(y);
if(x==y) return;
if(siz[x] > siz[y]) swap(x, y);
par[x] = y; siz[y] += siz[x];
}
} DSU;
int sub[MAXN], dep[MAXN], in[MAXN], tim, par[MAXN];
void dfs(int nw, int pa){
dep[nw] = dep[pa]+1; par[nw] = pa;
sub[nw] = 1;
for(auto nx : adj[nw]){
if(nx==pa) continue;
tree[nw].pb(nx);
dfs(nx, nw);
sub[nw] += sub[nx];
}
}
int ANS[MAXN];
struct segtree {
int st[4*MAXN], laz[4*MAXN];
void bnc(int id, int l, int r){
if(laz[id] == 0) return;
st[lf] += (md-l+1); laz[lf] += laz[id];
st[rg] += (r-md); laz[rg] += laz[id];
laz[id] = 0;
}
int upd(int id, int l, int r, int x, int y){
if(r<x || y<l) return st[id];
if(x<=l && r<=y){
laz[id]++; return st[id] += (r-l+1);
}
bnc(id, l, r);
return st[id] = upd(lf, l, md, x, y) + upd(rg, md+1, r, x, y);
}
void out(int id, int l, int r){
if(l==r){ ANS[l] = st[id]; return; }
bnc(id, l, r);
out(lf, l, md); out(rg, md+1, r);
}
} ST;
int root[MAXN], arr[MAXN];
void bd(int nw, int roo){
root[nw] = roo; in[nw] = ++tim; arr[tim] = nw;
if(tree[nw].size() == 0) return;
bd(tree[nw][0], roo);
for(int j=1; j<tree[nw].size(); j++){
bd(tree[nw][j], tree[nw][j]);
}
}
void query(int x, int y){
while(root[x] != root[y]){
if(dep[root[x]] > dep[root[y]]) swap(x, y); // y bawah
// root[y]+1 - y
// root->nyimpen siapa root, in[root]+1 --> idxnya
ST.upd(1, 1, n, in[root[y]]+1, in[y]);
}
if(dep[x] > dep[y]) swap(x, y); // y bawah
ST.upd(1, 1, n, in[x]+1, in[y]);
}
vector <int> edge[MAXN];
bool cmp(int a, int b){
return in[a] < in[b];
}
signed main(){
// ios_base::sync_with_stdio(0);cin.tie(0);cout.tie(0);
cin >> n >> k; DSU.bd();
for(int i=1; i<=n-1; i++){
int x, y; cin >> x >> y; u[i] = x; v[i] = y;
adj[x].pb(y); adj[y].pb(x);
}
for(int i=1; i<=n; i++){
cin >> a[i]; tot[a[i]]++;
vec[a[i]].pb(i);
}
// dfs hld
dfs(1, 0);
// build hld
for(int i=1; i<=n; i++){
if(tree[i].size() == 0) continue;
for(int j=1; j<tree[i].size(); j++){
if(sub[tree[i][0]] < sub[tree[i][j]])
swap(tree[i][0], tree[i][j]);
}
}
tim = 0; bd(1, 1);
// per state
for(int i=1; i<=k; i++){
if(vec[i].size() <= 1) continue;
sort(vec[i].begin(), vec[i].end(), cmp);
for(int j=1; j<vec[i].size(); j++) query(vec[i][j-1], vec[i][j]);
query(vec[i][0], vec[i].back());
}
// OUT segtree
ST.out(1, 1, n);
for(int i=1; i<=n; i++)
if(ANS[i]) DSU.mer(par[i], i);
// ANSwer
for(int i=1; i<=n-1; i++){
int x = DSU.f(u[i]), y = DSU.f(v[i]);
if(x == y) continue;
edge[x].pb(y); edge[y].pb(x);
}
int tot = 0;
for(int i=1; i<=n; i++){
if(edge[i].size() == 1) tot++;
}
cout << (tot+1)/2 << '\n';
}
Compilation message
mergers.cpp: In function 'void bd(long long int, long long int)':
mergers.cpp:94:16: warning: comparison of integer expressions of different signedness: 'long long int' and 'std::vector<long long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
94 | for(int j=1; j<tree[nw].size(); j++){
| ~^~~~~~~~~~~~~~~~
mergers.cpp: In function 'int main()':
mergers.cpp:131:17: warning: comparison of integer expressions of different signedness: 'long long int' and 'std::vector<long long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
131 | for(int j=1; j<tree[i].size(); j++){
| ~^~~~~~~~~~~~~~~
mergers.cpp:142:17: warning: comparison of integer expressions of different signedness: 'long long int' and 'std::vector<long long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
142 | for(int j=1; j<vec[i].size(); j++) query(vec[i][j-1], vec[i][j]);
| ~^~~~~~~~~~~~~~
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Execution timed out |
3038 ms |
48728 KB |
Time limit exceeded |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Execution timed out |
3038 ms |
48728 KB |
Time limit exceeded |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Execution timed out |
3038 ms |
48728 KB |
Time limit exceeded |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Execution timed out |
3022 ms |
62880 KB |
Time limit exceeded |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Execution timed out |
3038 ms |
48728 KB |
Time limit exceeded |
2 |
Halted |
0 ms |
0 KB |
- |