제출 #1333811

#제출 시각아이디문제언어결과실행 시간메모리
1333811altern23Cats or Dogs (JOI18_catdog)C++20
0 / 100
11 ms16076 KiB
#include "catdog.h"
#include <bits/stdc++.h>
using namespace std;

#define ll long long

const int MAXN = 1e5;
const ll INF = 1e9;

vector <ll> adj[MAXN + 5];
ll sz[MAXN + 5], P[MAXN + 5], state[MAXN + 5], head[MAXN + 5];
ll MN[MAXN + 5], MX[MAXN + 5], tin[MAXN + 5], t;

void construct_chain(ll idx, ll par) {
      tin[idx] = ++t;
      MN[head[idx]] = min(MN[head[idx]], tin[idx]);
      MX[head[idx]] = max(MX[head[idx]], tin[idx]);
      for (auto i : adj[idx]) {
            if (i != par) {
                  head[i] = (i == adj[idx][0] ? head[idx] : i);
                  construct_chain(i, idx);
            }
      }
}

void dfs(ll idx, ll par) {
      sz[idx] = 1;
      for (auto &i : adj[idx]) {
            if (i != par) {
                  P[i] = idx;
                  dfs(i, idx);
                  sz[idx] += sz[i];
                  if (sz[i] > sz[adj[idx][0]]) {
                        swap(i, adj[idx][0]);
                  }
            }
      }
}

struct ST {
      ll l, r;
      ST *lf, *rg;
      ll DP[2][2];

      ST (ll _l, ll _r) {
            l = _l, r = _r;
            for (int i = 0; i < 2; i++) {
                  for (int j = 0; j < 2; j++) DP[i][j] = INF;
            }
      }

      void combine() {
            for (int i = 0; i < 2; i++) {
                  for (int j = 0; j < 2; j++) {
                        DP[i][j] = INF;
                        for (int k = 0; k < 2; k++) {
                              for (int l = 0; l < 2; l++) {
                                    DP[i][j] = min(DP[i][j], lf->DP[i][k] + rg->DP[l][j] + (k != l));
                              }
                        }
                  }
            }
      }

      void build() {
            if (l == r) {
                  DP[0][0] = DP[1][1] = 0;
                  return;
            }

            ll mid = (l + r) / 2;
            lf = new ST(l, mid), rg = new ST(mid + 1, r);
            lf->build(), rg->build();
            
            combine();
      }

      void update(ll idx, pair<ll, ll> val) {
            if (l == r) {
                  DP[0][0] += val.first;
                  DP[1][1] += val.second;
                  return;
            }

            ll mid = (l + r) / 2;
            if (idx <= mid) lf->update(idx, val);
            else rg->update(idx, val);

            combine();
      }

      vector <vector<ll>> query(ll x, ll y) {
            if (l > y || r < x) return {};
            if (l >= x && r <= y) {
                  return {{DP[0][0], DP[0][1]}, {DP[1][0], DP[1][1]}};
            }
            
            vector <vector<ll>> L = lf->query(x, y), R = rg->query(x, y), ret(2, vector <ll> (2, INF));

            if (L.empty()) return R;
            if (R.empty()) return L;

            for (int i = 0; i < 2; i++) {
                  for (int j = 0; j < 2; j++) {
                        for (int k = 0; k < 2; k++) {
                              for (int l = 0; l < 2; l++) {
                                    ret[i][j] = min(ret[i][j], L[i][k] + R[l][j] + (k != l));
                              }
                        }
                  }
            }

            return ret;
      }

} sg(1, 100000);

void initialize(int N, std::vector<int> A, std::vector<int> B) {
      for (int i = 1; i <= N; i++) {
            state[i] = -1;
            MN[i] = INF, MX[i] = -INF;
      }
      for (int i = 0; i < N - 1; i++) {
            adj[A[i]].push_back(B[i]);
            adj[B[i]].push_back(A[i]);
      }
      dfs(1, -1);
      head[1] = 1;
      construct_chain(1, -1);
      sg.build();
      vector <ll> v;
      for (int i = 1; i <= N; i++) {
            if (head[i] == i) {
                  v.push_back(i);
            }
      }
      sort(v.begin(), v.end(), [&](ll a, ll b) {
            return tin[a] < tin[b];
      });
      for (auto i : v) {
            vector <vector<ll>> cur = sg.query(MN[i], MX[i]);
            if (P[i]) {
                  pair<ll, ll> val;
                  val.first = min({cur[0][0], cur[0][1], cur[1][0] + 1, cur[1][1] + 1});
                  val.second = min({cur[1][0], cur[1][1], cur[0][0] + 1, cur[0][1] + 1});
                  sg.update(tin[P[i]], val);
            }
      }
}

void solve(ll idx, ll d, ll c) {
      ll D = d, C = c;
      while (idx) {
            vector <vector<ll>> cur = sg.query(MN[head[idx]], MX[head[idx]]);
            
            pair<ll, ll> val, val2;
            val.first = min({cur[0][0], cur[0][1], cur[1][0] + 1, cur[1][1] + 1});
            val.second = min({cur[1][0], cur[1][1], cur[0][0] + 1, cur[0][1] + 1});

            sg.update(tin[idx], {D, C});

            cur = sg.query(MN[head[idx]], MX[head[idx]]);
            val2.first = min({cur[0][0], cur[0][1], cur[1][0] + 1, cur[1][1] + 1});
            val2.second = min({cur[1][0], cur[1][1], cur[0][0] + 1, cur[0][1] + 1});

            D = val2.first - val.first;
            C = val2.second - val.second;

            idx = P[head[idx]];
      }
}

int cat(int v) {
      state[v] = 0;
      solve(v, 0, INF);
      return min({sg.DP[0][0], sg.DP[0][1], sg.DP[1][0], sg.DP[1][1]});
}

int dog(int v) {
      state[v] = 1;
      solve(v, INF, 0);
      return min({sg.DP[0][0], sg.DP[0][1], sg.DP[1][0], sg.DP[1][1]});
}

int neighbor(int v) {
      if (!state[v]) solve(v, 0, -INF);
      else solve(v, -INF, 0);
      state[v] = -1;
      return min({sg.DP[0][0], sg.DP[0][1], sg.DP[1][0], sg.DP[1][1]});
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...