#include <bits/stdc++.h>
using namespace std;
template<typename A, typename B> ostream& operator<<(ostream &os, const pair<A, B> &p) { return os << '(' << p.first << ", " << p.second << ')'; }
template<typename T_container, typename T = typename enable_if<!is_same<T_container, string>::value, typename T_container::value_type>::type> ostream& operator<<(ostream &os, const T_container &v) { os << '{'; string sep; for (const T &x : v) os << sep << x, sep = ", "; return os << '}'; }
#define ll long long
#define ld long double
#define all(a) (a).begin(), (a).end()
#define sui cout.tie(NULL); cin.tie(NULL); ios_base::sync_with_stdio(false)
const int MAX_N = 1e5 + 5;
const int MOD = 1e9 + 7;
const ll INF = 1e9;
const ld EPS = 1e-9;
const int LOG = 30;
const int SQ = 320;
vector<int> adj[MAX_N];
int st[MAX_N];
int rev[MAX_N];
int timer = 0;
int h[MAX_N];
int par[MAX_N][LOG];
int cnt[MAX_N];
int sum = 0;
set<int> al;
void dfs(int u, int p)
{
rev[timer] = u;
st[u] = timer++;
par[u][0] = p;
for (int i = 1; i < LOG; i++) par[u][i] = par[par[u][i - 1]][i - 1];
for (int v : adj[u]) if (v != p)
{
h[v] = h[u] + 1;
dfs(v, u);
}
}
int getpar(int u, int k)
{
for (int i = 0; i < LOG; i++) if (k & (1 << i)) u = par[u][i];
return u;
}
int lca(int u, int v)
{
if (h[u] < h[v]) swap(u, v);
u = getpar(u, h[u] - h[v]);
if (u == v) return u;
for (int i = LOG - 1; i >= 0; i--) if (par[u][i] != par[v][i]) u = par[u][i], v = par[v][i];
return par[u][0];
}
int dis(int u, int v)
{
int lc = lca(u, v);
return h[u] + h[v] - 2 * h[lc];
}
void add(int x)
{
cnt[x]++;
if (cnt[x] > 1) return;
auto nxt = al.lower_bound(st[x]);
if (nxt != al.end())
{
int num = rev[*nxt];
sum += dis(x, num);
if (nxt != al.begin())
{
nxt--;
sum += dis(rev[*nxt], x);
sum -= dis(rev[*nxt], num);
}
} else if (nxt != al.begin())
{
nxt--;
sum += dis(rev[*nxt], x);
}
al.insert(st[x]);
}
void rem(int x)
{
cnt[x]--;
if (cnt[x]) return;
al.erase(st[x]);
auto nxt = al.lower_bound(st[x]);
if (nxt != al.end())
{
int num = rev[*nxt];
sum -= dis(x, num);
if (nxt != al.begin())
{
nxt--;
sum -= dis(rev[*nxt], x);
sum += dis(rev[*nxt], num);
}
} else if (nxt != al.begin())
{
nxt--;
sum -= dis(rev[*nxt], x);
}
}
void solve() {
int n, m, q;
cin >> n >> m >> q;
for (int i = 1; i < n; i++)
{
int u, v;
cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
}
int c[m];
for (int i = 0; i < m; i++) cin >> c[i];
vector<pair<pair<int, int>, int>> qs;
int l[q];
int r[q];
int ans[q];
for (int i = 0; i < q; i++)
{
cin >> l[i] >> r[i];
l[i]--, r[i]--;
qs.push_back({{l[i] / SQ, (l[i] / SQ) % 2 ? -r[i] : r[i]}, i});
}
sort(all(qs));
int l1 = 0;
int r1 = 0;
dfs(1, 0);
al.insert(st[c[0]]);
cnt[c[0]]++;
for (int i = 0; i < q; i++)
{
while (l1 > l[qs[i].second]) add(c[l1 - 1]), l1--;
while (r1 < r[qs[i].second]) add(c[r1 + 1]), r1++;
while (l1 < l[qs[i].second]) rem(c[l1]), l1++;
while (r1 > r[qs[i].second]) rem(c[r1]), r1--;
ans[qs[i].second] = (sum + dis(rev[*al.begin()], rev[*al.rbegin()])) / 2 + 1;
}
for (int i = 0; i < q; i++) cout << ans[i] << "\n";
}
int main() {
sui;
int tc = 1;
//cin >> tc;
for (int t = 1; t <= tc; t++) {
solve();
}
}
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |