답안 #801882

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
801882 2023-08-02T08:14:28 Z GEN 이지후(#10129) Cultivation (JOI17_cultivation) C++17
0 / 100
46 ms 71872 KB
#include <bits/stdc++.h>
using namespace std;
using lint = long long;
using pi = array<lint, 2>;
#define sz(v) ((int)(v).size())
#define all(v) (v).begin(), (v).end()
const int MAXN = 300005;

namespace fastio {
static constexpr uint32_t SZ = 1 << 17;
char ibuf[SZ];
char obuf[SZ];
uint32_t pil = 0, pir = 0, por = 0;

struct Pre {
	char num[40000];
	constexpr Pre() : num() {
		for (int i = 0; i < 10000; i++) {
			int n = i;
			for (int j = 3; j >= 0; j--) {
				num[i * 4 + j] = n % 10 + '0';
				n /= 10;
			}
		}
	}
} constexpr pre;

__attribute__((target("avx2"), optimize("O3"))) inline void load() {
	memcpy(ibuf, ibuf + pil, pir - pil);
	pir = pir - pil + fread(ibuf + pir - pil, 1, SZ - pir + pil, stdin);
	pil = 0;
}

__attribute__((target("avx2"), optimize("O3"))) inline void flush() {
	fwrite(obuf, 1, por, stdout);
	por = 0;
}

inline void rd(char &c) { c = ibuf[pil++]; }

template <typename T> __attribute__((target("avx2"), optimize("O3"))) inline void rd(T &x) {
	if (pil + 32 > pir)
		load();
	char c;
	do
		rd(c);
	while (c < '-');
	bool minus = 0;
	if constexpr (is_signed<T>::value) {
		if (c == '-') {
			minus = 1;
			rd(c);
		}
	}
	x = 0;
	while (c >= '0') {
		x = x * 10 + (c & 15);
		rd(c);
	}
	if constexpr (is_signed<T>::value) {
		if (minus)
			x = -x;
	}
}

inline void wt(char c) { obuf[por++] = c; }
template <typename T> __attribute__((target("avx2"), optimize("O3"))) inline void wt(T x) {
	if (por + 32 > SZ)
		flush();
	if (!x) {
		wt('0');
		return;
	}
	if constexpr (is_signed<T>::value) {
		if (x < 0) {
			wt('-');
			x = -x;
		}
	}
	if (x >= 10000000000000000) {
		uint32_t r1 = x % 100000000;
		uint64_t q1 = x / 100000000;
		if (x >= 1000000000000000000) {
			uint32_t n1 = r1 % 10000;
			uint32_t n2 = r1 / 10000;
			uint32_t n3 = q1 % 10000;
			uint32_t r2 = q1 / 10000;
			uint32_t n4 = r2 % 10000;
			uint32_t q2 = r2 / 10000;
			memcpy(obuf + por + 15, pre.num + (n1 << 2), 4);
			memcpy(obuf + por + 11, pre.num + (n2 << 2), 4);
			memcpy(obuf + por + 7, pre.num + (n3 << 2), 4);
			memcpy(obuf + por + 3, pre.num + (n4 << 2), 4);
			memcpy(obuf + por, pre.num + (q2 << 2) + 1, 3);
			por += 19;
		} else if (x >= 100000000000000000) {
			uint32_t n1 = r1 % 10000;
			uint32_t n2 = r1 / 10000;
			uint32_t n3 = q1 % 10000;
			uint32_t r2 = q1 / 10000;
			uint32_t n4 = r2 % 10000;
			uint32_t q2 = r2 / 10000;
			uint32_t q3 = (q2 * 205) >> 11;
			uint32_t r3 = q2 - q3 * 10;
			memcpy(obuf + por + 14, pre.num + (n1 << 2), 4);
			memcpy(obuf + por + 10, pre.num + (n2 << 2), 4);
			memcpy(obuf + por + 6, pre.num + (n3 << 2), 4);
			memcpy(obuf + por + 2, pre.num + (n4 << 2), 4);
			obuf[por + 1] = '0' + r3;
			obuf[por + 0] = '0' + q3;
			por += 18;
		} else {
			uint32_t n1 = r1 % 10000;
			uint32_t n2 = r1 / 10000;
			uint32_t n3 = static_cast<uint32_t>(q1) % 10000;
			uint32_t r2 = static_cast<uint32_t>(q1) / 10000;
			uint32_t n4 = r2 % 10000;
			uint32_t q2 = r2 / 10000;
			memcpy(obuf + por + 13, pre.num + (n1 << 2), 4);
			memcpy(obuf + por + 9, pre.num + (n2 << 2), 4);
			memcpy(obuf + por + 5, pre.num + (n3 << 2), 4);
			memcpy(obuf + por + 1, pre.num + (n4 << 2), 4);
			obuf[por + 0] = '0' + q2;
			por += 17;
		}
	} else {
		int i = 8;
		char buf[12];
		while (x >= 10000) {
			memcpy(buf + i, pre.num + (x % 10000) * 4, 4);
			x /= 10000;
			i -= 4;
		}
		if (x < 100) {
			if (x < 10) {
				wt(char('0' + x));
			} else {
				obuf[por + 0] = '0' + x / 10;
				obuf[por + 1] = '0' + x % 10;
				por += 2;
			}
		} else {
			if (x < 1000) {
				memcpy(obuf + por, pre.num + (x << 2) + 1, 3);
				por += 3;
			} else {
				memcpy(obuf + por, pre.num + (x << 2), 4);
				por += 4;
			}
		}
		memcpy(obuf + por, buf + i + 4, 8 - i);
		por += 8 - i;
	}
}

struct Dummy {
	Dummy() { atexit(flush); }
} dummy;

} // namespace fastio

using fastio::rd;
using fastio::wt;
int n, a[MAXN], f[MAXN];
lint sm[MAXN];
int par[MAXN];
int root[MAXN];
vector<int> gph[MAXN];

void dfs(int x, int p) {
	if (p == -1)
		root[x] = 1, par[x] = -1;
	for (auto &y : gph[x]) {
		if (y != p) {
			if (a[y] <= a[x] && a[y] > 0)
				root[y] = 1;
			if (a[y] > a[x] + 1)
				root[y] = 1;
			par[y] = x;
			dfs(y, x);
		}
	}
}

struct DP {
	map<lint, lint> incrBy;
	map<lint, lint> equals;
	pair<lint, lint> totMin() {
		lint totmin = 1e18;
		lint sum = 0;
		auto it1 = equals.end();
		auto it2 = incrBy.end();
		while (it1 != equals.begin()) {
			it1--;
			while (it2 != incrBy.begin() && (*prev(it2)).first >= (*it1).first) {
				it2--;
				totmin = min(totmin, sum + sm[(*it2).first + 1]);
				sum += it2->second;
			}
			totmin = min(totmin, f[it1->first] - it1->second + sum);
		}
		while (it2 != incrBy.begin()) {
			it2--;
			totmin = min(totmin, sum + sm[(*it2).first + 1]);
			sum += it2->second;
		}
		totmin = min(totmin, sum + sm[1]);
		return make_pair(totmin, sum);
	}
	void print() {
		return;
		cout << "incrBy" << endl;
		for (auto &[k, v] : incrBy)
			cout << k << " " << v << endl;
		cout << "equals" << endl;
		for (auto &[k, v] : equals)
			cout << k << " " << v << endl;
		auto [tot, sum] = totMin();
		cout << "totmin = " << tot << " sum = " << sum << endl;
	}
};

DP v[MAXN];
int idx[MAXN];

void solve(int x, int p) {
	idx[x] = x;
	vector<int> down;
	for (auto &y : gph[x]) {
		if (y != p && !root[y]) {
			solve(y, x);
			if (a[x] + 1 == a[y]) {
				auto [totmin, sum] = v[idx[y]].totMin();
				auto it1 = v[idx[y]].incrBy.begin();
				auto it2 = v[idx[y]].equals.begin();
				while (totmin < sum) {
					lint delta = sum - totmin;
					auto val = *it1;
					it1 = v[idx[y]].incrBy.erase(it1);
					sum -= min(val.second, delta);
					val.second -= min(val.second, delta);
					if (val.second)
						v[idx[y]].incrBy.insert(val);
					while (it2 != v[idx[y]].equals.end() && it2->first <= val.first) {
						if (it2->second <= delta) {
							it2 = v[idx[y]].equals.erase(it2);
						} else {
							auto val = *it2;
							val.second -= delta;
							v[idx[y]].equals.erase(it2);
							v[idx[y]].equals.insert(val);
							it2 = v[idx[y]].equals.upper_bound(val.first);
						}
					}
				}
			} else {
				auto [totmin, sum] = v[idx[y]].totMin();
				lint costEq = 0;
				{
					auto it = v[idx[y]].incrBy.end();
					while (it != v[idx[y]].incrBy.begin()) {
						it--;
						if (it->first >= a[x] + 1)
							costEq += it->second;
						else
							break;
					}
				}
				if (v[idx[y]].equals.count(a[x] + 1))
					costEq -= v[idx[y]].equals[a[x] + 1];
				v[idx[y]].incrBy.clear();
				v[idx[y]].equals.clear();
				if (totmin > costEq)
					v[idx[y]].equals[a[x] + 1] = totmin - costEq;
				v[idx[y]].incrBy[n] = totmin;
			}
			down.push_back(idx[y]);
		}
	}
	if (sz(down) == 0) {
		v[idx[x]].incrBy[a[x]] = 1e18;
		return;
	}
	sort(all(down), [&](int p, int q) { return sz(v[p].equals) + sz(v[p].incrBy) > sz(v[q].equals) + sz(v[q].incrBy); });
	idx[x] = down[0];
	// generate geqs
	{
		for (int i = 1; i < sz(down); i++) {
			for (auto &[k, va] : v[down[i]].incrBy) {
				if (k > a[x])
					v[idx[x]].incrBy[k] += va;
			}
			for (auto &[k, va] : v[down[i]].equals) {
				if (k > a[x])
					v[idx[x]].equals[k] += va;
			}
		}
	}
	// generate eqs (mutate if u like it)
	{
		auto it = v[idx[x]].incrBy.begin();
		while (it != v[idx[x]].incrBy.end()) {
			if (it->first <= a[x]) {
				it = v[idx[x]].incrBy.erase(it);
			} else
				break;
		}
	}
	{
		auto it = v[idx[x]].equals.begin();
		while (it != v[idx[x]].equals.end()) {
			if (it->first <= a[x]) {
				it = v[idx[x]].equals.erase(it);
			} else
				break;
		}
	}
	v[idx[x]].incrBy[a[x]] = 1e18;
}

int main() {
	ios_base::sync_with_stdio(0);
	cin.tie(0);
	cout.tie(0);
	rd(n);
	for (int i = 0; i < n; i++)
		rd(a[i]);
	for (int i = 1; i <= n; i++) {
		rd(f[i]);
	}
	sm[n + 1] = 1e18;
	for (int i = n; i; i--) {
		sm[i] = min(sm[i + 1], 1ll * f[i]);
	}
	for (int i = 0; i < n - 1; i++) {
		int u, v;
		rd(u);
		rd(v);
		u--;
		v--;
		gph[u].push_back(v);
		gph[v].push_back(u);
	}
	dfs(0, -1);
	lint ans = 0;
	for (int i = 0; i < n; i++) {
		if (root[i]) {
			solve(i, par[i]);
			lint totmin = v[idx[i]].totMin().first;
			ans += totmin;
		}
	}
	wt(ans);
}
# 결과 실행 시간 메모리 Grader output
1 Incorrect 17 ms 35540 KB Output isn't correct
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 17 ms 35540 KB Output isn't correct
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 17 ms 35540 KB Output isn't correct
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Runtime error 46 ms 71872 KB Execution killed with signal 6
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Runtime error 46 ms 71872 KB Execution killed with signal 6
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 17 ms 35540 KB Output isn't correct
2 Halted 0 ms 0 KB -