Submission #64954

#TimeUsernameProblemLanguageResultExecution timeMemory
64954zscoderCats or Dogs (JOI18_catdog)C++17
100 / 100
1446 ms38264 KiB
#include "catdog.h"
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
 
using namespace std;
using namespace __gnu_pbds;
 
#define fi first
#define se second
#define mp make_pair
#define pb push_back
 
typedef long long ll;
typedef pair<int,int> ii;
typedef vector<int> vi;
typedef long double ld; 
typedef tree<ii, null_type, less<ii>, rb_tree_tag, tree_order_statistics_node_update> pbds;

const int N = 111111;
int n;
vi adj[N];
int subsize[N];
int in[N];
int out[N];
int nxt[N];

const int INF = int(1e8);

void dfs_sz(int u, int p = -1)
{
	subsize[u]=1;
	if(adj[u].size()>=2&&adj[u][0]==p) swap(adj[u][0],adj[u][1]);
	for(auto &v:adj[u])
	{
		if(v==p) continue;
		dfs_sz(v,u);
		subsize[u]+=subsize[v];
		if(subsize[v]>subsize[adj[u][0]])
		{
			swap(v,adj[u][0]);
		}
	}
}

int timer;
int par[N];
int endpath[N];

void dfs_hld(int u, int p = -1)
{
	in[u]=timer++;
	for(auto v:adj[u])
	{
		if(v==p) continue;
		//cerr<<u<<' '<<v<<' '<<nxt[u]<<'\n';
		par[v]=u;
		nxt[v] = (v==adj[u][0]?nxt[u]:v);
		//cerr<<nxt[v]<<'\n';
		dfs_hld(v,u);
	}
	out[u]=timer;
}

struct matrix
{
	int a[2][2];
	int* operator [] (int r) { return a[r]; };
};

matrix operator+(matrix a, matrix b)
{
	matrix c; c[0][0]=c[0][1]=c[1][0]=c[1][1]=INF;
	for(int i=0;i<2;i++)
	{
		for(int j=0;j<2;j++)
		{
			for(int k=0;k<2;k++)
			{
				for(int l=0;l<2;l++)
				{
					c[i][l] = min(c[i][l], a[i][j] + b[k][l] + (j^k));
				}
			}
		}
	}
	return c;
}

struct node
{
	int col;
	matrix M;
};

node st[4*N+6];
node def;
ii dp[N];
node emp;

node combine(node a, node b)
{
	if(a.col==-2) return b;
	else if(b.col==-2) return a;
	if(a.col>=0){for(int i=0;i<2;i++){for(int j=0;j<2;j++){if(i!=a.col||j!=a.col) a.M[i][j]=INF;}}}
	if(b.col>=0){for(int i=0;i<2;i++){for(int j=0;j<2;j++){if(i!=b.col||j!=b.col) b.M[i][j]=INF;}}}
	node c; c.col=-1; c.M = a.M+b.M;
	return c;
}

void build(int id, int l, int r)
{
	if(r-l<2)
	{
		st[id] = def;
		return ;
	}
	int mid=(l+r)>>1;
	build(id*2,l,mid); build(id*2+1,mid,r);
	st[id] = combine(st[id*2], st[id*2+1]);
}

void update(int id, int l, int r, int pos, int v) //color of pos becomes v
{
	if(pos>=r||pos<l) return ;
	if(r-l<2)
	{
		st[id].col = v; //only update the color, the matrix remains (I hope)
		return ;
	}
	int mid=(l+r)>>1;
	update(id*2,l,mid,pos,v); update(id*2+1,mid,r,pos,v);
	st[id] = combine(st[id*2], st[id*2+1]);
}

void increment(int id, int l, int r, int pos, ii pre, ii nw) //replace pre to nw
{
	if(pos>=r||pos<l) return ;
	if(r-l<2)
	{
		st[id].M[0][0]-=min(pre.fi,pre.se+1); st[id].M[1][1]-=min(pre.fi+1,pre.se);
		st[id].M[0][0]+=min(nw.fi,nw.se+1); st[id].M[1][1]+=min(nw.fi+1,nw.se);
		return ;
	}
	int mid=(l+r)>>1;
	increment(id*2,l,mid,pos,pre,nw); increment(id*2+1,mid,r,pos,pre,nw);
	st[id] = combine(st[id*2], st[id*2+1]);
}

node query(int id, int l, int r, int ql, int qr) //get the transition matrix in range
{
	if(ql>=r||l>=qr) return emp;
	if(ql<=l&&r<=qr) return st[id];
	int mid=(l+r)>>1;
	return combine(query(id*2,l,mid,ql,qr),query(id*2+1,mid,r,ql,qr));
}

int color[N];

void initialize(int N, std::vector<int> A, std::vector<int> B) 
{
	emp.col=-2;
	def.col=-1; def.M[0][0]=0; def.M[0][1]=INF; def.M[1][0]=INF; def.M[1][1]=0;
	n=N; 
	for(int i=0;i<N-1;i++) 
	{
		adj[A[i]-1].pb(B[i]-1); adj[B[i]-1].pb(A[i]-1);
	}
	memset(par,-1,sizeof(par));
	dfs_sz(0); dfs_hld(0);
	build(1,0,n); 
	memset(color,-1,sizeof(color));
	for(int i=0;i<n;i++)
	{
		dp[i]=mp(0,0);
	}
	for(int i=0;i<n;i++)
	{
		endpath[nxt[i]] = max(endpath[nxt[i]], in[i]);
	}
}

ii get_value(int u) 
{
	node tmp = query(1,0,n,in[u],endpath[u]+1);
	ii res = mp(min(tmp.M[0][0],tmp.M[0][1]),min(tmp.M[1][0],tmp.M[1][1]));
	if(color[u]==0) res.se=INF;
	if(color[u]==1) res.fi=INF;
	return res;
}

void change_color(int u, int c)
{
	color[u]=c;
	update(1,0,n,in[u],c);
	while(u!=-1)
	{
		//[nxt[u],u]
		u=nxt[u];
		ii res = get_value(u);
		if(par[u]!=-1)
		{
			increment(1,0,n,in[par[u]],dp[u],res);
		}
		dp[u]=res;
		u=par[u];
	}
}

int solve()
{
	ii tmp = get_value(0);
	return min(tmp.fi,tmp.se);
}

int cat(int v) 
{
	v--; 
	change_color(v,0);
	return solve();
}

int dog(int v) 
{
	v--; 
	change_color(v,1); 
	return solve();
}

int neighbor(int v) 
{
	v--;
	change_color(v,-1); 
	return solve();
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...