Submission #1300316

#TimeUsernameProblemLanguageResultExecution timeMemory
1300316avighnaCapital City (JOI20_capital_city)C++20
0 / 100
179 ms56852 KiB
#include <algorithm>
#include <iostream>
#include <numeric>
#include <set>
#include <vector>

using namespace std;

class dsu {
private:
  int n;
  vector<int> par;

public:
  dsu(int n) : n(n), par(n) {
    iota(par.begin(), par.end(), 0);
  }

  int root(int u) { return u == par[u] ? u : par[u] = root(par[u]); }

  void merge(int u, int v) { // first retains its identity
    u = root(u), v = root(v);
    if (u != v) {
      par[v] = u;
    }
  }
};

const int inf = 1e9;

class segment_tree {
private:
  int n;
  vector<int> seg;

public:
  segment_tree(int n) : n(n), seg(2 * n) {}

  void add(int i, int x) {
    for (seg[i += n] += x, i >>= 1; i > 0; i >>= 1) {
      seg[i] = seg[2 * i] + seg[2 * i + 1];
    }
  }

  int query(int l, int r) {
    int ans = 0;
    for (l += n, r += n + 1; l < r; l >>= 1, r >>= 1) {
      if (l & 1)
        ans += seg[l++];
      if (r & 1)
        ans += seg[--r];
    }
    return ans;
  }
};

class segment_tree2 {
private:
  int n;
  vector<int> seg;

public:
  segment_tree2(int n) : n(n), seg(2 * n, -inf) {}

  void set(int i, int x) {
    for (seg[i += n] = x, i >>= 1; i > 0; i >>= 1) {
      seg[i] = max(seg[2 * i], seg[2 * i + 1]);
    }
  }

  int query(int l, int r) {
    int ans = -inf;
    for (l += n, r += n + 1; l < r; l >>= 1, r >>= 1) {
      if (l & 1)
        ans = max(ans, seg[l++]);
      if (r & 1)
        ans = max(ans, seg[--r]);
    }
    return ans;
  }
};

int main() {
  ios::sync_with_stdio(false);
  cin.tie(nullptr);

  int n, k;
  cin >> n >> k;
  vector<vector<int>> adj(n + 1);
  vector<int> deg(n + 1);
  for (int i = 0, u, v; i < n - 1; ++i) {
    cin >> u >> v;
    adj[u].push_back(v);
    adj[v].push_back(u);
    ++deg[u], ++deg[v];
  }

  vector<int> line, loc(n + 1);
  auto dfs = [&](auto &&self, int u, int p) -> void {
    loc[u] = line.size();
    line.push_back(u);
    for (int &i : adj[u]) {
      if (i != p) {
        self(self, i, u);
      }
    }
  };
  for (int i = 1; i <= n; ++i) {
    if (deg[i] == 1) {
      dfs(dfs, i, 0);
      break;
    }
  }

  vector<int> c(n + 1), city_st(k + 1, inf), city_en(k + 1, -inf);
  for (int i = 1; i <= n; ++i) {
    cin >> c[i];
    city_st[c[i]] = min(city_st[c[i]], loc[i]);
    city_en[c[i]] = max(city_en[c[i]], loc[i]);
  }
  vector<int> c_right(n + 1);
  for (int i = 1; i <= n; ++i) {
    c_right[i] = c[line[i - 1]];
  }

  dsu dsu(k + 1);
  vector<pair<int, int>> ints;
  vector<int> int_of(n + 1);
  for (int i = 1; i <= k; ++i) {
    if (city_st[i]==inf){continue;}
    ints.push_back({city_st[i], city_en[i]});
    int_of[city_st[i]] = i;
  }

  // merge
  // 1 2 3 4 5
  //     |---|
  // |-----|

  sort(ints.begin(), ints.end(), [&](pair<int, int> a, pair<int, int> b) { return a.second > b.second; });
  set<int> lefts;
  for (int i = 0; i < ints.size(); ++i) {
    auto [l, r] = ints[i];
    auto it = lefts.upper_bound(l);
    if (it != lefts.end() && *it < r) { // have intersection
      int u = int_of[l], v = int_of[*it];
      dsu.merge(u, v);
      lefts.erase(it);
    }
    lefts.insert(l);
  }

  // rebuild ints
  ints.clear();
  int_of.assign(n + 1, 0);
  vector<vector<int>> comps(k + 1);
  for (int i = 1; i <= k; ++i) {
    comps[dsu.root(i)].push_back(i);
  }
  for (auto &comp : comps) {
    if (comp.empty()) {
      continue;
    }
    int l = city_st[comp.front()], r = city_en[comp.front()];
    for (int &x : comp) {
      l = min(l, city_st[x]);
      r = max(r, city_en[x]);
    }
    ints.push_back({l, r});
    int_of[r] = comp.front();
  }

  // merge
  // 1 2 3 4 5
  // |---|
  //   |-----|

  sort(ints.begin(), ints.end(), [&](pair<int, int> a, pair<int, int> b) { return a.first < b.first; });
  set<int> rights;
  for (int i = 0; i < ints.size(); ++i) {
    auto [l, r] = ints[i];
    auto it = rights.lower_bound(r);
    if (it != rights.begin()) {
      --it;
      if (*it > l) {
        int u = int_of[r], v = int_of[*it];
        dsu.merge(u, v);
        rights.erase(it);
      }
    }
    rights.insert(r);
  }

  // now do complete interval intersection thing
  struct interval {
    int l, r, c, root;
  };
  comps = vector<vector<int>>(k + 1);
  for (int i = 1; i <= k; ++i) {
    comps[dsu.root(i)].push_back(i);
  }
  vector<interval> fin; // final intervals
  for (auto &comp : comps) {
    if (comp.empty()) {
      continue;
    }
    int l = city_st[comp.front()], r = city_en[comp.front()];
    for (int &x : comp) {
      l = min(l, city_st[x]);
      r = max(r, city_en[x]);
    }
    fin.push_back({l, r, int(comp.size()), dsu.root(comp.front())});
  }

  // cerr << "so far, we have:\n";
  // for (auto &[l, r, c, root] : fin) {
  //   cerr << "interval [" << l << ", " << r << "] with cost " << c << '\n';
  // }

  segment_tree st(n); // add at r, iterate over reverse l
  vector<int> ans_to_roots(k + 1);
  sort(fin.begin(), fin.end(), [&](interval a, interval b) { return a.l > b.l; });
  for (auto &[l, r, c, root] : fin) {
    ans_to_roots[dsu.root(root)] = st.query(0, r) + c;
    st.add(r, c);
  }

  segment_tree2 rmq_trust(n);
  for (int i = 1; i <= n; ++i) {
    // cerr << ans_to_roots[dsu.root(c_right[i])] << ' ';
    rmq_trust.set(i - 1, ans_to_roots[dsu.root(c_right[i])]);
  }
  // cerr << '\n';

  int ans = inf;
  for (auto &[l, r, c, root] : fin) {
    int qval = rmq_trust.query(l, r);
    if (qval != ans_to_roots[dsu.root(root)]) {
      continue;
    }
    ans = min(ans, qval);
  }

  cout << ans - 1 << '\n';
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...