Submission #630511

#TimeUsernameProblemLanguageResultExecution timeMemory
630511ArnchMergers (JOI19_mergers)C++17
70 / 100
1516 ms170552 KiB
// oooo /* har chi delet mikhad bebar ~ gitar o ba khodet nabar! ~ ;Amoo_Hasan; */ #include<bits/stdc++.h> #pragma GCC optimize("O3,no-stack-protector,unroll-loops") #pragma GCC target("avx2,fma") using namespace std; typedef long long ll; typedef long double ld; #define Sz(x) int((x).size()) #define All(x) (x).begin(), (x).end() #define wtf(x) cout<<#x <<" : " <<x <<endl #define mak make_pair //constexpr int PRI = 1000696969; constexpr int N = 2e6 + 2, MAXN = 5e5 + 2, LOG = 22; int n, k; int s[MAXN], sub[MAXN], head[MAXN], link[MAXN]; int par[MAXN][LOG], h[MAXN], st[MAXN], fn[MAXN], tim; int sz[MAXN], pv[MAXN]; int ans, total; int seg[N]; vector<int> vc[MAXN], adj[MAXN], nei[MAXN]; void dfs(int x, int p = -1) { par[x][0] = p; for(int i = 1; i < LOG; i++) par[x][i] = par[par[x][i - 1]][i - 1]; if(p != -1) h[x] = h[p] + 1; sub[x] = 1; for(auto j : adj[x]) { if(j == p) continue; dfs(j, x); sub[x] += sub[j]; } } void hld(int x, int p = -1, int hi = 0) { link[tim] = x; st[x] = tim++; head[x] = hi; int bz = -1; for(auto j : adj[x]) { if(j == p) continue; if(bz == -1 || sub[bz] < sub[j]) bz = j; } if(bz == -1) { fn[x] = tim; return; } hld(bz, x, hi); for(auto j : adj[x]) { if(j == p || j == bz) continue; hld(j, x, j); } fn[x] = tim; } inline int get_par(int x, int y) { for(int i = 0; i < LOG; i++) if((y >> i) & 1) x = par[x][i]; return x; } inline int lca(int x, int y) { if(h[x] > h[y]) swap(x, y); y = get_par(y, h[y] - h[x]); if(x == y) return x; for(int i = LOG - 1; i >= 0; i--) if(par[x][i] != par[y][i]) x = par[x][i], y = par[y][i]; return par[x][0]; } int find(int x) { if(pv[x] == x) return x; return pv[x] = find(pv[x]); } inline void merge(int x, int y) { // cout<<"^^" <<x <<' ' <<y <<endl; if(max(x, y) >= MAXN) assert(0); int X = find(x), Y = find(y); if(X == Y) return; if(sz[X] < sz[Y]) swap(X, Y); pv[Y] = X, sz[X] += sz[Y]; } void build(int l = 0, int r = n, int v = 1) { seg[v] = -1; if(r - l < 2) { return; } int mid = (l + r) >> 1; build(l, mid, 2 * v), build(mid, r, 2 * v + 1); total = max(total, v); } void upd(int s, int e, int val, int l = 0, int r = n, int v = 1) { if(r <= s || l >= e) return; if(l >= s && r <= e) { if(seg[v] == -1) seg[v] = val; else merge(seg[v], val); return; } int mid = (l + r) >> 1; upd(s, e, val, l, mid, 2 * v), upd(s, e, val, mid, r, 2 * v + 1); } void relax(int l = 0, int r = n, int v = 1) { if(r - l < 2) { if(seg[v] != -1) merge(seg[v], s[link[l]]); return; } int mid = (l + r) >> 1; if(seg[v] != -1) { if(seg[2 * v] == -1) seg[2 * v] = seg[v]; else merge(seg[v], seg[2 * v]); if(seg[2 * v + 1] == -1) seg[2 * v + 1] = seg[v]; else merge(seg[v], seg[2 * v + 1]); } relax(l, mid, 2 * v), relax(mid, r, 2 * v + 1); } bool cmp(int i, int j) { return st[i] < st[j]; } inline void solve(int x) { vector<int> ver; for(auto i : vc[x]) ver.push_back(i); vc[x].clear(); sort(All(ver), cmp); int sz = Sz(ver); for(int i = 1; i < sz; i++) { ver.push_back(lca(ver[i - 1], ver[i])); } sort(All(ver), cmp); ver.erase(unique(All(ver)), ver.end()); vector<int> mt; mt.push_back(ver[0]); upd(st[ver[0]], st[ver[0]] + 1, x); for(int i = 1; i < Sz(ver); i++) { int v = ver[i]; while(fn[mt.back()] < fn[v]) mt.pop_back(); int p = mt.back(); upd(st[v], st[v] + 1, x); int u = v; while(u != -1 && h[u] >= h[p]) { if(h[head[u]] < h[p]) break; upd(st[head[u]], st[u] + 1, x); u = par[head[u]][0]; } if(h[u] >= h[p]) { upd(st[p], st[u] + 1, x); } mt.push_back(v); } ver.clear(); mt.clear(); } int main() { ios :: sync_with_stdio(0), cin.tie(0); cout.tie(0); for(int i = 0; i < MAXN; i++) sz[i] = 1, pv[i] = i; cin >>n >>k; for(int i = 0; i < n - 1; i++) { int u, v; cin >>u >>v; --u, --v; adj[u].push_back(v), adj[v].push_back(u); } dfs(0); hld(0); for(int i = 0; i < n; i++) { cin >>s[i]; } build(); for(int i = 0; i < n; i++) { --s[i]; vc[s[i]].push_back(i); } for(int i = 0; i < k; i++) { solve(i); } relax(); for(int i = 0; i < MAXN; i++) pv[i] = find(pv[i]); for(int i = 0; i < n; i++) { int u = s[i]; for(auto j : adj[i]) { int v = s[j]; if(pv[u] == pv[v]) continue; nei[pv[u]].push_back(pv[v]); nei[pv[v]].push_back(pv[u]); } //adj[i].clear(); } for(int i = 0; i < MAXN; i++) { nei[i].erase(unique(All(nei[i])), nei[i].end()); if(Sz(nei[i]) == 1) ans++; } cout<<(ans + 1) / 2; return 0; }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...