#include <bits/stdc++.h>
using namespace std;
// define
#define execute cerr << " Time: " << fixed << setprecision(6) << (1.0 * clock() / CLOCKS_PER_SEC) << "s\n";
#define ll long long
#define ii pair<int,int>
#define iii pair<int,ii>
#define se second
#define fi first
#define all(v) (v).begin(), (v).end()
#define Unique(v) sort(all(v)), v.resize(unique(all(v)) - v.begin())
#define bit(x,i) (((x) >> (i)) & 1LL)
#define flip(x,i) ((x) ^ (1LL << (i)))
#define ms(d,x) memset(d , x , sizeof(d))
//constant
const long long mod = 1e9 + 7;
const long long linf = 4557430888798830399LL;
const long long nlinf = -4485090715960753727LL;
const int LOG = 20;
const int inf = 1061109567;
const int ninf = -1044266559;
const int maxn = 1e5 + 7;
struct Bitset {
using ull = unsigned long long;
// sizes chosen to safely handle n <= 100000
ull group[1605];
ull GroupOnline[50];
ull mask[66];
int SIZE = 0;
int numGroup = 0; // number of 64-bit groups
int numOfnumGroup = 0; // number of GroupOnline entries
inline ull mask_up_to(int idx) const {
// return mask that has bits [0..idx] set
if(idx < 0) return 0ULL;
if(idx >= 63) return ~0ULL;
return ((1ULL << (idx + 1)) - 1ULL);
}
inline ull mask_from(int idx) const {
// return mask that has bits [idx..63] set
if(idx <= 0) return ~0ULL;
if(idx >= 64) return 0ULL;
ull low = ((idx - 1) >= 63) ? ~0ULL : ((1ULL << idx) - 1ULL);
return (~low);
}
void init(int sz) {
SIZE = sz;
numGroup = (SIZE + 63) / 64; // ceil
numOfnumGroup = (numGroup + 63) / 64; // ceil
// zero arrays
for(int i = 0; i < 1605; ++i) group[i] = 0ULL;
for(int i = 0; i < 50; ++i) GroupOnline[i] = 0ULL;
// optional precomputed masks (not strictly necessary)
for(int i=0;i<=63;i++) mask[i] = mask_up_to(i);
mask[64] = ~0ULL;
}
void setOn(int x) {
int id = x >> 6;
int block = id >> 6;
int b = id & 63;
// mark group presence
GroupOnline[block] |= (1ULL << b);
// set bit inside group
int bitpos = x & 63;
group[id] |= (1ULL << bitpos);
}
void setOff(int x) {
int id = x >> 6;
int block = id >> 6;
int bitpos = x & 63;
group[id] &= ~(1ULL << bitpos);
if(group[id] == 0ULL) {
int b = id & 63;
GroupOnline[block] &= ~(1ULL << b);
}
}
// return largest index <= val which is set, or -1
int lower(int val) {
if(val < 0) return -1;
int id = val >> 6;
int off = val & 63;
if(id >= numGroup) id = numGroup - 1; // safety
// mask bits <= off
ull w = group[id] & mask_up_to(off);
if(w) {
int high = 63 - __builtin_clzll(w);
return (id << 6) + high;
}
// look for previous non-empty group in same GroupOnline block
int block = id >> 6;
int posInBlock = id & 63;
// groups strictly before id in the same block -> bits [0..posInBlock-1]
ull wb = 0ULL;
if(posInBlock > 0) {
wb = GroupOnline[block] & mask_up_to(posInBlock - 1);
}
if(wb) {
int groupIndexInBlock = 63 - __builtin_clzll(wb);
int g = (block << 6) + groupIndexInBlock;
ull wg = group[g];
int high = 63 - __builtin_clzll(wg);
return (g << 6) + high;
}
// search earlier blocks
for(int b = block - 1; b >= 0; --b) {
if(GroupOnline[b]) {
int groupIndexInBlock = 63 - __builtin_clzll(GroupOnline[b]);
int g = (b << 6) + groupIndexInBlock;
ull wg = group[g];
int high = 63 - __builtin_clzll(wg);
return (g << 6) + high;
}
}
return -1;
}
// return smallest index >= val which is set, or -1
int upper(int val) {
if(val < 0) return -1;
int id = val >> 6;
int off = val & 63;
if(id >= numGroup) return -1;
// mask bits >= off
ull w;
if(off == 0) w = group[id];
else if(off >= 63) w = (group[id] & (1ULL << 63));
else w = group[id] & mask_from(off);
if(w) {
int low = __builtin_ctzll(w);
return (id << 6) + low;
}
int block = id >> 6;
int posInBlock = id & 63;
// groups strictly after id in same block -> bits [posInBlock+1 .. 63]
ull wb = 0ULL;
if(posInBlock < 63) {
// mask groups after posInBlock
ull afterMask = (~mask_up_to(posInBlock));
wb = GroupOnline[block] & afterMask;
}
if(wb) {
int groupIndexInBlock = __builtin_ctzll(wb);
int g = (block << 6) + groupIndexInBlock;
ull wg = group[g];
int low = __builtin_ctzll(wg);
return (g << 6) + low;
}
// search later blocks
for(int b = block + 1; b < numOfnumGroup; ++b) {
if(GroupOnline[b]) {
int groupIndexInBlock = __builtin_ctzll(GroupOnline[b]);
int g = (b << 6) + groupIndexInBlock;
if(g >= numGroup) continue; // safety
ull wg = group[g];
int low = __builtin_ctzll(wg);
return (g << 6) + low;
}
}
return -1;
}
} B;
int n, m, q;
int BlockSize = 0;
int c[maxn], curNodes = 0, cnt[maxn];
int id[maxn];
struct Query {
int l, r;
int id;
bool operator < (Query &other) {
if ((l / BlockSize) == (other.l / BlockSize)) return r < other.r;
return l < other.l;
}
} Q[maxn];
int ans[maxn], sumDist;
vector<int> adj[maxn];
int in[maxn], out[maxn], depth[maxn], timer;
int arr[maxn * 2][20];
int tin[maxn], tout[maxn], timer2;
void dfs(int u, int p) {
in[u] = ++timer;
tin[u] = ++timer2;
arr[timer][0] = u;
for (int v : adj[u]) {
if (v == p) continue;
depth[v] = depth[u] + 1;
dfs(v, u);
arr[++timer][0] = u;
}
out[u] = timer;
tout[u] = timer2;
}
void BuildLca() {
for (int j = 1; (1 << j) <= timer; ++j) {
for (int i = 1; i + (1 << j) - 1 <= timer; ++i) {
int a = arr[i][j - 1], b = arr[i + (1 << (j - 1))][j - 1];
arr[i][j] = (depth[a] < depth[b]) ? a : b;
}
}
}
int lca(int u, int v) {
int l = in[u], r = in[v];
if (l > r) swap(l, r);
int k = __lg(r - l + 1);
if (depth[arr[l][k]] < depth[arr[r - (1 << k) + 1][k]]) return arr[l][k];
return arr[r - (1 << k) + 1][k];
}
int dist(int u, int v) {
return depth[u] + depth[v] - 2 * depth[lca(u, v)];
}
void AddSet(int u) {
if (curNodes == 1) {
B.setOn(u);
return;
}
int l = B.lower(u);
if (l == -1) {
if (cnt[n - 1] != 0) l = n - 1;
else l = B.lower(n - 1);
}
int r = B.upper(u);
if (r == -1) {
if (cnt[0] != 0) r = 0;
else r = B.upper(0);
}
assert(l != -1 && r != -1);
sumDist -= dist(id[l], id[r]);
sumDist += dist(id[l], id[u]);
sumDist += dist(id[u], id[r]);
B.setOn(u);
}
void DelSet(int u) {
if (curNodes <= 1) {
sumDist = 0;
B.setOff(u);
return;
}
B.setOff(u);
int l = B.lower(u);
if (l == -1) {
if (cnt[n - 1]) l = n - 1;
else l = B.lower(n - 1);
}
int r = B.upper(u);
if (r == -1) {
if (cnt[0]) r = 0;
else r = B.upper(0);
}
assert(l != -1 && r != -1);
sumDist -= dist(id[l], id[u]);
sumDist -= dist(id[r], id[u]);
sumDist += dist(id[l], id[r]);
}
void add_Node(int u) {
if (cnt[u] == 0) {
++curNodes;
AddSet(u);
}
++cnt[u];
}
void del_Node(int u) {
if (cnt[u] == 1) {
--curNodes;
DelSet(u);
}
--cnt[u];
}
int L = 1, R = 0;
void Mo(const int &l, const int &r) {
while (R < r) { ++R; add_Node(tin[c[R]]); }
while (L > l) { --L; add_Node(tin[c[L]]); }
while (L < l) { del_Node(tin[c[L]]); ++L; }
while (R > r) { del_Node(tin[c[R]]); --R; }
}
void solve(void) {
cin >> n >> m >> q;
// clear
for (int i = 0; i < n; ++i) {
adj[i].clear();
cnt[i] = 0;
}
curNodes = 0;
sumDist = 0;
timer = timer2 = 0;
for (int i = 1; i < n; ++i) {
int u, v;
cin >> u >> v;
--u; --v;
adj[u].push_back(v);
adj[v].push_back(u);
}
for (int i = 1; i <= m; ++i) {
cin >> c[i];
--c[i];
}
for (int i = 1; i <= q; ++i) {
cin >> Q[i].l >> Q[i].r;
Q[i].id = i;
}
// root at 0 (nodes are 0-based after decrement)
depth[0] = 0;
dfs(0, -1);
BuildLca();
B.init(n);
for (int i = 0; i < n; ++i) {
tin[i]--;
id[tin[i]] = i;
}
BlockSize = max(1, (int)sqrt(m));
sort(Q + 1, Q + q + 1);
for (int i = 1; i <= q; ++i) {
Mo(Q[i].l, Q[i].r);
ans[Q[i].id] = (sumDist / 2) + 1;
}
for (int i = 1; i <= q; ++i) {
cout << ans[i] << "\n";
}
}
/**
Sample input (from your comment)
7 6 1
1 2
1 3
2 4
2 5
3 6
3 7
2 3 6 4 5 7
4 6
**/
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
int tc = 1;
//cin >> tc;
while (tc--) solve();
//execute;
return 0;
}
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |