Submission #623102

#TimeUsernameProblemLanguageResultExecution timeMemory
623102radalMergers (JOI19_mergers)C++17
0 / 100
3075 ms39424 KiB
#include <bits/stdc++.h> #pragma GCC target("sse,sse2,avx2") #pragma GCC optimize("unroll-loops,O2") #define rep(i,l,r) for (int i = l; i < r; i++) #define repr(i,r,l) for (int i = r; i >= l; i--) #define X first #define Y second #define all(x) (x).begin() , (x).end() #define pb push_back #define endl '\n' #define debug(x) cerr << #x << " : " << x << endl; using namespace std; typedef long long ll; typedef long double ld; typedef pair<int,int> pll; constexpr int N = 5e5+10,mod = 998244353,inf = 1e9+10,sq = 700; inline int mkay(int a,int b){ if (a+b >= mod) return a+b-mod; if (a+b < 0) return a+b+mod; return a+b; } inline int poww(int a,int k){ if (k < 0) return 0; int z = 1; while (k){ if (k&1) z = 1ll*z*a%mod; a = 1ll*a*a%mod; k /= 2; } return z; } vector<int> adj[N],col[N]; int par[N][20],h[N],T,tin[N],calc[N]; void pre(int v,int p){ par[v][0] = p; tin[v] = T++; for (int u : adj[v]){ if (u == p) continue; h[u] = h[v]+1; pre(u,v); } } bool cmp(int u,int v){ return (tin[u] < tin[v]); } int lca(int u,int v){ if (h[u] < h[v]) swap(u,v); repr(i,19,0){ if ((1 << i) <= h[u]-h[v]) u = par[u][i]; } if (u == v) return u; repr(i,19,0){ if (par[u][i] != par[v][i]){ v = par[v][i]; u = par[u][i]; } } return par[v][0]; } void dfs(int v,int p){ for (int u : adj[v]){ if (u != p){ dfs(u,v); calc[v] += calc[u]; } } } int main(){ ios :: sync_with_stdio(0); cin.tie(0); mt19937 rng(chrono::steady_clock::now().time_since_epoch().count()); int n,k; cin >> n >> k; rep(i,1,n){ int u,v; cin >> u >> v; adj[u].pb(v); adj[v].pb(u); } rep(i,1,n+1){ int c; cin >> c; col[c].pb(i); } rep(i,1,k+1) sort(all(col[i]),cmp); pre(1,0); rep(j,1,20){ rep(i,2,n+1) par[i][j] = par[par[i][j-1]][j-1]; } rep(i,1,k+1){ int sz = col[i].size(); if (sz < 2) continue; rep(j,1,sz){ calc[col[i][j]]++; calc[col[i][j-1]]++; calc[lca(col[i][j],col[i][j-1])] -= 2; } } dfs(1,0); int cnt = 0,ans = 0; rep(i,2,n+1) cnt += (calc[i] == 0); while (cnt){ ans++; int mi = inf,x = -1,y = -1; rep(i,1,k){ if (rng()%3 == 0) continue; if (col[i].empty()) continue; rep(j,i+1,k+1){ if (col[j].empty()) continue; rep(f,1,n+1) calc[f] = 0; rep(f,1,k+1){ if (f == j || f == i) continue; int sz = col[f].size(); if (sz < 2) continue; rep(g,1,sz){ calc[col[f][g]]++; calc[col[f][g-1]]++; calc[lca(col[f][g],col[f][g-1])] -= 2; } } vector<int> tmp; tmp.resize(col[i].size()+col[j].size()); merge(all(col[i]),all(col[j]),tmp.begin(),cmp); int sz = tmp.size(); rep(f,1,sz){ calc[tmp[f]]++; calc[tmp[f-1]]++; calc[lca(tmp[f],tmp[f-1])] -= 2; } dfs(1,0); int fuck = 0; rep(i,2,n+1) fuck += (calc[i] == 0); if (fuck < mi){ mi = fuck; x = i; y = j; } } } int i = x,j = y; rep(f,1,n+1) calc[f] = 0; rep(f,1,k+1){ if (f == j || f == i) continue; int sz = col[f].size(); if (sz < 2) continue; rep(g,1,sz){ calc[col[f][g]]++; calc[col[f][g-1]]++; calc[lca(col[f][g],col[f][g-1])] -= 2; } } vector<int> tmp; tmp.resize(col[i].size()+col[j].size()); merge(all(col[i]),all(col[j]),tmp.begin(),cmp); int sz = tmp.size(); rep(f,1,sz){ calc[tmp[f]]++; calc[tmp[f-1]]++; calc[lca(tmp[f],tmp[f-1])] -= 2; } col[i] = tmp; col[j].clear(); dfs(1,0); cnt = 0; rep(i,2,n+1) cnt += (calc[i] == 0); } cout << ans; }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...