Submission #532281

#TimeUsernameProblemLanguageResultExecution timeMemory
532281Alex_tz307Cats or Dogs (JOI18_catdog)C++17
100 / 100
693 ms24028 KiB
#include <bits/stdc++.h>
#include "catdog.h"
#define INF 0x3f3f3f3f

using namespace std;

/*
5
1 2
2 3
2 4
4 5
5
1 3
2 5
1 2
2 1
3 2
*/

/// dp[nod][c] =def= costul minim pentru a avea in componenta care-l contine pe nod doar noduri
///                  colorate in c sau necolorate
/// Se observa ca pe o linie sa rezolva simplu cu un aint
/// Fac cu heavy merge-ul la dp pe chain-uri tinand in nodurile aint-ului starile celor 2 capete
/// ale chain-ului
/// Pentru muchiile light actualizez dinamica "normal" si continui pe urmatoarele chain-uri pana
/// la radacina

const int kN = 1e5;
vector<int> g[1 + kN];
int labels, sz[1 + kN], p[1 + kN], heavySon[1 + kN], chainTop[1 + kN], label[1 + kN], down[1 + kN], last[1 + kN][2], sum[1 + kN][2];
short col[1 + kN];

void minSelf(int &x, int y) {
  if (y < x) {
    x = y;
  }
}

struct node {
  int dp[2][2];

  node() {
    for (int i = 0; i < 2; ++i) {
      for (int j = 0; j < 2; ++j) {
        dp[i][j] = INF;
      }
    }
  }

  void init(int pos, int c) {
    dp[0][1] = dp[1][0] = INF;
    for (int i = 0; i < 2; ++i) {
      if (i != c && c != 2) {
        dp[i][i] = INF;
      } else {
        dp[i][i] = sum[pos][i];
      }
    }
  }

  node operator + (const node &rhs) const {
    node ret;
    for (int i = 0; i < 2; ++i) {
      for (int j = 0; j < 2; ++j) {
        ret.dp[i][j] = INF;
      }
    }
    for (int a = 0; a < 2; ++a) {
      for (int b = 0; b < 2; ++b) {
        for (int c = 0; c < 2; ++c) {
          for (int d = 0; d < 2; ++d) {
            minSelf(ret.dp[a][d], dp[a][b] + rhs.dp[c][d] + (b != c));
          }
        }
      }
    }
    return ret;
  }
};

struct ST {
  int n;
  vector<node> t;

  void init(int N) {
    n = N;
    int dim = 1;
    while (dim < n) {
      dim *= 2;
    }
    t.resize(dim * 2);
  }

  void build(int x, int lx, int rx) {
    if (lx == rx) {
      t[x].init(lx, 2);
      return;
    }
    int mid = (lx + rx) / 2;
    build(x * 2, lx, mid);
    build(x * 2 + 1, mid + 1, rx);
    t[x] = t[x * 2] + t[x * 2 + 1];
  }

  void update(int x, int lx, int rx, int pos, int c) {
    if (lx == rx) {
      t[x].init(pos, c);
      return;
    }
    int mid = (lx + rx) / 2;
    if (pos <= mid) {
      update(x * 2, lx, mid, pos, c);
    } else {
      update(x * 2 + 1, mid + 1, rx, pos, c);
    }
    t[x] = t[x * 2] + t[x * 2 + 1];
  }

  void update(int pos, int c) {
    update(1, 1, n, pos, c);
  }

  node query(int x, int lx, int rx, int st, int dr) {
    if (st <= lx && rx <= dr) {
      return t[x];
    }
    int mid = (lx + rx) / 2;
    if (st <= mid && mid < dr) {
      return query(x * 2, lx, mid, st, dr) + query(x * 2 + 1, mid + 1, rx, st, dr);
    }
    if (st <= mid) {
      return query(x * 2, lx, mid, st, dr);
    }
    return query(x * 2 + 1, mid + 1, rx, st, dr);
  }

  node query(int st, int dr) {
    return query(1, 1, n, st, dr);
  }
} t;

void dfs1(int u) {
  sz[u] = 1;
  chainTop[u] = u;
  for (int v : g[u]) {
    if (v != p[u]) {
      p[v] = u;
      dfs1(v);
      if (sz[heavySon[u]] < sz[v]) {
        heavySon[u] = v;
      }
      sz[u] += sz[v];
    }
  }
}

void dfs2(int u) {
  label[u] = ++labels;
  down[chainTop[u]] = u;
  if (heavySon[u] == 0) {
    return;
  }
  chainTop[heavySon[u]] = chainTop[u];
  dfs2(heavySon[u]);
  for (int v : g[u]) {
    if (v != p[u] && v != heavySon[u]) {
      dfs2(v);
    }
  }
}

int getDp(node x, int i) {
  int best = INF;
  for (int j = 0; j < 2; ++j) {
    minSelf(best, x.dp[i][j]);
    minSelf(best, x.dp[i ^ 1][j] + 1);
  }
  return best;
}

node update(int v) {
  t.update(label[v], col[v]);
  int root = chainTop[v];
  node chain = t.query(label[root], label[down[root]]);
  if (root == 1) {
    return chain;
  }
  for (int i = 0; i < 2; ++i) {
    sum[label[p[root]]][i] -= last[root][i];
    last[root][i] = getDp(chain, i);
    sum[label[p[root]]][i] += last[root][i];
  }
  return update(p[root]);
}

void initialize(int N, vector<int> A, vector<int> B) {
  for (int i = 0; i < N - 1; ++i) {
    g[A[i]].emplace_back(B[i]);
    g[B[i]].emplace_back(A[i]);
  }
  for (int v = 1; v <= N; ++v) {
    col[v] = 2;
  }
  dfs1(1);
  dfs2(1);
  t.init(N);
  t.build(1, 1, N);
}

int cat(int v) {
  col[v] = 0;
  node ret = update(v);
  return min(getDp(ret, 0), getDp(ret, 1));
}

int dog(int v) {
  col[v] = 1;
  node ret = update(v);
  return min(getDp(ret, 0), getDp(ret, 1));
}

int neighbor(int v) {
  col[v] = 2;
  node ret = update(v);
  return min(getDp(ret, 0), getDp(ret, 1));
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...