Submission #124984

# Submission time Handle Problem Language Result Execution time Memory
124984 2019-07-04T09:30:59 Z WhipppedCream Cats or Dogs (JOI18_catdog) C++17
0 / 100
41 ms 27896 KB
#include <bits/stdc++.h>
#pragma GCC optimize ("O3")
#pragma GCC target ("sse4")
using namespace std;
#define X first
#define Y second
#define pb push_back
typedef pair<int, int> ii;
typedef long long ll;

const int maxn = 1e5+5;

int n;

struct fenwick
{
	ll ft[maxn];
	ll sum(int x)
	{
		ll res = 0;
		for(; x; x -= x&(-x)) res += ft[x];
		return res;
	}
	void change(int x, int dx)
	{
		for(; x<= n; x += x&(-x)) ft[x] += dx;
	}
	ll ask(int x)
	{
		return sum(x);
	}
	void update(int a, int b, int dx)
	{
		change(a, dx);
		change(b+1, -dx);
	}
};

struct segtree
{
	struct node
	{
		vector<int> vec = vector<int>(3, 0);
		int lz = 0;
		node(){}
		node(vector<int> vec) : vec(vec) {}
	};
	node st[4*maxn];
	void push(int p, int L, int R)
	{
		int &lz = st[p].lz;
		if(!lz) return;
		vector<int> &vec = st[p].vec;
		vector<int> nou(3);
		for(int i = 0; i< 3; i++)
		{
			nou[i] = (0<= i-lz && i-lz< 3)?vec[i-lz]:0;
		}
		vec = nou;
		if(L != R)
		{
			st[2*p].lz += lz;
			st[2*p+1].lz += lz;
		}
		lz = 0;
	}
	node pull(node &x, node &y)
	{
		node res;
		for(int i = 0; i< 3; i++) res.vec[i] = x.vec[i]+y.vec[i];
		return res;
	}
	void build(int p = 1, int L = 1, int R = n)
	{
		if(L == R)
		{
			st[p].vec = {0, 1, 0};
			return;
		}
		int M = (L+R)/2;
		build(2*p, L, M);
		build(2*p+1, M+1, R);
		st[p] = pull(st[2*p], st[2*p+1]);
	}
	node ask(int i, int j, int p = 1, int L = 1, int R = n)
	{
		if(i> R || j< L) return node();
		push(p, L, R);
		if(i<= L && R<= j) return st[p];
		int M = (L+R)/2;
		node x = ask(i, j, 2*p, L, M);
		node y = ask(i, j, 2*p+1, M+1, R);
		node res = pull(x, y);
		return res;
	}
	void update(int i, int j, int dx, int p = 1, int L = 1, int R = n)
	{
		push(p, L, R);
		if(i> R || j< L) return;
		if(i<= L && R<= j)
		{
			st[p].lz += dx;
			push(p, L, R);
			return;
		}
		int M = (L+R)/2;
		update(i, j, dx, 2*p, L, M);
		update(i, j, dx, 2*p+1, M+1, R);
		st[p] = pull(st[2*p], st[2*p+1]);
	}
	void point(int x, int dx, int p = 1, int L = 1, int R = n)
	{
		push(p, L, R);
		if(x> R || x< L) return;
		if(x<= L && R<= x)
		{
			st[p].vec = {0, 0, 0};
			if(0<= dx+1 && dx+1< 3) st[p].vec[dx+1] = 1;
			return;
		}
		int M = (L+R)/2;
		point(x, dx, 2*p, L, M);
		point(x, dx, 2*p+1, M+1, R);
		st[p] = pull(st[2*p], st[2*p+1]);
	}
};

vector<int> adj[maxn];
int par[22][maxn];
int pos[maxn];
int head[maxn];
int prf[maxn];
int cnt[maxn];
int dep[maxn];

void dfs(int u = 1, int p = 0)
{
	dep[u] = dep[p]+1;
	par[0][u] = p;
	for(int i = 1; i<= 20; i++) par[i][u] = par[i-1][par[i-1][u]];
	cnt[u] = 1;
	ii best = {0, -1};
	for(int v : adj[u])
	{
		if(v == p) continue;
		dfs(v, u);
		best = max(best, {cnt[v], v});
		cnt[u] += cnt[v];
	}
	prf[u] = best.Y;
}

void hld()
{
	int tim = 1;
	for(int i = 1; i<= n; i++)
	{
		if(prf[par[0][i]] == i) continue;
		for(int j = i; j != -1; j = prf[j])
		{
			head[j] = i;
			pos[j] = tim++;
		}
	}
}

fenwick Cat, Dog;

ll gimme(fenwick &ft, int x)
{
	return ft.ask(pos[x]);
}

void rangeplus(fenwick &ft, int u, int v, int dx)
{
	if(v == 0) v = 1;
	if(u == 0) return;
	while(head[u] != head[v])
	{
		ft.update(pos[head[u]], pos[u], dx);
		u = par[0][head[u]];
	}
	ft.update(pos[v], pos[u], dx);
}

