제출 #1159100

#제출 시각아이디문제언어결과실행 시간메모리
1159100MPGRailway (BOI17_railway)C++20
100 / 100
372 ms50152 KiB
//#pragma GCC optomize("Ofast") #pragma GCC optimize("unroll-loops") //#pragma GCC optimize("O3") //#pragma GCC target("avx2") //#pragma GCC target("sse,sse2,sse4.1,sse4.2") #include <bits/stdc++.h> using namespace std; typedef long long ll; #define max_heap priority_queue<pair <ll, pair <ll, ll>>> #define min_heap priority_queue<pair <ll, ll>, vector<pair <ll, ll>>, greater<pair <ll, ll>>> //#define min_heap priority_queue<ll, vector<ll>, greater<ll>> #define sariE cin.tie(NULL); cout.tie(NULL); ios_base::sync_with_stdio(false); #define filE freopen("in.txt", "r", stdin); freopen("out1.txt", "w", stdout); #define endl '\n' #define md(a) (a % mod + mod) % mod #define pb push_back //cout << vectorprecision(5) << fixed << f; //hash prime = 769 mt19937 rng(chrono::steady_clock::now().time_since_epoch().count()); ll const maxn = 2e5 + 123; ll const inf = 2e18; ll const loG = 23; ll const mod = 1e9 + 7; //ll const mod = 998244353; ll const sq = 400; ll power(ll a, ll b, ll mod){if(b==0)return 1;if(b==1)return a;ll x = power(a, b / 2, mod);return (((x * x) % mod) * (b % 2 ? a : 1)) % mod;} ll n, q, k, st[maxn], stt[maxn], fn[maxn], h[maxn], par[loG][maxn], sz[maxn], big[maxn], head[maxn], op[maxn * 2], tim, tim2, segsz = 2; vector <ll> g[maxn], seg; vector <pair <pair <ll, ll>, ll>> yal; void dfs(ll v, ll p){ sz[v] = 1; par[0][v] = p; h[v] = h[p] + 1; big[v] = 0; stt[v] = ++tim2; for (ll u : g[v]){ if (u != p){ dfs(u, v); if (sz[u] > sz[big[v]]){ big[v] = u; } sz[v] += sz[u]; } } fn[v] = ++tim2; } void dfs_hld(ll v, ll p, ll baba){ st[v] = ++tim; head[v] = baba; if (big[v]){ dfs_hld(big[v], v, baba); } for (ll u : g[v]){ if ((u == p) || (u == big[v])) continue; dfs_hld(u, v, u); } } void propogate(ll x, ll lx, ll rx){ if (rx - lx == 1) return; if (op[x] == 0) return; ll mid = (lx + rx) / 2, a = 2 * x + 1, b = a + 1; op[a] += op[x]; op[b] += op[x]; seg[a] += (mid - lx) * op[x]; seg[b] += (rx - mid) * op[x]; op[x] = 0; } ll getter(ll i, ll x, ll lx, ll rx){ propogate(x, lx, rx); if (rx - lx == 1) return seg[x]; ll mid = (lx + rx) / 2, a = 2 * x + 1, b = a + 1; ll val = 0; if (i < mid) val = getter(i, a, lx, mid); else val = getter(i, b, mid, rx); seg[x] = seg[a] + seg[b]; return val; } void adder(ll l, ll r, ll x, ll lx, ll rx){ //cout << l << ' ' << r << ' ' << x << ' ' << lx << " " << rx << ' ' << getter(l, 0, 0, segsz) << ' ' << getter(r - 1, 0, 0, segsz) << endl; propogate(x, lx, rx); if (l >= rx || lx >= r) return; if (l <= lx && rx <= r){ seg[x] += (rx - lx); op[x] += 1; return; } ll mid = (lx + rx) / 2, a = 2 * x + 1, b = a + 1; adder(l, r, a, lx, mid); adder(l, r, b, mid, rx); seg[x] = seg[a] + seg[b]; } ll jump(ll v, ll d){ for (int i = 0; i < loG; i++) if (d & (1 << i)) v = par[i][v]; return v; } ll lca(ll u, ll v){ if (h[u] < h[v]) swap(u, v); u = jump(u, h[u] - h[v]); if (u == v) return u; for (int i = loG - 1; i >= 0; i--) if (par[i][u] != par[i][v]) u = par[i][u], v = par[i][v]; return par[0][u]; } bool cmp(ll v, ll u){ return stt[v] < stt[u]; } bool ispar(ll v, ll u){ return (stt[v] <= stt[u]) && (fn[v] >= fn[u]); } void add(ll v, ll u){ while (head[v] != head[u]){ if (h[head[v]] < h[head[u]]){ swap(u, v); } adder(st[head[v]], st[v] + 1, 0, 0, segsz); //cout << "adding " << head[v] << ' ' << v << endl; v = par[0][head[v]]; } if (h[v] > h[u]) swap(v, u); //cout << "adding " << v << ' ' << u << endl; adder(st[v], st[u] + 1, 0, 0, segsz); } void Solve(){ cin >> n >> q >> k; while (segsz <= n * 2) segsz = segsz * 2; seg.resize(2 * segsz); for (int i = 1; i < n; i++){ ll a, b; cin >> a >> b; yal.pb({{a, b}, i}); g[a].pb(b); g[b].pb(a); } dfs(1, 1); dfs_hld(1, 1, 1); for (int i = 1; i < loG; i++) for (int j = 1; j < n + 1; j++) par[i][j] = par[i - 1][par[i - 1][j]]; while (q--){ ll m; cin >> m; vector <ll> tmp, stk; tmp.clear(); stk.clear(); for (int i = 1; i < m + 1; i++){ ll x; cin >> x; tmp.pb(x); } sort(tmp.begin(), tmp.end(), cmp); for (int i = 0; i < tmp.size() - 1; i++){ ll a = tmp[i], b = tmp[i + 1], c = lca(a, b); stk.pb(c); } for (ll x : stk) tmp.pb(x); sort(tmp.begin(), tmp.end(), cmp); tmp.resize(unique(tmp.begin(), tmp.end()) - tmp.begin()); stk.clear(); for (ll v : tmp){ while (stk.size() && !ispar(stk.back(), v)) stk.pop_back(); if (stk.size()){ // kar ll x = stk.back(); ll e = h[v] - h[x]; ll xx = jump(v, e - 1); add(v, xx); //cout << "yal " << v << ' ' << x << ' ' << xx << endl; } stk.pb(v); } //cout << "new qu" << endl; } // for (int i = 1; i < n + 1; i++){ // cout << i << ' ' << st[i] << ' ' << getter(st[i], 0, 0, segsz) << endl; // } vector <ll> ans; ans.clear(); for (auto p : yal){ ll a = p.first.first, b = p.first.second; ll x = getter(max(st[a], st[b]), 0, 0, segsz); //cout << a << ' ' << b << ' ' << st[a] << ' ' << st[b] << ' ' << x << endl; if (x >= k) ans.pb(p.second); } cout << ans.size() << endl; for (ll x : ans) cout << x << ' '; cout << endl; } int main(){ sariE;// filE; int test = 1; //cin >> test; while (test--) Solve(); return 0; }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...