Submission #1242528

#TimeUsernameProblemLanguageResultExecution timeMemory
1242528errayTree (IOI24_tree)C++20
0 / 100
61 ms23856 KiB
#include <bits/stdc++.h>

#include "tree.h"

using namespace std;

#ifdef DEBUG
	#include "debug.h"
#else 
	#define debug(...) void(37)
#endif

struct DSU {
	vector<int> link, leaves;
	DSU(vector<int> _leaf_count) {
		leaves = _leaf_count;
		int n = int(leaves.size());
		link.resize(n);
		iota(link.begin(), link.end(), 0);
	}
	int get(int v) {
		return link[v] = link[v] == v ? v : get(link[v]);
	}
	bool unite(int x, int y) {
		x = get(x), y = get(y);
		if (x == y) return false;
		link[y] = x;
		leaves[x] += leaves[y] - 1;
		return true;
	}
	int leaf_count(int v) {
		return leaves[get(v)];
	}
};

struct linear_sum {
	int64_t L_coeff, R_coeff;
	int64_t eval(int L, int R) {
		return L_coeff * L + R_coeff * R;
	}
	void init() {
		L_coeff = 0, R_coeff = 0;
	}
};
linear_sum operator+(linear_sum l, linear_sum r) {
	l.L_coeff += r.L_coeff;
	l.R_coeff += r.R_coeff;
	return l;
}

vector<linear_sum> pref_sums, suf_sums;
vector<int> placer;
int64_t leaf_sums;

void init(std::vector<int> P, std::vector<int> W) {
	int N = int(P.size());
	vector<int> degree(N);
	for (int i = 1; i < N; ++i) degree[P[i]]++;
	for (int i = 0; i < N; ++i) {
		if (degree[i] == 0) {
			degree[i] = 1;
			leaf_sums += W[i];
		}
	}
	DSU dsu(degree);
	vector<int> node_ord(N); iota(node_ord.begin(), node_ord.end(), 0);
	sort(node_ord.begin(), node_ord.end(), [&](int x, int y) {
		return W[x] > W[y];
	});
	vector<vector<int>> waiting(N);
	vector<bool> act(N);
	vector<array<int, 2>> sums;
	for (auto v : node_ord) {
		act[v] = true;
		for (auto u : waiting[v]) {
			dsu.unite(v, u);
		}
		int subtree = dsu.leaf_count(v);
		if (P[v] != -1) {
			if (act[P[v]]) dsu.unite(P[v], v);
			else waiting[P[v]].push_back(v);
		}
		int root = dsu.leaf_count(v);
		sums.push_back({subtree, root});
	}
	auto Get = [&](array<int, 2> v) -> tuple<int, linear_sum, linear_sum> {
		int s = v[0], r = v[1];
		return {r - s + 1, linear_sum{r, -1}, linear_sum{s - 1, 0}};
	};
	sort(sums.begin(), sums.end(), [&](array<int, 2> x, array<int, 2> y) {
		return get<0>(Get(x)) < get<0>(Get(y));
	});
	debug(sums);
	pref_sums.resize(N + 1), suf_sums.resize(N + 1);
	pref_sums[0].init(), suf_sums[N].init();
	placer.resize(N);
	for (int i = 0; i < N; ++i) {
		auto[cp, b0, b1] = Get(sums[i]);
		placer[i] = cp;
		pref_sums[i + 1] = pref_sums[i] + b0;
	}
	for (int i = N - 1; i >= 0; --i) {
		suf_sums[i] = suf_sums[i + 1] + get<2>(Get(sums[i]));
	}
}

long long query(int L, int R) {
	int sep = int(lower_bound(placer.begin(), placer.end(), array<int, 2>{L, R}, [&](int c, array<int, 2> lr) {
		return int64_t(lr[0]) * c < lr[1];
	}) - placer.begin());
	return leaf_sums * L + pref_sums[sep].eval(L, R) + suf_sums[sep].eval(L, R);
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...