segtree foo;

vector<int> gim3(int u, int v)
{
	if(u == 0) return {0, 0, 0};
	if(v == 0) v = 1;
	vector<int> res(3, 0);
	while(head[u] != head[v])
	{
		auto tmp = foo.ask(pos[head[u]], pos[u]);
		for(int i = 0; i< 3; i++) res[i] += tmp.vec[i];
		u = par[0][head[u]];
	}
	auto tmp = foo.ask(pos[v], pos[u]);
	for(int i = 0; i< 3; i++) res[i] += tmp.vec[i];
	return res;
}

void shift(int u, int v, int dx)
{
	if(v == 0) v = 1;
	while(head[u] != head[v])
	{
		foo.update(pos[head[u]], pos[u], dx);
		u = par[0][head[u]];
	}
	foo.update(pos[v], pos[u], dx);
}

void spec(int u, int dx)
{
	foo.point(pos[u], dx);
}

int stat[maxn];

void diffcat(int u, int dc, int dd)
{
	if(u == 0) return;
	if(u == 1)
	{
		rangeplus(Cat, u, u, dc);
		rangeplus(Dog, u, u, dd);
		spec(u, stat[u]?1e9:gimme(Cat, u)-gimme(Dog, u));
		return;
	}
	if(dd-dc == 2)
	{
		// printf("KUY\n");
		int cur = u;
		for(int i = 20; i>= 0; i--)
		{
			if(gim3(u, par[i][cur])[2] == dep[u]-dep[par[i][cur]]+1)
			{
				cur = par[i][cur];
			}
		}
		int bad = par[0][cur];
		if(gim3(u, cur)[2] != dep[u]-dep[cur]+1) bad = cur;
		shift(u, bad, dc-dd);
		if(bad)
		{
			ll c = gimme(Cat, bad), d = gimme(Dog, bad);
			if(stat[u] == 1) d = 1e9;
			if(stat[u] == 2) c = 1e9;
			// printf("diff = %d\n", (int) (c-d));
			if(c-d> 2) diffcat(par[0][bad], dd, dd);
			else if(c-d == 2) diffcat(par[0][bad], dc+1, dd);
			else if(c-d == 0) diffcat(par[0][bad], dc, dc+1);
			else diffcat(par[0][bad], dc, dc);
		}
		rangeplus(Cat, u, bad, dc);
		rangeplus(Dog, u, bad, dd);
		spec(u, stat[u]?1e9:gimme(Cat, u)-gimme(Dog, u));
		return;
	}
	if(dd-dc == 1)
	{
		int cur = u;
		for(int i = 20; i>= 0; i--)
		{
			auto tmp = gim3(u, par[i][cur]);
			if(tmp[1]+tmp[2] == dep[u]-dep[par[i][cur]]+1)
			{
				cur = par[i][cur];
			}
		}
		int bad = par[0][cur];
		auto ff = gim3(bad, bad);
		// printf("bad1 = %d\n", bad);
		// printf("%d %d %d\n", ff[0], ff[1], ff[2]);
		auto tmp = gim3(u, cur);
		if(tmp[1]+tmp[2] != dep[u]-dep[cur]+1) bad = cur;
		// printf("bad2 = %d\n", bad);
		shift(u, bad, dc-dd);
		if(bad)
		{
			ll c = gimme(Cat, bad), d = gimme(Dog, bad);
			if(stat[u] == 1) d = 1e9;
			if(stat[u] == 2) c = 1e9;
			if(c-d> 1) diffcat(par[0][bad], dd, dd);
			if(c-d< 0) diffcat(par[0][bad], dc, dc);
		}
		rangeplus(Cat, u, bad, dc);
		rangeplus(Dog, u, bad, dd);
		spec(u, stat[u]?1e9:gimme(Cat, u)-gimme(Dog, u));
		return;
	}
	if(dd-dc == 0)
	{
		rangeplus(Cat, u, 1, dc);
		rangeplus(Dog, u, 1, dd);
		return;
	}
	if(dd-dc == -1)
	{
		int cur = u;
		for(int i = 20; i>= 0; i--)
		{
			auto tmp = gim3(u, par[i][cur]);
			if(tmp[0]+tmp[1] == dep[u]-dep[par[i][cur]]+1)
			{
				cur = par[i][cur];
			}
		}
		int bad = par[0][cur];
		auto tmp = gim3(u, cur);
		if(tmp[0]+tmp[1] != dep[u]-dep[cur]+1) bad = cur;
		shift(u, bad, dc-dd);
		if(bad)
		{
			ll c = gimme(Cat, bad), d = gimme(Dog, bad);
			if(stat[u] == 1) d = 1e9;
			if(stat[u] == 2) c = 1e9;
			if(c-d< -1) diffcat(par[0][bad], dc, dc);
			if(c-d> 0) diffcat(par[0][bad], dd, dd);
		}
		rangeplus(Cat, u, bad, dc);
		rangeplus(Dog, u, bad, dd);
		spec(u, stat[u]?1e9:gimme(Cat, u)-gimme(Dog, u));
		return;
	}
	if(dd-dc == -2)
	{
		int cur = u;
		for(int i = 20; i>= 0; i--)
		{
			if(gim3(u, par[i][cur])[0] == dep[u]-dep[par[i][cur]]+1)
			{
				cur = par[i][cur];
			}
		}
		int bad = par[0][cur];
		if(gim3(u, cur)[0] != dep[u]-dep[cur]+1) bad = cur;
		shift(u, bad, dc-dd);
		if(bad)
		{
			ll c = gimme(Cat, bad), d = gimme(Dog, bad);
			if(stat[u] == 1) d = 1e9;
			if(stat[u] == 2) c = 1e9;
			if(c-d< -2) diffcat(par[0][bad], dc, dc);
			else if(c-d == -2) diffcat(par[0][bad], dc, dd+1);
			else if(c-d == 0) diffcat(par[0][bad], dd+1, dd);
			else diffcat(par[0][bad], dd, dd);
		}
		rangeplus(Cat, u, bad, dc);
		rangeplus(Dog, u, bad, dd);
		spec(u, stat[u]?1e9:gimme(Cat, u)-gimme(Dog, u));
		return;
	}
}

