Submission #848255

#TimeUsernameProblemLanguageResultExecution timeMemory
848255damot67679Cats or Dogs (JOI18_catdog)C++14
100 / 100
324 ms27732 KiB
#include "catdog.h"
#include <bits/stdc++.h>
using namespace std;

const int N = 1e5 + 5;

vector<int> adj[N];

int par[N], chain[N], head[N], tail[N], tin[N], sz[N];
int tchain = 0, tcur = 0, n;

void dfs_size(int x, int p) {
	sz[x] = 1;
	par[x] = p;
	for (int y: adj[x]) {
		if (y == p) continue;
		dfs_size(y, x);
		sz[x] += sz[y];
	}
}

void dfs(int x) {
	tin[x] = ++tcur;	
	if (!chain[x]) {
		chain[x] = ++tchain;
		head[tchain] = x;
	}
	tail[chain[x]] = x;
	int spec = -1;
	for (int y: adj[x]) {
		if (y == par[x]) continue;
		if (spec == -1 || sz[y] > sz[spec])
			spec = y;
	}
	if (spec != -1) {
		chain[spec] = tchain;
		dfs(spec);
	}
	for (int y: adj[x])
		if (y != par[x] && y != spec)
			dfs(y);
}

const int INF = 1e9;
struct Node {
	int val[2][2];
	
	Node() = default;
	
	Node(bool set_inf) {
		if (set_inf) {
			for (int a: {0, 1})
				for (int b: {0, 1})
					val[a][b] = INF;
		}
	}
	
	void debug() const {
		cerr << string(10, '=') << '\n';
		for (int a: {0, 1}) for (int b: {0, 1})
			fprintf(stderr, "val[%d][%d] = %d\n", a+1, b+1, val[a][b]);
		cerr << string(10, '=') << "\n";
	}
	
	int res() const {
		int ans = INF;
		for (int a: {0, 1})
			for (int b: {0, 1})
				ans = min(ans, val[a][b]);
		// cerr << "THIS IS RES ====\n";
		// debug();
		// cerr << string(15, '=') << "\n\n";
		return ans;
	}
	
	Node operator + (const Node &other) const {
		Node ans;
		for (int a = 0; a < 2; a++) {
			for (int d = 0; d < 2; d++) {
				int cur = INF;
				for (int b = 0; b < 2; b++)
					for (int c = 0; c < 2; c++)
						cur = min(cur, val[a][b] + other.val[c][d] + (b != c));
				ans.val[a][d] = cur;
			}
		}
		return ans;
	}
} IT[N << 2];

int dp[N][2]; // only consider vertices NOT in heavy chain, segment tree will consider the rest + dp
int up_par[N][2]; // for chain_head to update dp[par]'s values

Node node_inf(true);

void build(int id, int l, int r) {
	if (l == r) {
		for (int a: {0, 1})
			for (int b: {0, 1})
				IT[id].val[a][b] = (a != b ? INF : 0);
		return;
	}
	int mid = (l + r) / 2;
	build(id << 1, l, mid);
	build(id << 1 | 1, mid + 1, r);
	IT[id] = IT[id << 1] + IT[id << 1 | 1];
}

void initialize(int n, vector<int> A, vector<int> B) {
	::n = n;
	for (int i = 0; i < n - 1; i++) {
		int x = A[i], y = B[i];
		adj[x].emplace_back(y);
		adj[y].emplace_back(x);
	}
	dfs_size(1, 0);
	dfs(1);
	build(1, 1, n);
	// for (int i = 2; i <= n; i++) {
		// cerr << tin[par[i]] << ' ' << tin[i] << '\n';
	// }
}

void invalid_color(int x, int c, int id, int l, int r) {
	if (l == r) {
		// if (c ^ color) == 0 => skip
		IT[id] = node_inf;
		for (int col: {0, 1})
			if ((col + 1) ^ c)
				IT[id].val[col][col] = dp[x][col];
		// if (x == 2) cerr << string(25, '?') << " " << dp[2][0] << '\n';
		// IT[id].debug();
		return;
	}
	int mid = (l + r) / 2;
	if (x <= mid) invalid_color(x, c, id << 1, l, mid);
	else invalid_color(x, c, id << 1 | 1, mid + 1, r);
	IT[id] = IT[id << 1] + IT[id << 1 | 1];
}

void update_val(int x, int w[], int id, int l, int r) {
	if (l == r) {
		// assert(([&] () -> bool {
			// for (int a: {0, 1})
				// for (int b: {0, 1})
					// if (a != b && IT[id].val[a][b] != INF)
						// return false;
			// return true;
		// }()));
		for (int c: {0, 1})
			if (IT[id].val[c][c] != INF)
				IT[id].val[c][c] += w[c];
		// cerr << "x = " << x << '\n';
		// IT[id].debug();
		return;
	}
	int mid = (l + r) / 2;
	if (x <= mid) update_val(x, w, id << 1, l, mid);
	else update_val(x, w, id << 1 | 1, mid + 1, r);
	IT[id] = IT[id << 1] + IT[id << 1 | 1];
}

Node get_range(int x, int y, int id, int l, int r) {
	if (x <= l && r <= y) return /*fprintf(stderr, "get_range(%d, %d)\n", l, r), IT[id].debug(), */IT[id];
	int mid = (l + r) / 2;
	if (y <= mid) return get_range(x, y, id << 1, l, mid);
	if (x > mid) return get_range(x, y, id << 1 | 1, mid + 1, r);
	return get_range(x, y, id << 1, l, mid) + get_range(x, y, id << 1 | 1, mid + 1, r);
}

int get(int x) {
	int xchain = chain[x], head_ = head[xchain], tail_ = tail[xchain];
	// cerr << "head tail = " << head_ << ' ' << tail_ << '\n';
	while (xchain != 1) {
		auto S = get_range(tin[head_], tin[tail_], 1, 1, n);
		int val[] = {min(S.val[0][0], S.val[0][1]), min(S.val[1][0], S.val[1][1])};
		int upd[] = {-up_par[tin[head_]][0] + min(val[0], val[1] + 1), 
								 -up_par[tin[head_]][1] + min(val[0] + 1, val[1])};
		// S.debug();
		// fprintf(stderr, "update_val(%d, {%d, %d})\n", tin[par[head_]], upd[0], upd[1]);
		update_val(tin[par[head_]], upd, 1, 1, n);
		dp[tin[par[head_]]][0] += upd[0];
		dp[tin[par[head_]]][1] += upd[1];
		up_par[tin[head_]][0] = min(val[0], val[1] + 1);
		up_par[tin[head_]][1] = min(val[0] + 1, val[1]);
		xchain = chain[par[head_]];
		tail_ = tail[xchain];
		head_ = head[xchain];
	}
	// cerr << "hehehehe " << tin[head_] << ' ' << tin[tail_] << '\n';
	return get_range(tin[head_], tin[tail_], 1, 1, n).res();
}

int solve(int s, int c) {
	// cerr << "solve s = " << tin[s] << ", c = " << (c == 1 ? 1 : (c == 2 ? 0 : -1)) << '\n';
	// fprintf(stderr, "invalid_color(%d, %d)\n", tin[s], c);
	invalid_color(tin[s], c, 1, 1, n);
	return get(s);
}

int cat(int v) {
  return solve(v, 2);
}

int dog(int v) {
  return solve(v, 1);
}

int neighbor(int v) {
  return solve(v, 0);
}

#ifdef LOCAL
int readInt(){
	int i;
	if(scanf("%d",&i)!=1){
		fprintf(stderr,"Error while reading input\n");
		exit(1);
	}
	return i;
}

int main(){
	int N=readInt();
	
	std::vector<int> A(N-1),B(N-1);
	for(int i=0;i<N-1;i++)
	{
		A[i]=readInt();
		B[i]=readInt();
	}
	int Q;
	assert(scanf("%d",&Q)==1);
	std::vector <int> T(Q),V(Q);
	for(int i=0;i<Q;i++)
	{
		T[i]=readInt();
		V[i]=readInt();
	}
	
	initialize(N,A,B);
	
	// int db[] = {1, 0};
	// update_val(2, db, 1, 1, n);
	// invalid_color(1, 2, 1, 1, n);
	// get_range(1, 2, 1, 1, n).res();
// 	
	// return 0;
	
	// for (int i = 1; i <= tchain; i++)
		// cerr << "chain " << i << " = [" << tin[head[i]] << ", " << tin[tail[i]] << "]\n";
	
	std::vector<int> res(Q);
	for(int j=0;j<Q;j++)
	{
		if(T[j]==1) res[j]=cat(V[j]);
		else if(T[j]==2) res[j]=dog(V[j]);
		else res[j]=neighbor(V[j]);
	}
	// get_range(2, 2, 1, 1, n).debug();
	for(int j=0;j<Q;j++)
		printf("%d\n",res[j]);
	// for (int x = 1; x <= N; x++)
		// cerr << x << ' ' << chain[x] << ' ' << head[chain[x]] << ' ' << tail[chain[x]] << '\n';
	// get_range(2, 2, 1, 1, n).debug();
	return 0;
}
#endif
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...