Submission #1167244

#TimeUsernameProblemLanguageResultExecution timeMemory
1167244fryingducTourism (JOI23_tourism)C++20
100 / 100
3084 ms34216 KiB
#include "bits/stdc++.h"

using namespace std;

#ifdef duc_debug
#include "bits/debug.h"
#else
#define debug(...)
#endif

const int maxn = 1e5 + 5;
const int B = 505;
const int LG = 18;
int n, m, q;
vector<int> g[maxn];
int c[maxn];
int lg2[maxn << 1];
int h[maxn];
int tin[maxn], tout[maxn], timer, et[maxn];
long long res[maxn];

struct query {
  int l, r, id;
  bool operator < (const query &o) {
    if (l / B != o.l / B) return l / B < o.l / B;
    return ((l / B) & 1 ? r < o.r : r > o.r);
  }
} que[maxn];

void pre_dfs(int u, int prev) {
  tin[u] = ++timer;
  et[timer] = u;
  for (auto v : g[u]) {
    if (v == prev) continue;
    h[v] = h[u] + 1;
    pre_dfs(v, u);
  }
  tout[u] = timer;
}

namespace lca_o1 {
  int st[maxn << 1][LG + 1];
  int tin[maxn], et[maxn << 1], timer;
  
  void dfs(int u, int prev) {
    tin[u] = ++timer;
    et[timer] = u;
    for (auto v : g[u]) {
      if (v == prev) continue;
      dfs(v, u);
      et[++timer] = u;
    }
  }
  
  void build() {
    dfs(1, 0);
    for (int i = 1; i <= timer; ++i) {
      st[i][0] = et[i];
    }
    for (int j = 1; j <= LG; ++j) {
      for (int i = 1; i + (1 << j) <= timer + 1; ++i) {
        if (h[st[i][j - 1]] < h[st[i + (1 << (j - 1))][j - 1]]) {
          st[i][j] = st[i][j - 1];
        } else {
          st[i][j] = st[i + (1 << (j - 1))][j - 1];
        }
      }
    }
  }
  
  inline int lca(int u, int v) {
    if (tin[u] > tin[v]) swap(u, v);
    int p = lg2[tin[v] - tin[u] + 1];
    return h[st[tin[u]][p]] < h[st[tin[v] - (1 << p) + 1][p]] ?
    st[tin[u]][p] : st[tin[v] - (1 << p) + 1][p];
  }
  
}

inline int dist(int u, int v) {
  return h[u] + h[v] - 2 * h[lca_o1::lca(u, v)];
}

int freq[maxn];
int bit[maxn];

void update(int i, int val) {
  for (; i <= n; i += i & (-i)) {
    bit[i] += val;
  }
}

int get(int i) {
  int ans = 0;
  for (; i > 0; i -= i & (-i)) {
    ans += bit[i];
  }
  return ans;
}

int lower_bound(long long v) {
  long long sum = 0;
  int pos = 0;
  for (int i = LG; i >= 0; i--) {
    if (pos + (1 << i) <= n and sum + bit[pos + (1 << i)] < v) {
      sum += bit[pos + (1 << i)];
      pos += (1 << i);
    }
  }
  return pos + 1;
}

int cnt;
long long total;

void add(int p) {
  if (!freq[c[p]]) {
    update(tin[c[p]], 1);
    ++cnt;
    if (cnt == 2) {
      total = 0;
      int x = lower_bound(1), y = lower_bound(2);
      total = dist(et[x], et[y]) * 2;
    } else {
      int g = get(tin[c[p]]);
      int l = g == 1 ? lower_bound(cnt) : lower_bound(g - 1);
      int r = g == cnt ? lower_bound(1) : lower_bound(g + 1);
      total -= dist(et[l], et[r]);
      total += dist(et[l], c[p]) + dist(et[r], c[p]);
    }
  }
//  debug("add", p, c[p], total);
  ++freq[c[p]];
}

void del(int p) {
  --freq[c[p]];
  if (!freq[c[p]]) {
    if (cnt < 3) total = 0;
    else {
      int g = get(tin[c[p]]);
      int l = g == 1 ? lower_bound(cnt) : lower_bound(g - 1);
      int r = g == cnt ? lower_bound(1) : lower_bound(g + 1);
//      debug(c[p], et[l], et[r]);
      total += dist(et[l], et[r]);
      total -= dist(et[l], c[p]) + dist(et[r], c[p]);
    }
    --cnt;
    update(tin[c[p]], -1);
  }
//  debug("del", p, c[p], total);
}

void solve() {
  cin >> n >> m >> q;
  for (int i = 1; i < n; ++i) {
    int u, v; cin >> u >> v;
    g[u].push_back(v);
    g[v].push_back(u);
  }
  for (int i = 1; i <= m; ++i) {
    cin >> c[i];
  }
  pre_dfs(1, 0);
  lca_o1::build();
  for (int i = 1; i <= q; ++i) {
    cin >> que[i].l >> que[i].r;
    que[i].id = i;
  }
  sort(que + 1, que + q + 1);
  int l = 1, r = 0;
  for (int i = 1; i <= q; ++i) {
    while (l > que[i].l) add(--l);
    while (r < que[i].r) add(++r);
    while (l < que[i].l) del(l++);
    while (r > que[i].r) del(r--);
//    debug(que[i].id, l, r, total);
    res[que[i].id] = total;
  }
  for (int i = 1; i <= q; ++i) {
    cout << (res[i] >> 1) + 1 << '\n';
  }
}

signed main() {
  ios_base::sync_with_stdio(0);
  cin.tie(0);
  
  for (int i = 2; i < (maxn << 1); ++i) {
    lg2[i] = lg2[i >> 1] + 1;
  }
  solve();

  return 0;
}


#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...