Submission #125019

#TimeUsernameProblemLanguageResultExecution timeMemory
125019WhipppedCreamCats or Dogs (JOI18_catdog)C++17
0 / 100
47 ms27896 KiB
#include <bits/stdc++.h> #pragma GCC optimize ("O3") #pragma GCC target ("sse4") using namespace std; #define X first #define Y second #define pb push_back typedef pair<int, int> ii; typedef long long ll; const int maxn = 1e5+5; int n; struct fenwick { ll ft[maxn]; ll sum(int x) { ll res = 0; for(; x; x -= x&(-x)) res += ft[x]; return res; } void change(int x, int dx) { for(; x<= n; x += x&(-x)) ft[x] += dx; } ll ask(int x) { return sum(x); } void update(int a, int b, int dx) { change(a, dx); change(b+1, -dx); } }; struct segtree { struct node { vector<int> vec = vector<int>(3, 0); int lz = 0; node(){} node(vector<int> vec) : vec(vec) {} }; node st[4*maxn]; void push(int p, int L, int R) { int &lz = st[p].lz; if(!lz) return; vector<int> &vec = st[p].vec; vector<int> nou(3); for(int i = 0; i< 3; i++) { nou[i] = (0<= i-lz && i-lz< 3)?vec[i-lz]:0; } vec = nou; if(L != R) { st[2*p].lz += lz; st[2*p+1].lz += lz; } lz = 0; } node pull(node &x, node &y) { node res; for(int i = 0; i< 3; i++) res.vec[i] = x.vec[i]+y.vec[i]; return res; } void build(int p = 1, int L = 1, int R = n) { if(L == R) { st[p].vec = {0, 1, 0}; return; } int M = (L+R)/2; build(2*p, L, M); build(2*p+1, M+1, R); st[p] = pull(st[2*p], st[2*p+1]); } node ask(int i, int j, int p = 1, int L = 1, int R = n) { if(i> R || j< L) return node(); push(p, L, R); if(i<= L && R<= j) return st[p]; int M = (L+R)/2; node x = ask(i, j, 2*p, L, M); node y = ask(i, j, 2*p+1, M+1, R); node res = pull(x, y); return res; } void update(int i, int j, int dx, int p = 1, int L = 1, int R = n) { push(p, L, R); if(i> R || j< L) return; if(i<= L && R<= j) { st[p].lz += dx; push(p, L, R); return; } int M = (L+R)/2; update(i, j, dx, 2*p, L, M); update(i, j, dx, 2*p+1, M+1, R); st[p] = pull(st[2*p], st[2*p+1]); } void point(int x, int dx, int p = 1, int L = 1, int R = n) { push(p, L, R); if(x> R || x< L) return; if(x<= L && R<= x) { st[p].vec = {0, 0, 0}; if(0<= dx+1 && dx+1< 3) st[p].vec[dx+1] = 1; return; } int M = (L+R)/2; point(x, dx, 2*p, L, M); point(x, dx, 2*p+1, M+1, R); st[p] = pull(st[2*p], st[2*p+1]); } }; vector<int> adj[maxn]; int par[22][maxn]; int pos[maxn]; int head[maxn]; int prf[maxn]; int cnt[maxn]; int dep[maxn]; void dfs(int u = 1, int p = 0) { dep[u] = dep[p]+1; par[0][u] = p; for(int i = 1; i<= 20; i++) par[i][u] = par[i-1][par[i-1][u]]; cnt[u] = 1; ii best = {0, -1}; for(int v : adj[u]) { if(v == p) continue; dfs(v, u); best = max(best, {cnt[v], v}); cnt[u] += cnt[v]; } prf[u] = best.Y; } void hld() { int tim = 1; for(int i = 1; i<= n; i++) { if(prf[par[0][i]] == i) continue; for(int j = i; j != -1; j = prf[j]) { head[j] = i; pos[j] = tim++; } } } fenwick Cat, Dog; ll gimme(fenwick &ft, int x) { return ft.ask(pos[x]); } void rangeplus(fenwick &ft, int u, int v, int dx) { if(v == 0) v = 1; if(u == 0) return; while(head[u] != head[v]) { ft.update(pos[head[u]], pos[u], dx); u = par[0][head[u]]; } ft.update(pos[v], pos[u], dx); } segtree foo; vector<int> gim3(int u, int v) { if(u == 0) return {0, 0, 0}; if(v == 0) v = 1; vector<int> res(3, 0); while(head[u] != head[v]) { auto tmp = foo.ask(pos[head[u]], pos[u]); for(int i = 0; i< 3; i++) res[i] += tmp.vec[i]; u = par[0][head[u]]; } auto tmp = foo.ask(pos[v], pos[u]); for(int i = 0; i< 3; i++) res[i] += tmp.vec[i]; return res; } void shift(int u, int v, int dx) { if(v == 0) v = 1; while(head[u] != head[v]) { foo.update(pos[head[u]], pos[u], dx); u = par[0][head[u]]; } foo.update(pos[v], pos[u], dx); } void spec(int u, int dx) { foo.point(pos[u], dx); } int stat[maxn]; void diffcat(int u, int dc, int dd) { if(u == 0) return; if(u == 1) { rangeplus(Cat, u, u, dc); rangeplus(Dog, u, u, dd); spec(u, stat[u]?1e9:gimme(Cat, u)-gimme(Dog, u)); return; } if(dd-dc == 2) { // printf("KUY\n"); int cur = u; for(int i = 20; i>= 0; i--) { if(gim3(u, par[i][cur])[2] == dep[u]-dep[par[i][cur]]+1) { cur = par[i][cur]; } } int bad = par[0][cur]; if(gim3(u, cur)[2] != dep[u]-dep[cur]+1) bad = cur; shift(u, bad, dc-dd); if(bad) { ll c = gimme(Cat, bad), d = gimme(Dog, bad); if(stat[bad] == 1) d = 1e9; if(stat[bad] == 2) c = 1e9; // printf("diff = %d\n", (int) (c-d)); if(c-d> 2) diffcat(par[0][bad], dd, dd); else if(c-d == 2) diffcat(par[0][bad], dc+1, dd); else if(c-d == 0) diffcat(par[0][bad], dc, dc+1); else diffcat(par[0][bad], dc, dc); } rangeplus(Cat, u, bad, dc); rangeplus(Dog, u, bad, dd); spec(u, stat[u]?1e9:gimme(Cat, u)-gimme(Dog, u)); return; } if(dd-dc == 1) { int cur = u; for(int i = 20; i>= 0; i--) { auto tmp = gim3(u, par[i][cur]); if(tmp[1]+tmp[2] == dep[u]-dep[par[i][cur]]+1) { cur = par[i][cur]; } } int bad = par[0][cur]; auto ff = gim3(bad, bad); // printf("bad1 = %d\n", bad); // printf("%d %d %d\n", ff[0], ff[1], ff[2]); auto tmp = gim3(u, cur); if(tmp[1]+tmp[2] != dep[u]-dep[cur]+1) bad = cur; // printf("bad2 = %d\n", bad); shift(u, bad, dc-dd); if(bad) { ll c = gimme(Cat, bad), d = gimme(Dog, bad); if(stat[bad] == 1) d = 1e9; if(stat[bad] == 2) c = 1e9; if(c-d> 1) { diffcat(par[0][bad], dd, dd); // printf("KUY\n"); } if(c-d< 0) diffcat(par[0][bad], dc, dc); } rangeplus(Cat, u, bad, dc); rangeplus(Dog, u, bad, dd); spec(u, stat[u]?1e9:gimme(Cat, u)-gimme(Dog, u)); return; } if(dd-dc == 0) { rangeplus(Cat, u, 1, dc); rangeplus(Dog, u, 1, dd); return; } if(dd-dc == -1) { int cur = u; for(int i = 20; i>= 0; i--) { auto tmp = gim3(u, par[i][cur]); if(tmp[0]+tmp[1] == dep[u]-dep[par[i][cur]]+1) { cur = par[i][cur]; } } int bad = par[0][cur]; auto tmp = gim3(u, cur); if(tmp[0]+tmp[1] != dep[u]-dep[cur]+1) bad = cur; shift(u, bad, dc-dd); if(bad) { ll c = gimme(Cat, bad), d = gimme(Dog, bad); if(stat[bad] == 1) d = 1e9; if(stat[bad] == 2) c = 1e9; if(c-d< -1) diffcat(par[0][bad], dc, dc); if(c-d> 0) diffcat(par[0][bad], dd, dd); } rangeplus(Cat, u, bad, dc); rangeplus(Dog, u, bad, dd); spec(u, stat[u]?1e9:gimme(Cat, u)-gimme(Dog, u)); return; } if(dd-dc == -2) { int cur = u; for(int i = 20; i>= 0; i--) { if(gim3(u, par[i][cur])[0] == dep[u]-dep[par[i][cur]]+1) { cur = par[i][cur]; } } int bad = par[0][cur]; if(gim3(u, cur)[0] != dep[u]-dep[cur]+1) bad = cur; shift(u, bad, dc-dd); if(bad) { ll c = gimme(Cat, bad), d = gimme(Dog, bad); if(stat[bad] == 1) d = 1e9; if(stat[bad] == 2) c = 1e9; if(c-d< -2) diffcat(par[0][bad], dc, dc); else if(c-d == -2) diffcat(par[0][bad], dc, dd+1); else if(c-d == 0) diffcat(par[0][bad], dd+1, dd); else diffcat(par[0][bad], dd, dd); } rangeplus(Cat, u, bad, dc); rangeplus(Dog, u, bad, dd); spec(u, stat[u]?1e9:gimme(Cat, u)-gimme(Dog, u)); return; } } ll getanswer() { // auto tmp = foo.ask(1, n); // printf("%d %d %d\n", tmp.vec[0], tmp.vec[1], tmp.vec[2]); // for(int i = 1; i<= n; i++) printf("%d ", (int) gimme(Cat, i)); printf("\n"); // for(int i = 1; i<= n; i++) printf("%d ", (int) gimme(Dog, i)); printf("\n"); // for(int i = 1; i<= n; i++) // { // int go = -1; // auto tmp = gim3(i, i); // for(int j = 0; j< 3; j++) // { // if(tmp[j]) // { // go = j; // break; // } // } // if(go == -1) printf("? "); // else if(go == 0) printf("- "); // else if(go == 1) printf("0 "); // else printf("1 "); // } // printf("\n"); // for(int i = 1; i<= n; i++) printf("%c ", stat[i]?'*':'-'); // printf("\n"); // puts("---"); ll c = gimme(Cat, 1); ll d = gimme(Dog, 1); if(stat[1] == 1) return c; if(stat[1] == 2) return d; return min(c, d); } int cat(int u) { // printf("cat(%d)\n", u); stat[u] = 1; ll c = gimme(Cat, u), d = gimme(Dog, u); if(c == d) diffcat(par[0][u], 0, 1); if(c> d) diffcat(par[0][u], c-d-1, c-d-1+2); spec(u, 1e9); return (int) getanswer(); } int dog(int u) { // printf("dog(%d)\n", u); stat[u] = 2; ll c = gimme(Cat, u), d = gimme(Dog, u); if(c == d) diffcat(par[0][u], 1, 0); if(c< d) diffcat(par[0][u], d-c-1+2, d-c-1); spec(u, 1e9); return (int) getanswer(); } int neighbor(int u) { // printf("neigh(%d)\n"s, u); ll c = gimme(Cat, u), d = gimme(Dog, u); if(stat[u] == 1) { if(c == d) diffcat(par[0][u], 0, -1); if(c> d) diffcat(par[0][u], -(c-d-1), -(c-d-1+2)); } if(stat[u] == 2) { if(c == d) diffcat(par[0][u], -1, 0); if(c< d) diffcat(par[0][u], -(d-c-1+2), -(d-c-1)); } stat[u] = 0; spec(u, c-d); return (int) getanswer(); } void initialize(int N, vector<int> A, vector<int> B) { n = N; for(int i = 0; i< n; i++) { adj[A[i]].pb(B[i]); adj[B[i]].pb(A[i]); } dfs(); hld(); foo.build(); }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...