이 제출은 이전 버전의 oj.uz에서 채점하였습니다. 현재는 제출 당시와는 다른 서버에서 채점을 하기 때문에, 다시 제출하면 결과가 달라질 수도 있습니다.
#include <bits/stdc++.h>
using namespace std;
#define ar array
#define sz(v) int(std::size(v))
using pii = pair<int, int>;
const int N = 1e5, L = 17, M = 1e5, Q = 1e5;
int n, m, q;
vector<int> adj[N];
int c[N], tin[N], tt;
vector<int> et;
int depth[N], lg[N], st[N][L];
int ans[Q];
int low(int i, int j) {
return depth[i] < depth[j] ? i : j;
}
void dfs(int p, int i) {
tin[i] = tt++;
for (int j : adj[i]) if (p != j) {
depth[j] = depth[i] + 1;
et.push_back(i);
dfs(i, j);
}
}
int qlow(int i, int j) {
int l = lg[j - i + 1];
return low(st[i][l], st[j - (1 << l) + 1][l]);
}
int lca_(int i, int j) {
if (i == j) return i;
if (tin[i] > tin[j]) swap(i, j);
return qlow(tin[i], tin[j] - 1);
}
int lca(int i, int j) {
int p = lca_(i, j);
// cout << i+1 << " + " << j+1 << " => " << p+1 << '\n';
return p;
}
vector<pii> aux[N];
int tl[N], tr[N], up[N];
void ae(int i, int j, int d) {
// cout << "\tlink " << i+1 << ' ' << j+1 << ' ' << d << endl;
aux[i].push_back({j, d});
aux[j].push_back({i, d});
}
void dfs_t(int p, int i) {
for (auto [j, d] : aux[i]) if (p != j) {
up[j] = d;
dfs_t(i, j);
tl[i] = max(tl[i], tl[j]);
tr[i] = min(tr[i], tr[j]);
}
}
struct FT {
int t[M];
void upd(int i, int x) {
while (i < m) {
t[i] += x;
i |= i + 1;
}
}
int qry(int i) {
int x = 0;
while (i >= 0) {
x += t[i];
i &= i + 1, i--;
}
return x;
}
} ft;
void rec(int low, int hi, vector<ar<int, 3>> qr) {
if (sz(qr)) {
int mid = (low + hi) / 2;
vector<ar<int, 3>> me;
{
vector<ar<int, 3>> one, two;
for (auto [l, r, h] : qr)
if (r < mid) one.push_back({l, r, h});
else if (l > mid) two.push_back({l, r, h});
else me.push_back({l, r, h});
rec(low, mid, one);
rec(mid + 1, hi, two);
}
// cout << low << ' ' << hi << endl;
// for (auto [l, r, h] : qr) cout << "\t" << l << ' ' << r << ' ' << h << endl;
vector<int> v(c + low, c + hi + 1);
sort(begin(v), end(v), [&](int i, int j) { return tin[i] < tin[j]; });
v.erase(unique(begin(v), end(v)), end(v));
int base = sz(v);
for (int i = 0; i < base - 1; i++) v.push_back(lca(v[i], v[i + 1]));
sort(begin(v), end(v), [&](int i, int j) { return tin[i] < tin[j]; });
v.erase(unique(begin(v), end(v)), end(v));
for (int i = 1; i < sz(v); i++) {
int p = lca(v[i - 1], v[i]);
ae(p, v[i], depth[v[i]] - depth[p]);
}
for (int i : v) tl[i] = low - 1, tr[i] = hi + 1;
for (int i = low; i <= hi; i++) {
if (i <= mid) tl[c[i]] = max(tl[c[i]], i);
if (i >= mid) tr[c[i]] = min(tr[c[i]], i);
}
dfs_t(-1, c[mid]);
// cout << "ROOT AT " << c[mid]+1 << endl;
vector<int> one(mid - low + 1), two(hi - mid + 1);
vector<ar<int, 3>> upd;
for (int i : v) if (i != c[mid]) {
// cout << i+1 << " => " << tl[i] << ' ' << tr[i] << " | val = " << up[i] << endl;
if (tl[i] >= low) one[tl[i] - low] += up[i];
if (tr[i] <= hi) two[tr[i] - mid] += up[i];
if (tl[i] >= low && tr[i] <= hi) upd.push_back({tl[i], tr[i], up[i]});
}
for (int i = mid - 1; i >= low; i--) one[i - low] += one[i - low + 1];
for (int i = mid + 1; i <= hi; i++) two[i - mid] += two[i - mid - 1];
sort(begin(upd), end(upd), greater<>());
sort(begin(me), end(me), greater<>());
int p = 0;
for (auto [l, r, h] : me) {
ans[h] += one[l - low] + two[r - mid];
while (p < sz(upd) && l <= upd[p][0]) {
auto [i, j, x] = upd[p++];
ft.upd(j, x);
}
ans[h] -= ft.qry(r);
}
while (p > 0) {
auto [i, j, x] = upd[--p];
ft.upd(j, -x);
}
for (int i : v) vector<pii>().swap(aux[i]);
// cout << endl << endl << endl;
}
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
cin >> n >> m >> q;
for (int h = 0; h < n - 1; h++) {
int i, j;
cin >> i >> j, i--, j--;
adj[i].push_back(j);
adj[j].push_back(i);
}
dfs(-1, 0);
lg[1] = 0;
for (int i = 2; i <= sz(et); i++) lg[i] = lg[i / 2] + 1;
for (int i = 0; i < sz(et); i++) st[i][0] = et[i];
for (int l = 1; 1 << l <= sz(et); l++) for (int i = 0; i + (1 << l) <= sz(et); i++)
st[i][l] = low(st[i][l - 1], st[i + (1 << (l - 1))][l - 1]);
for (int i = 0; i < m; i++) cin >> c[i], c[i]--;
vector<ar<int, 3>> qr;
for (int h = 0; h < q; h++) {
int l, r;
cin >> l >> r, l--, r--;
qr.push_back({l, r, h});
}
rec(0, m - 1, qr);
for (int h = 0; h < q; h++) cout << ans[h] + 1 << '\n';
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |