#include <bits/stdc++.h>
#define ar array
//#define int long long
using namespace std;
using ll = long long;
using pii = pair<int, int>;
using pll = pair<ll, ll>;
const int mod = 1e9 + 7;
const int LOG = 20;
const int maxn = 1e5 + 5;
vector<int> G[maxn], E;
int in[maxn], up[maxn][20], dep[maxn], n, m, q, timer=0;
const int B = 310;
void dfs(int u, int p) {
in[u] = timer++; E.push_back(u);
for(int i=1; i<20; i++) up[u][i] = up[up[u][i-1]][i-1];
for(int v : G[u]) {
if(v == p) continue;
dep[v] = dep[u] + 1;
up[v][0] = u;
dfs(v, u);
}
}
int lca(int a, int b) {
if(dep[a] < dep[b]) swap(a, b);
int d = dep[a] - dep[b];
for(int j=19; j>=0; j--)
if(d & (1 << j)) a = up[a][j];
if(a == b) return a;
for(int j=19; j>=0; j--)
if(up[a][j] != up[b][j]) a = up[a][j], b = up[b][j];
return up[a][0];
}
int dist(int a, int b) {
return dep[a] + dep[b] - 2 * dep[lca(a, b)];
}
bool cmp(ar<int, 3> a, ar<int, 3> b) {
if(a[0] / B != b[0] / B) return a[0] < b[0];
return a[1] < b[1];
}
int cnt[maxn];
set<int> st;
ll sum = 0;
void rem(int u) {
cnt[u]--;
if(cnt[u] > 0) return ;
int L=0, R=0;
auto it = st.find(in[u]);
if(it != st.begin()) L = E[*prev(it)];
else L = E[*st.rbegin()];
if(it != --st.end()) R = E[*next(it)];
else R = E[*st.begin()];
if(L) sum -= dist(u, L);
if(R) sum -= dist(u, R);
if(L && R) sum += dist(L, R);
st.erase(in[u]);
}
void add(int u) {
cnt[u]++;
if(cnt[u] > 1) return ;
st.insert(in[u]);
int L=0, R=0;
auto it = st.find(in[u]);
if(it != st.begin()) L = E[*prev(it)];
else L = E[*st.rbegin()];
if(it != --st.end()) R = E[*next(it)];
else R = E[*st.begin()];
// cout << u << ": " << L << " " << R << '\n';
if(L) sum += dist(u, L);
if(R) sum += dist(u, R);
if(L && R) sum -= dist(L, R);
}
signed main() {
ios_base::sync_with_stdio(false);
cout.tie(0); cin.tie(0);
cin >> n >> m >> q;
for(int i=1; i<n; i++) {
int a, b; cin >> a >> b;
G[a].push_back(b);
G[b].push_back(a);
}
dfs(1, 1);
vector<int> a(m+1), ans(q+1);
for(int i=1; i<=m; i++) cin >> a[i];
vector<ar<int, 3> > qus;
for(int i=1; i<=q; i++) {
int l, r; cin >> l >> r;
qus.push_back({ l, r, i });
}
sort(qus.begin(), qus.end(), cmp);
int l=1, r=0;
for(auto [ql, qr, id] : qus) {
while(r < qr) add(a[++r]);
while(r > qr) rem(a[r--]);
while(l > ql) add(a[--l]);
while(l < ql) rem(a[l++]);
ans[id] = sum / 2 + 1;
}
for(int i=1; i<=q; i++) cout << ans[i] << '\n';
return 0;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |