답안 #558466

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
558466 2022-05-07T12:20:05 Z DanShaders Sprinkler (JOI22_sprinkler) C++17
0 / 100
4000 ms 636004 KB
//bs:sanitizers
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
using namespace std;

namespace x = __gnu_pbds;
template <typename T>
using ordered_set = x::tree<T, x::null_type, less<T>, x::rb_tree_tag, x::tree_order_statistics_node_update>;

template <typename T>
using normal_queue = priority_queue<T, vector<T>, greater<>>;

#define all(x) begin(x), end(x)
#define sz(x) ((int) (x).size())
#define x first
#define y second
using ll = long long;
using ld = long double;

const int N = 2e5 + 10, EULER = 2 * N, LOG = 19, DIFF = 9;

vector<int> g[N];
int h[N], sz[N];
char used[N];

struct CentroidNode {
	int parent;
	vector<pair<int, int>> od, pd;
	vector<int> op[DIFF], pp[DIFF];
	vector<ll> oi, pi;
} tc[N];

int dfs_sz(int u, int p = -1) {
	if (used[u]) {
		return 0;
	}
	sz[u] = 1;
	for (int v : g[u]) {
		if (v == p) {
			continue;
		}
		sz[u] += dfs_sz(v, u);
	}
	return sz[u];
}

int dfs_find_centroid(int u, int csz, int p = -1) {
	for (int v : g[u]) {
		if (v != p && !used[v] && 2 * sz[v] > csz) {
			return dfs_find_centroid(v, csz, u);
		}
	}
	return u;
}

int croot;

void dfs_centroid(int root, int parent = -1) {
	int centroid = dfs_find_centroid(root, dfs_sz(root));
	auto &node = tc[centroid];
	node.parent = parent;
	if (parent == -1) {
		croot = centroid;
	}

	normal_queue<tuple<int, int, int>> bfsp;
	bfsp.push({1, root, -1});
	while (sz(bfsp)) {
		auto [d, u, p] = bfsp.top();
		bfsp.pop();
		node.pd.push_back({d, u});
		for (int v : g[u]) {
			if (!used[v] && v != p) {
				bfsp.push({d + 1, v, u});
			}
		}
	}
	normal_queue<tuple<int, int, int>> bfso;
	bfso.push({0, centroid, -1});
	while (sz(bfso)) {
		auto [d, u, p] = bfso.top();
		bfso.pop();
		node.od.push_back({d, u});
		for (int v : g[u]) {
			if (!used[v] && v != p) {
				bfso.push({d + 1, v, u});
			}
		}
	}
	
	for (int i = 0; i < DIFF; ++i) {
		node.op[i].resize(sz(node.od));
		node.pp[i].resize(sz(node.od));
	}
	node.oi.resize(sz(node.od), 1);
	node.pi.resize(sz(node.od), 1);

	used[centroid] = 1;
	for (int v : g[centroid]) {
		if (!used[v]) {
			dfs_centroid(v, centroid);
		}
	}
}

int ipow[EULER], depth[N];
pair<int, int> sp[LOG][EULER];
int order[N], timer = 0;

void dfs_euler(int u, int d = 0, int p = -1) {
	order[u] = timer;
	depth[u] = d;
	sp[0][timer++] = {d, u};
	for (int v : g[u]) {
		if (v != p) {
			dfs_euler(v, d + 1, u);
			sp[0][timer++] = {d, u};
		}
	}
}

int lca(int u, int v) {
	if (u == v) {
		return u;
	}
	u = order[u];
	v = order[v];
	if (u > v) {
		swap(u, v);
	}
	++v;
	int pw = ipow[v - u];
	return min(sp[pw][u], sp[pw][v - (1 << pw)]).y;
}

int dist(int u, int v) {
	return depth[u] + depth[v] - 2 * depth[lca(u, v)];
}

vector<pair<int, int>> factor(int x) {
	vector<pair<int, int>> res;
	for (int i = 2; i * i <= x; ++i) {
		if (x % i == 0) {
			res.push_back({i, 0});
			while (x % i == 0) {
				++res.back().y;
				x /= i;
			}
		}
	}
	if (x != 1) {
		res.push_back({x, 1});
	}
	return res;
}

pair<vector<int>, int> get_pw(int x, const vector<pair<int, int>> &fact) {
	vector<int> pw;
	for (auto [prime, _] : fact) {
		pw.push_back(0);
		while (x % prime == 0) {
			x /= prime;
			++pw.back();
		}
	}
	return {pw, x};
}

void exgcd(int a, int b, ll &x, ll &y) {
	if (b == 0) {
		x = 1;
		y = 0;
		return;
	}
	exgcd(b, a % b, x, y);
	ll nw = x - (a / b) * y;
	x = y;
	y = nw;
}

int diff;
ll l;

void apply_for(const vector<int> &part, int ineq, int inv, int u, int x, int d) {
	auto &node = tc[u];

	int dst = d - dist(u, x);
	int bound = int(lower_bound(all(node.od), pair{dst + 1, -1}) - begin(node.od));
	for (int i = 0; i < diff; ++i) {
		if (part[i] == 0) {
			continue;
		}
		for (int j = 0; j < bound; ++j) {
			node.op[i][j] += part[i];
		}
	}
	if (ineq != 1) {
		for (int j = 0; j < bound; ++j) {
			(node.oi[j] *= ineq) %= l;
		}
	}

	if (node.parent == -1) {
		return;
	}

	dst = d - dist(node.parent, x);
	bound = int(lower_bound(all(node.pd), pair{dst + 1, -1}) - begin(node.pd));
	for (int i = 0; i < diff; ++i) {
		if (part[i] == 0) {
			continue;
		}
		for (int j = 0; j < bound; ++j) {
			node.pp[i][j] -= part[i];
		}
	}
	if (inv != 1) {
		for (int j = 0; j < bound; ++j) {
			(node.pi[j] *= inv) %= l;
		}
	}
}

void count_for(vector<int> &part, ll &ineq, int u, int x) {
	const auto &node = tc[u];

	int i = int(lower_bound(all(node.od), pair{dist(u, x), x}) - begin(node.od));

	// cout << "at " << u << " with " << x << endl;
	// for (auto [d, v] : node.od) {
	// 	cout << d << ":" << v << " ";
	// }
	// cout << endl;

	for (int j = 0; j < diff; ++j) {
		part[j] += node.op[j][i];
	}
	(ineq *= node.oi[i]) %= l;

	if (node.parent == -1) {
		return;
	}

	i = int(lower_bound(all(node.pd), pair{dist(node.parent, x), x}) - begin(node.pd));
	for (int j = 0; j < diff; ++j) {
		part[j] += node.pp[j][i];
	}
	(ineq *= node.pi[i]) %= l;
}

int fpow(int a, int b) {
	int c = 1;
	for (int i = 1; i <= b; i *= 2) {
		if (b & i) {
			(c *= a) %= l;
		}
		(a *= a) %= l;
	}
	return c;
}

signed main() {
	cin.tie(0)->sync_with_stdio(0);
	int n;
	cin >> n >> l;
	for (int i = 1; i < n; ++i) {
		int u, v;
		cin >> u >> v;
		g[--u].push_back(--v);
		g[v].push_back(u);
	}
	for (int i = 0; i < n; ++i) {
		cin >> h[i];
		if (h[i] == 0) {
			h[i] = int(l);
		}
	}
	dfs_euler(0);
	for (int i = 1; i < LOG; ++i) {
		for (int j = 0; j <= timer - (1 << i); ++j) {
			sp[i][j] = min(sp[i - 1][j], sp[i - 1][j + (1 << (i - 1))]);
		}
	}
	for (int i = 2; i <= timer; ++i) {
		ipow[i] = ipow[i / 2] + 1;
	}
	dfs_centroid(0);
	int queries;
	cin >> queries;
	auto fact = factor(int(l));
	diff = sz(fact);

	auto &node = tc[croot];
	for (int i = 0; i < n; ++i) {
		auto [part, ineq] = get_pw(h[node.od[i].y], fact);

		for (int j = 0; j < diff; ++j) {
			node.op[j][i] += part[j];
		}
		(node.oi[i] *= ineq) %= l;
	}

	while (queries--) {
		int type;
		cin >> type;
		if (type == 1) {
			int x, d, w;
			cin >> x >> d >> w;
			--x;
			if (!w) {
				w = int(l);
			}
			auto [part, ineq] = get_pw(w, fact);
			ll inv, tmp;
			exgcd(ineq, int(l), inv, tmp);
			inv = (inv % l + l) % l;

			int curr = x;
			while (curr != -1) {
				apply_for(part, ineq, int(inv), curr, x, d);
				curr = tc[curr].parent;
			}
		} else {
			int x;
			cin >> x;
			--x;
			vector<int> part(diff);
			ll ineq = 1;
			int curr = x;
			while (curr != -1) {
				count_for(part, ineq, curr, x);
				curr = tc[curr].parent;
			}

			ll res = 1;
			for (int i = 0; i < diff; ++i) {
				(res *= fpow(fact[i].x, part[i])) %= l;
			}
			(res *= ineq) %= l;
			cout << res << "\n";
			// for (int u : part) {
			// 	cout << u << " ";
			// }
			// cout << ineq << "\n";
		}
	}
}
# 결과 실행 시간 메모리 Grader output
1 Correct 53 ms 109896 KB Output is correct
2 Correct 54 ms 109936 KB Output is correct
3 Correct 53 ms 109908 KB Output is correct
4 Incorrect 57 ms 111688 KB Output isn't correct
5 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 50 ms 109900 KB Output is correct
2 Execution timed out 4106 ms 545948 KB Time limit exceeded
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 50 ms 109900 KB Output is correct
2 Execution timed out 4106 ms 545948 KB Time limit exceeded
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 59 ms 109880 KB Output is correct
2 Execution timed out 4121 ms 636004 KB Time limit exceeded
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 59 ms 109960 KB Output is correct
2 Execution timed out 4090 ms 631544 KB Time limit exceeded
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 53 ms 109896 KB Output is correct
2 Correct 54 ms 109936 KB Output is correct
3 Correct 53 ms 109908 KB Output is correct
4 Incorrect 57 ms 111688 KB Output isn't correct
5 Halted 0 ms 0 KB -