Submission #1316481

#TimeUsernameProblemLanguageResultExecution timeMemory
1316481pvproJOI tour (JOI24_joitour)C++20
48 / 100
2918 ms374432 KiB
#ifndef LOCAL
#pragma GCC Optimize("O3,Ofast,unroll-loops")
#pragma GCC Target("bmi,bmi2,avx,avx2")
#endif
#include <bits/stdc++.h>

using namespace std;
using ll = long long;
using ld = long double;

#define f first 
#define s second 
#define mp make_pair 
#define pb push_back
#define pii pair<int, int>
#define all(x) (x).begin(), (x).end()
#define rall(x) (x).rbegin() (x).rend()
#ifndef LOCAL
#define endl "\n"
#endif

mt19937 rnd(11);

const int LOG_N = 18;

struct F {
	vector<int> t;
	F() = default;
	int get(int r) {
		int ans = 0;
		for (; r >= 0; r = (r&(r + 1)) - 1) {
			ans += t[r];
		}
		return ans;
	}
	void upd(int i, int x) {
		for (; i < t.size(); i = (i|(i + 1))) {
			t[i] += x;
		}
	}
};

vector<vector<int>> graph;
vector<int> CP[LOG_N], lvl, sz, f, tin, tout, binup[LOG_N], tin1[LOG_N], tout1[LOG_N], fst[LOG_N], zerotwo, act;
vector<F> T[6], two, zero;
ll ans = 0;
int Tm = 0;

void dfs(int v, int prev = 0) {
	binup[0][v] = prev;
	tin[v] = Tm++;
	for (auto &u : graph[v]) {
		if (tin[u] == -1) {
			dfs(u, v);
		}
	}
	tout[v] = Tm;
}

bool inside(int a, int b) {
	return tin[a] <= tin[b] && tout[b] <= tout[b];
}

int lca(int a, int b) {
	if (inside(a, b)) {
		return a;
	}
	for (int lg = LOG_N - 1; lg >= 0; --lg) {
		if (!inside(binup[lg][a], b)) {
			a = binup[lg][a];
		}
	}
	return binup[0][a];
}

int calcUp(int a, int t) {
	int ans = 0;
	while (a != 0) {
		ans += (f[a] == t);
		a = binup[0][a];
	}
	return ans + (f[a] == t);
}

int calcWay(int a, int b, int t) {
	int lc = lca(a, b);
	int ans = calcUp(a, t) + calcUp(b, t) - calcUp(lc, t) * 2;
	if (f[lc] == t) {
		++ans;
	}
	return ans;
}

ll sumup(int v, int lg, int t)  {
	return T[t][CP[lg][v]].get(tout1[lg][v]) - T[3 + t][CP[lg][v]].get(tout1[lg][v]);
}

ll sumsub(int v, int lg, int t)  {
	return T[t][CP[lg][v]].get(tout1[lg][v] - 1) - T[t][CP[lg][v]].get(tin1[lg][v] - 1);
}

ll calc(int v) {
	ll ans = 0;
	for (int lg = lvl[v] - 1; lg >= 0; --lg) {
		if (f[v] == 0) {
			ll onenum = sumup(v, lg, 1);
			ll twonum = sumsub(CP[lg][v], lg, 2) - sumsub(fst[lg][v], lg, 2);
			ans += onenum * twonum;
			ans += two[CP[lg][v]].get(tout1[lg][CP[lg][v]] - 1);
			ans -= two[CP[lg][v]].get(tout1[lg][fst[lg][v]] - 1) - two[CP[lg][v]].get(tin1[lg][fst[lg][v]] - 1);
		} else if (f[v] == 1) {
			ans += sumsub(v, lg, 0) * (sumsub(CP[lg][v], lg, 2) - sumsub(fst[lg][v], lg, 2));
			ans += sumsub(v, lg, 2) * (sumsub(CP[lg][v], lg, 0) - sumsub(fst[lg][v], lg, 0));
		} else {
			ll onenum = sumup(v, lg, 1);
			ll zeronum = sumsub(CP[lg][v], lg, 0) - sumsub(fst[lg][v], lg, 0);
			ans += onenum * zeronum;
			ans += zero[CP[lg][v]].get(tout1[lg][CP[lg][v]] - 1);
			ans -= zero[CP[lg][v]].get(tout1[lg][fst[lg][v]] - 1) - zero[CP[lg][v]].get(tin1[lg][fst[lg][v]] - 1);
		}
	}
	if (f[v] == 0) {
		ans += two[v].get(tout1[lvl[v]][v] - 1);
	} else if (f[v] == 2) {
		ans += zero[v].get(tout1[lvl[v]][v] - 1);
	} else {
		ans += zerotwo[v];
	}
	return ans;
}

void del(int v) {
	ans -= calc(v);
	for (int lg = lvl[v]; lg >= 0; --lg) {
		T[f[v]][CP[lg][v]].upd(tin1[lg][v], -1);
		T[f[v] + 3][CP[lg][v]].upd(tout1[lg][v], -1);
		if (lg != lvl[v]) {
			if (f[v] == 0) {
				zerotwo[CP[lg][v]] -= sumsub(CP[lg][v], lg, 2) - sumsub(fst[lg][v], lg, 2) - (f[CP[lg][v]] == 2 && act[CP[lg][v]]);
				zero[CP[lg][v]].upd(tin1[lg][v], -(sumup(v, lg, 1) - (f[CP[lg][v]] == 1 && act[CP[lg][v]])));
			} else if (f[v] == 2) {
				zerotwo[CP[lg][v]] -= sumsub(CP[lg][v], lg, 0) - sumsub(fst[lg][v], lg, 0) - (f[CP[lg][v]] == 0 && act[CP[lg][v]]);
				two[CP[lg][v]].upd(tin1[lg][v], -(sumup(v, lg, 1) - (f[CP[lg][v]] == 1 && act[CP[lg][v]])));
			} else {
				zero[CP[lg][v]].upd(tin1[lg][v], -sumsub(v, lg, 0));
				two[CP[lg][v]].upd(tin1[lg][v], -sumsub(v, lg, 2));
			}
		}
	}
	act[v] = false;
}

