제출 #743347

#제출 시각아이디문제언어결과실행 시간메모리
743347myrcellaMergers (JOI19_mergers)C++17
100 / 100
1448 ms119944 KiB
//by szh #include<bits/stdc++.h> using namespace std; #define fi first #define se second #define pii pair<int,int> #define pll pair<long long,long long> #define pb push_back #define debug(x) cerr<<#x<<"="<<x<<endl #define pq priority_queue #define inf 0x3f #define rep(i,a,b) for (int i=a;i<(b);i++) #define MP make_pair #define SZ(x) (int(x.size())) #define ll long long #define mod 1000000007 #define ALL(x) x.begin(),x.end() void inc(int &a,int b) {a=(a+b)%mod;} void dec(int &a,int b) {a=(a-b+mod)%mod;} int lowbit(int x) {return x&(-x);} ll p0w(ll base,ll p) {ll ret=1;while(p>0){if (p%2ll==1ll) ret=ret*base%mod;base=base*base%mod;p/=2ll;}return ret;} const int maxn = 5e5+10; int n,m; vector <int> edge[maxn]; int dep[maxn]; int f[maxn]; int deg[maxn]; vector <int> node[maxn]; int col[maxn]; int par[maxn]; int fnode[maxn]; void dfs(int u,int fa) { par[u] = fa; dep[u] = dep[fa] + 1; for (int v:edge[u]) { if (v==fa) continue; dfs(v,u); } return; } int getf(int x) { if (f[x]==x) return x; else return f[x] = getf(f[x]); } int getfnode(int x) { if (fnode[x]==x) return x; else return fnode[x] = getfnode(fnode[x]); } void merge(int x,int y) { f[getf(x)] = getf(y); } void update(int u,int v) { u = getfnode(u), v = getfnode(v); while (u!=v) { int uu = u,vv=v; if (dep[u]>dep[v]) u = getfnode(par[u]),fnode[getfnode(uu)] = getfnode(u); else v = getfnode(par[v]),fnode[getfnode(vv)] = getfnode(v); merge(col[u],col[v]); } return; } map <pii,int> mp; int main() { // freopen("input.txt","r",stdin); std::ios::sync_with_stdio(false);cin.tie(0); cin>>n>>m; rep(i,1,m+1) f[i] = i; rep(i,1,n+1) fnode[i] = i; rep(i,1,n) { int u,v; cin>>u>>v; edge[u].pb(v); edge[v].pb(u); } dfs(1,0); rep(i,1,n+1) { int c;cin>>c; col[i] = c; node[c].pb(i); } // rep(i,1,n+1) // rep(j,1,i) if (col[i]==col[j]) update(i,j); rep(i,1,m+1) { if (node[i].empty()) continue; int tmp = node[i][0]; rep(j,1,SZ(node[i])) update(tmp,node[i][j]); } rep(i,1,n+1) for (int v:edge[i]) { int fx = getf(col[i]), fy = getf(col[v]); if (fx==fy) continue; if (fx>fy) swap(fx,fy); if (mp.find({fx,fy})!=mp.end()) continue; mp[{fx,fy}]=1; deg[fx]++; deg[fy]++; } int hii = 0; int cnt = 0; rep(i,1,m+1) { if (getf(i)==i) hii++; if (getf(i)==i and deg[i]==1) cnt++; } assert(SZ(mp)==hii-1); cout<<(cnt+1)/2; return 0; }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...