This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
#include <bits/stdc++.h>
#pragma GCC optimize ("O3")
#pragma GCC target ("sse4")
using namespace std;
#define X first
#define Y second
#define pb push_back
typedef pair<int, int> ii;
typedef long long ll;
const int maxn = 1e5+5;
int n;
struct fenwick
{
ll ft[maxn];
ll sum(int x)
{
ll res = 0;
for(; x; x -= x&(-x)) res += ft[x];
return res;
}
void change(int x, int dx)
{
for(; x<= n; x += x&(-x)) ft[x] += dx;
}
ll ask(int x)
{
return sum(x);
}
void update(int a, int b, int dx)
{
change(a, dx);
change(b+1, -dx);
}
};
struct segtree
{
struct node
{
vector<int> vec = vector<int>(3, 0);
int lz = 0;
node(){}
node(vector<int> vec) : vec(vec) {}
};
node st[4*maxn];
void push(int p, int L, int R)
{
int &lz = st[p].lz;
if(!lz) return;
vector<int> &vec = st[p].vec;
vector<int> nou(3);
for(int i = 0; i< 3; i++)
{
nou[i] = (0<= i-lz && i-lz< 3)?vec[i-lz]:0;
}
vec = nou;
if(L != R)
{
st[2*p].lz += lz;
st[2*p+1].lz += lz;
}
lz = 0;
}
node pull(node &x, node &y)
{
node res;
for(int i = 0; i< 3; i++) res.vec[i] = x.vec[i]+y.vec[i];
return res;
}
void build(int p = 1, int L = 1, int R = n)
{
if(L == R)
{
st[p].vec = {0, 1, 0};
return;
}
int M = (L+R)/2;
build(2*p, L, M);
build(2*p+1, M+1, R);
st[p] = pull(st[2*p], st[2*p+1]);
}
node ask(int i, int j, int p = 1, int L = 1, int R = n)
{
if(i> R || j< L) return node();
push(p, L, R);
if(i<= L && R<= j) return st[p];
int M = (L+R)/2;
node x = ask(i, j, 2*p, L, M);
node y = ask(i, j, 2*p+1, M+1, R);
node res = pull(x, y);
return res;
}
void update(int i, int j, int dx, int p = 1, int L = 1, int R = n)
{
push(p, L, R);
if(i> R || j< L) return;
if(i<= L && R<= j)
{
st[p].lz += dx;
push(p, L, R);
return;
}
int M = (L+R)/2;
update(i, j, dx, 2*p, L, M);
update(i, j, dx, 2*p+1, M+1, R);
st[p] = pull(st[2*p], st[2*p+1]);
}
void point(int x, int dx, int p = 1, int L = 1, int R = n)
{
push(p, L, R);
if(x> R || x< L) return;
if(x<= L && R<= x)
{
st[p].vec = {0, 0, 0};
if(0<= dx+1 && dx+1< 3) st[p].vec[dx+1] = 1;
return;
}
int M = (L+R)/2;
point(x, dx, 2*p, L, M);
point(x, dx, 2*p+1, M+1, R);
st[p] = pull(st[2*p], st[2*p+1]);
}
};
vector<int> adj[maxn];
int par[22][maxn];
int pos[maxn];
int head[maxn];
int prf[maxn];
int cnt[maxn];
int dep[maxn];
void dfs(int u = 1, int p = 0)
{
dep[u] = dep[p]+1;
par[0][u] = p;
for(int i = 1; i<= 20; i++) par[i][u] = par[i-1][par[i-1][u]];
cnt[u] = 1;
ii best = {0, -1};
for(int v : adj[u])
{
if(v == p) continue;
dfs(v, u);
best = max(best, {cnt[v], v});
cnt[u] += cnt[v];
}
prf[u] = best.Y;
}
void hld()
{
int tim = 1;
for(int i = 1; i<= n; i++)
{
if(prf[par[0][i]] == i) continue;
for(int j = i; j != -1; j = prf[j])
{
head[j] = i;
pos[j] = tim++;
}
}
}
fenwick Cat, Dog;
ll gimme(fenwick &ft, int x)
{
return ft.ask(pos[x]);
}
void rangeplus(fenwick &ft, int u, int v, int dx)
{
if(v == 0) v = 1;
if(u == 0) return;
while(head[u] != head[v])
{
ft.update(pos[head[u]], pos[u], dx);
u = par[0][head[u]];
}
ft.update(pos[v], pos[u], dx);
}
segtree foo;
vector<int> gim3(int u, int v)
{
if(u == 0) return {0, 0, 0};
if(v == 0) v = 1;
vector<int> res(3, 0);
while(head[u] != head[v])
{
auto tmp = foo.ask(pos[head[u]], pos[u]);
for(int i = 0; i< 3; i++) res[i] += tmp.vec[i];
u = par[0][head[u]];
}
auto tmp = foo.ask(pos[v], pos[u]);
for(int i = 0; i< 3; i++) res[i] += tmp.vec[i];
return res;
}
void shift(int u, int v, int dx)
{
if(v == 0) v = 1;
while(head[u] != head[v])
{
foo.update(pos[head[u]], pos[u], dx);
u = par[0][head[u]];
}
foo.update(pos[v], pos[u], dx);
}
void spec(int u, int dx)
{
foo.point(pos[u], dx);
}
int stat[maxn];
void diffcat(int u, int dc, int dd)
{
if(u == 0) return;
if(u == 1)
{
rangeplus(Cat, u, u, dc);
rangeplus(Dog, u, u, dd);
spec(u, stat[u]?1e9:gimme(Cat, u)-gimme(Dog, u));
return;
}
if(dd-dc == 2)
{
// printf("KUY\n");
int cur = u;
for(int i = 20; i>= 0; i--)
{
if(gim3(u, par[i][cur])[2] == dep[u]-dep[par[i][cur]]+1)
{
cur = par[i][cur];
}
}
int bad = par[0][cur];
if(gim3(u, cur)[2] != dep[u]-dep[cur]+1) bad = cur;
shift(u, bad, dc-dd);
if(bad)
{
ll c = gimme(Cat, bad), d = gimme(Dog, bad);
if(stat[bad] == 1) d = 1e9;
if(stat[bad] == 2) c = 1e9;
// printf("diff = %d\n", (int) (c-d));
if(c-d> 2) diffcat(par[0][bad], dd, dd);
else if(c-d == 2) diffcat(par[0][bad], dc+1, dd);
else if(c-d == 0) diffcat(par[0][bad], dc, dc+1);
else diffcat(par[0][bad], dc, dc);
}
rangeplus(Cat, u, bad, dc);
rangeplus(Dog, u, bad, dd);
spec(u, stat[u]?1e9:gimme(Cat, u)-gimme(Dog, u));
return;
}
if(dd-dc == 1)
{
int cur = u;
for(int i = 20; i>= 0; i--)
{
auto tmp = gim3(u, par[i][cur]);
if(tmp[1]+tmp[2] == dep[u]-dep[par[i][cur]]+1)
{
cur = par[i][cur];
}
}
int bad = par[0][cur];
auto ff = gim3(bad, bad);
// printf("bad1 = %d\n", bad);
// printf("%d %d %d\n", ff[0], ff[1], ff[2]);
auto tmp = gim3(u, cur);
if(tmp[1]+tmp[2] != dep[u]-dep[cur]+1) bad = cur;
// printf("bad2 = %d\n", bad);
shift(u, bad, dc-dd);
if(bad)
{
ll c = gimme(Cat, bad), d = gimme(Dog, bad);
if(stat[bad] == 1) d = 1e9;
if(stat[bad] == 2) c = 1e9;
if(c-d> 1)
{
diffcat(par[0][bad], dd, dd);
// printf("KUY\n");
}
if(c-d< 0) diffcat(par[0][bad], dc, dc);
}
rangeplus(Cat, u, bad, dc);
rangeplus(Dog, u, bad, dd);
spec(u, stat[u]?1e9:gimme(Cat, u)-gimme(Dog, u));
return;
}
if(dd-dc == 0)
{
rangeplus(Cat, u, 1, dc);
rangeplus(Dog, u, 1, dd);
return;
}
if(dd-dc == -1)
{
int cur = u;
for(int i = 20; i>= 0; i--)
{
auto tmp = gim3(u, par[i][cur]);
if(tmp[0]+tmp[1] == dep[u]-dep[par[i][cur]]+1)
{
cur = par[i][cur];
}
}
int bad = par[0][cur];
auto tmp = gim3(u, cur);
if(tmp[0]+tmp[1] != dep[u]-dep[cur]+1) bad = cur;
shift(u, bad, dc-dd);
if(bad)
{
ll c = gimme(Cat, bad), d = gimme(Dog, bad);
if(stat[bad] == 1) d = 1e9;
if(stat[bad] == 2) c = 1e9;
if(c-d< -1) diffcat(par[0][bad], dc, dc);
if(c-d> 0) diffcat(par[0][bad], dd, dd);
}
rangeplus(Cat, u, bad, dc);
rangeplus(Dog, u, bad, dd);
spec(u, stat[u]?1e9:gimme(Cat, u)-gimme(Dog, u));
return;
}
if(dd-dc == -2)
{
int cur = u;
for(int i = 20; i>= 0; i--)
{
if(gim3(u, par[i][cur])[0] == dep[u]-dep[par[i][cur]]+1)
{
cur = par[i][cur];
}
}
int bad = par[0][cur];
if(gim3(u, cur)[0] != dep[u]-dep[cur]+1) bad = cur;
shift(u, bad, dc-dd);
if(bad)
{
ll c = gimme(Cat, bad), d = gimme(Dog, bad);
if(stat[bad] == 1) d = 1e9;
if(stat[bad] == 2) c = 1e9;
if(c-d< -2) diffcat(par[0][bad], dc, dc);
else if(c-d == -2) diffcat(par[0][bad], dc, dd+1);
else if(c-d == 0) diffcat(par[0][bad], dd+1, dd);
else diffcat(par[0][bad], dd, dd);
}
rangeplus(Cat, u, bad, dc);
rangeplus(Dog, u, bad, dd);
spec(u, stat[u]?1e9:gimme(Cat, u)-gimme(Dog, u));
return;
}
}
ll getanswer()
{
// auto tmp = foo.ask(1, n);
// printf("%d %d %d\n", tmp.vec[0], tmp.vec[1], tmp.vec[2]);
// for(int i = 1; i<= n; i++) printf("%d ", (int) gimme(Cat, i)); printf("\n");
// for(int i = 1; i<= n; i++) printf("%d ", (int) gimme(Dog, i)); printf("\n");
// for(int i = 1; i<= n; i++)
// {
// int go = -1;
// auto tmp = gim3(i, i);
// for(int j = 0; j< 3; j++)
// {
// if(tmp[j])
// {
// go = j;
// break;
// }
// }
// if(go == -1) printf("? ");
// else if(go == 0) printf("- ");
// else if(go == 1) printf("0 ");
// else printf("1 ");
// }
// printf("\n");
// for(int i = 1; i<= n; i++) printf("%c ", stat[i]?'*':'-');
// printf("\n");
// puts("---");
ll c = gimme(Cat, 1);
ll d = gimme(Dog, 1);
if(stat[1] == 1) return c;
if(stat[1] == 2) return d;
return min(c, d);
}
int cat(int u)
{
// printf("cat(%d)\n", u);
stat[u] = 1;
ll c = gimme(Cat, u), d = gimme(Dog, u);
if(c == d) diffcat(par[0][u], 0, 1);
if(c> d) diffcat(par[0][u], c-d-1, c-d-1+2);
spec(u, 1e9);
return (int) getanswer();
}
int dog(int u)
{
// printf("dog(%d)\n", u);
stat[u] = 2;
ll c = gimme(Cat, u), d = gimme(Dog, u);
if(c == d) diffcat(par[0][u], 1, 0);
if(c< d) diffcat(par[0][u], d-c-1+2, d-c-1);
spec(u, 1e9);
return (int) getanswer();
}
int neighbor(int u)
{
// printf("neigh(%d)\n"s, u);
ll c = gimme(Cat, u), d = gimme(Dog, u);
if(stat[u] == 1)
{
if(c == d) diffcat(par[0][u], 0, -1);
if(c> d) diffcat(par[0][u], -(c-d-1), -(c-d-1+2));
}
if(stat[u] == 2)
{
if(c == d) diffcat(par[0][u], -1, 0);
if(c< d) diffcat(par[0][u], -(d-c-1+2), -(d-c-1));
}
stat[u] = 0;
spec(u, c-d);
return (int) getanswer();
}
void initialize(int N, vector<int> A, vector<int> B)
{
n = N;
for(int i = 0; i< n; i++)
{
adj[A[i]].pb(B[i]);
adj[B[i]].pb(A[i]);
}
dfs(); hld();
foo.build();
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |