Submission #1150552

#TimeUsernameProblemLanguageResultExecution timeMemory
1150552pinbuCats or Dogs (JOI18_catdog)C++20
100 / 100
135 ms28232 KiB
#include "catdog.h"
#include <bits/stdc++.h>
using namespace std;

const int MxN = 100005, oo = 1e9 + 7;
void mini(int &X, int Y) {
    if (X > Y) X = Y;
}

using node = array<array<int, 2>, 2>;
node operator + (const node &A, const node &B) {
    node res = {{{oo, oo}, {oo, oo}}};
    for (int i = 0; i < 2; i++) {
        for (int j = 0; j < 2; j++) {
            for (int k = 0; k < 2; k++) {
                for (int l = 0; l < 2; l++) {
                    mini(res[i][j], A[i][k] + B[l][j] + (k ^ l));
                }
            }
        }
    }
    return res;
}
struct ST {
	int n;
	vector<node> nut;
	void build(int id, int l, int r) {
		if (l == r) {
			nut[id][0][0] = nut[id][1][1] = 0;
			nut[id][0][1] = nut[id][1][0] = oo;
			return;
		}
		
		int mid = l + r >> 1;
		build(id << 1, l, mid);
		build(id << 1 | 1, mid + 1, r);
		nut[id] = nut[id << 1] + nut[id << 1 | 1];
	}
	void rsz(int sz) {
		n = sz;
		nut.resize(4 * sz);
		build(1, 1, sz);
	}
	void update(int i, int v1, int v2, int id, int l, int r) {
		if (l == r) {
			nut[id][0][0] += v1;
			nut[id][1][1] += v2;
			return;
		}
		
		int mid = l + r >> 1;
		i <= mid ? update(i, v1, v2, id << 1, l, mid) : update(i, v1, v2, id << 1 | 1, mid + 1, r);
		nut[id] = nut[id << 1] + nut[id << 1 | 1];
	}
	void update(int i, int v1, int v2) {
		update(i, v1, v2, 1, 1, n);
	}
	pair<int, int> get(void) {
		node t = nut[1];
		int a = t[0][0], b = t[0][1], c = t[1][0], d = t[1][1];
		return {min({a, b, c + 1, d + 1}), min({a + 1, b + 1, c, d})};
	}
} tri[MxN];

vector<int> adj[MxN];
int par[MxN], sz[MxN];
void DFS(int u, int p) {
	sz[u] = 1;
	pair<int, int> best = {-1, -1};
	int i = -1;
	for (int v: adj[u]) {
		i++;
		if (v == p) continue;
		par[v] = u;
		DFS(v, u);
		sz[u] += sz[v];
		best = max(best, {sz[v], i});
	}
	if (best.first > 0) swap(adj[u][0], adj[u][best.second]);
}
int head[MxN], st[MxN], timer = 0, cnt[MxN];
void HLD(int u, int p, int h) {
	head[u] = h;
	cnt[h]++;
	st[u] = ++timer;
	int nxt = -1;
	if (adj[u].size() && adj[u][0] != par[u]) nxt = adj[u][0];
	if (nxt > 0) HLD(nxt, u, h);
	for (int v: adj[u]) if (v != nxt && v != p) HLD(v, u, v);
}
void initialize(int n, std::vector<int> A, std::vector<int> B) {
    for (int i = 0; i < n - 1; i++) {
        int u = A[i], v = B[i];
        adj[u].emplace_back(v);
        adj[v].emplace_back(u);
    }
    DFS(1, 0);
    HLD(1, 0, 1);
    for (int h = 1; h <= n; h++) if (cnt[h]) tri[h].rsz(cnt[h]);
}


int cur[MxN];
int update(int u, int c) {
	int uwu = u;
	int v1, v2;
	if (c == 1) tie(v1, v2) = make_tuple(0, oo);
	else if (c == 2) tie(v1, v2) = make_tuple(oo, 0);
	else tie(v1, v2) = make_tuple(-(cur[u] == 2) * oo, -(cur[u] == 1) * oo);
	while (u) {
		int h = head[u];
		pair<int, int> p1 = tri[h].get();
		tri[h].update(st[u] - st[h] + 1, v1, v2);
		pair<int, int> p2 = tri[h].get();
		v1 = p2.first - p1.first; v2 = p2.second - p1.second;
		u = par[h];
	}
	cur[uwu] = c;
	node t = tri[1].nut[1];
	return min({t[0][0], t[0][1], t[1][0], t[1][1]});
}
int cat(int u) {
	return update(u, 1);
}
int dog(int u) {
	return update(u, 2);
}
int neighbor(int u) {
	return update(u, 3);
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...