ll getanswer()
{
	auto tmp = foo.ask(1, n);
	// printf("%d %d %d\n", tmp.vec[0], tmp.vec[1], tmp.vec[2]);
	// for(int i = 1; i<= n; i++) printf("%d ", (int) gimme(Cat, i)); printf("\n");
	// for(int i = 1; i<= n; i++) printf("%d ", (int) gimme(Dog, i)); printf("\n");
	// for(int i = 1; i<= n; i++)
	// {
	// 	int go = -1;
	// 	auto tmp = gim3(i, i);
	// 	for(int j = 0; j< 3; j++)
	// 	{
	// 		if(tmp[j])
	// 		{
	// 			go = j;
	// 			break;
	// 		}
	// 	}
	// 	if(go == -1) printf("? ");
	// 	else if(go == 0) printf("- ");
	// 	else if(go == 1) printf("0 ");
	// 	else printf("1 ");
	// }
	// printf("\n");
	// for(int i = 1; i<= n; i++) printf("%c ", stat[i]?'*':'-');
	// printf("\n");
	// puts("---");
	ll c = gimme(Cat, 1);
	ll d = gimme(Dog, 1);
	if(stat[1] == 1) return c;
	if(stat[1] == 2) return d;
	return min(c, d);
}

int cat(int u)
{
	// printf("cat(%d)\n", u);
	stat[u] = 1;
	ll c = gimme(Cat, u), d = gimme(Dog, u);
	if(c == d) diffcat(par[0][u], 0, 1);
	if(c> d) diffcat(par[0][u], c-d-1, c-d-1+2);
	spec(u, 1e9);
	return (int) getanswer();
}

int dog(int u)
{
	// printf("dog(%d)\n", u);
	stat[u] = 2;
	ll c = gimme(Cat, u), d = gimme(Dog, u);
	if(c == d) diffcat(par[0][u], 1, 0);
	if(c< d) diffcat(par[0][u], d-c-1+2, d-c-1);
	spec(u, 1e9);
	return (int) getanswer();
}

int neighbor(int u)
{
	// printf("neigh(%d)\n"s, u);
	ll c = gimme(Cat, u), d = gimme(Dog, u);
	if(stat[u] == 1)
	{
		if(c == d) diffcat(par[0][u], 0, -1);
		if(c> d) diffcat(par[0][u], -(c-d-1), -(c-d-1+2));
	}
	if(stat[u] == 2)
	{
		if(c == d) diffcat(par[0][u], -1, 0);
		if(c< d) diffcat(par[0][u], -(d-c-1+2), -(d-c-1));
	}
	stat[u] = 0;
	spec(u, c-d);
	return (int) getanswer();
}

void initialize(int N, vector<int> A, vector<int> B)
{
	n = N;
	for(int i = 0; i< n; i++)
	{
		adj[A[i]].pb(B[i]);
		adj[B[i]].pb(A[i]);
	}
	dfs(); hld();
	foo.build();
}
# Verdict Execution time Memory Grader output
1 Correct 41 ms 27896 KB Output is correct
2 Incorrect 41 ms 27896 KB Output isn't correct
3 Halted 0 ms 0 KB -
# Verdict Execution time Memory Grader output
1 Correct 41 ms 27896 KB Output is correct
2 Incorrect 41 ms 27896 KB Output isn't correct
3 Halted 0 ms 0 KB -
# Verdict Execution time Memory Grader output
1 Correct 41 ms 27896 KB Output is correct
2 Incorrect 41 ms 27896 KB Output isn't correct
3 Halted 0 ms 0 KB -