void add(int v) {
	for (int lg = lvl[v]; lg >= 0; --lg) {
		T[f[v]][CP[lg][v]].upd(tin1[lg][v], 1);
		T[f[v] + 3][CP[lg][v]].upd(tout1[lg][v], 1);
		if (lg != lvl[v]) {
			if (f[v] == 0) {
				zerotwo[CP[lg][v]] += sumsub(CP[lg][v], lg, 2) - sumsub(fst[lg][v], lg, 2) - (f[CP[lg][v]] == 2 && act[CP[lg][v]]);
				zero[CP[lg][v]].upd(tin1[lg][v], (sumup(v, lg, 1) - (f[CP[lg][v]] == 1 && act[CP[lg][v]])));
			} else if (f[v] == 2) {
				zerotwo[CP[lg][v]] += sumsub(CP[lg][v], lg, 0) - sumsub(fst[lg][v], lg, 0) - (f[CP[lg][v]] == 0 && act[CP[lg][v]]);
				two[CP[lg][v]].upd(tin1[lg][v], (sumup(v, lg, 1) - (f[CP[lg][v]] == 1 && act[CP[lg][v]])));
			} else {
				zero[CP[lg][v]].upd(tin1[lg][v], sumsub(v, lg, 0));
				two[CP[lg][v]].upd(tin1[lg][v], sumsub(v, lg, 2));
			}
		}
	}
	act[v] = true;
	ans += calc(v);
}

void init(int n, vector<int> F, vector<int> u, vector<int> v, int q) {
	f = F;
	zerotwo.resize(n);
	act.resize(n);
	graph.resize(n);
	lvl.assign(n, -1);
	binup[0].resize(n);
	tin.assign(n, -1);
	tout.resize(n);
	for (int i = 0; i < 6; ++i) {
		T[i].resize(n);
	}
	two.resize(n);
	zero.resize(n);
	for (int i = 0; i < n - 1; ++i) {
		graph[u[i]].pb(v[i]);
		graph[v[i]].pb(u[i]);
	}
	dfs(0);
	for (int l = 1; l < LOG_N; ++l) {
		binup[l].resize(n);
		for (int i = 0; i < n; ++i) {
			binup[l][i] = binup[l - 1][binup[l - 1][i]];
		}
	}
	int lg = 0;
	auto calcSz = [&](int v, int prev, auto &&self) -> void {
		sz[v] = 1;
		for (auto &u : graph[v]) {
			if (lvl[u] == -1 && u != prev) {
				self(u, v, self);
				sz[v] += sz[u];
			}
		}
	};
	auto findCenter = [&](int v, int prev, int Tsz, auto &&self) -> int {
		for (auto &u : graph[v]) {
			if (lvl[u] == -1 && u != prev && sz[u] * 2 > Tsz) {
				return self(u, v, Tsz, self);
			}
		}
		return v;
	};
	int Tm1;
	auto paint = [&](int v, int prev, int center, auto &&self) -> void {
		if (prev == center) {
			fst[lg][v] = v;
		} else {
			fst[lg][v] = fst[lg][prev];
		}
		tin1[lg][v] = Tm1++;
		CP[lg][v] = center;
		for (auto &u : graph[v]) {
			if (lvl[u] == -1 && u != prev) {
				self(u, v, center, self);
			}
		}
		tout1[lg][v] = Tm1++;
	};
	for (; lg < LOG_N; ++lg) {
		CP[lg].resize(n);
		fst[lg].resize(n);
		tin1[lg].resize(n);
		tout1[lg].resize(n);
		sz.assign(n, -1);
		for (int i = 0; i < n; ++i) {
			if (lvl[i] == -1 && sz[i] == -1) {
				Tm1 = 0;
				calcSz(i, i, calcSz);
				int center = findCenter(i, i, sz[i], findCenter);
				lvl[center] = lg;
				calcSz(center, center, calcSz);
				for (int j = 0; j < 6; ++j) {
					T[j][center].t.resize(sz[center] * 2);
				}
				two[center].t.resize(sz[center] * 2);
				zero[center].t.resize(sz[center] * 2);
				paint(center, center, center, paint);
			}
		}
	}
	for (int i = 0; i < n; ++i) {
		add(i);
	}
}

void change(int v, int x) {
	del(v);
	f[v] = x;
	add(v);
}

long long num_tours() {
	return ans;
}

#ifdef LOCAL
#include <cassert>
#include <cstdio>

int main() {
	freopen("in.txt", "r", stdin);
	freopen("out.txt", "w", stdout);
	
  int N;
  assert(scanf("%d", &N) == 1);

  std::vector<int> F(N);
  for (int i = 0; i < N; i++) {
    assert(scanf("%d", &F[i]) == 1);
  }

  std::vector<int> U(N - 1), V(N - 1);
  for (int j = 0; j < N - 1; j++) {
    assert(scanf("%d %d", &U[j], &V[j]) == 2);
  }

  int Q;
  assert(scanf("%d", &Q) == 1);

  init(N, F, U, V, Q);
  printf("%lld\n", num_tours());
  fflush(stdout);

  for (int k = 0; k < Q; k++) {
    int X, Y;
    assert(scanf("%d %d", &X, &Y) == 2);

    change(X, Y);
    printf("%lld\n", num_tours());
    fflush(stdout);
  }
}
#endif
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...