Submission #674359

#TimeUsernameProblemLanguageResultExecution timeMemory
674359tibinyteCats or Dogs (JOI18_catdog)C++17
100 / 100
607 ms40356 KiB
#include <bits/stdc++.h>
#include "catdog.h"
using namespace std;
const int inf = 1e5;

struct node
{
  int dp[2][2];
  node()
  {
    for (int i = 0; i < 2; ++i)
    {
      for (int j = 0; j < 2; ++j)
      {
        dp[i][j] = inf;
      }
    }
  }
};

struct aint
{
  vector<node> a;
  void init(int n)
  {
    a = vector<node>(4 * n);
  }
  node combine(node a, node b)
  {
    node c;
    for (int i = 0; i < 2; ++i)
    {
      for (int j = 0; j < 2; ++j)
      {
        for (int k = 0; k < 2; ++k)
        {
          for (int l = 0; l < 2; ++l)
          {
            c.dp[i][j] = min(c.dp[i][j], a.dp[i][k] + b.dp[l][j] + (k != l));
          }
        }
      }
    }
    return c;
  }
  void update(int node, int left, int right, int pos, pair<int, int> val)
  {
    if (left == right)
    {
      a[node].dp[0][0] = val.first;
      a[node].dp[1][1] = val.second;
      a[node].dp[0][1] = inf;
      a[node].dp[1][0] = inf;
      return;
    }
    int mid = (left + right) / 2;
    if (pos <= mid)
    {
      update(2 * node, left, mid, pos, val);
    }
    else
    {
      update(2 * node + 1, mid + 1, right, pos, val);
    }
    a[node] = combine(a[2 * node], a[2 * node + 1]);
  }
  node query(int node, int left, int right, int st, int dr)
  {
    if (st <= left && dr >= right)
    {
      return a[node];
    }
    int mid = (left + right) / 2;
    if (st <= mid && mid + 1 <= dr)
    {
      return combine(query(2 * node, left, mid, st, dr), query(2 * node + 1, mid + 1, right, st, dr));
    }
    if (st <= mid)
    {
      return query(2 * node, left, mid, st, dr);
    }
    return query(2 * node + 1, mid + 1, right, st, dr);
  }
};

int n;

vector<int> a;

vector<vector<int>> g;

vector<int> heavy, pos, head, par, down;

vector<vector<int>> sum, prv;

aint tree;

int p = 0;
void init_hld()
{
  pos = head = par = down = vector<int>(n + 1);
  heavy = vector<int>(n + 1, -1);
  function<int(int, int)> dfs = [&](int node, int parent)
  {
    int sz = 1;
    int maxi = 0;
    for (auto i : g[node])
    {
      if (i != parent)
      {
        int cine = dfs(i, node);
        sz += cine;
        if (cine > maxi)
        {
          maxi = cine;
          heavy[node] = i;
        }
      }
    }
    return sz;
  };
  dfs(1, 0);
  function<void(int, int, int)> get = [&](int node, int boss, int parent)
  {
    par[node] = parent;
    head[node] = boss;
    pos[node] = ++p;
    down[head[node]] = node;
    if (heavy[node] != -1)
    {
      get(heavy[node], boss, node);
    }
    for (auto i : g[node])
    {
      if (i != parent && i != heavy[node])
      {
        get(i, i, node);
      }
    }
  };
  get(1, 1, 0);
}

void initialize(int N, vector<int> A, vector<int> B)
{
  n = N;
  g = vector<vector<int>>(n + 1);
  a = vector<int>(n + 1, 2);
  prv = sum = vector<vector<int>>(n + 1, vector<int>(2));
  for (int i = 0; i <= n - 2; ++i)
  {
    g[A[i]].push_back(B[i]);
    g[B[i]].push_back(A[i]);
  }
  init_hld();
  tree.init(n);
  for (int i = 1; i <= n; ++i)
  {
    tree.update(1, 1, n, i, {0, 0});
  }
}
node update(int nd)
{
  if (a[nd] == 0)
  {
    tree.update(1, 1, n, pos[nd], {sum[nd][0], inf});
  }
  if (a[nd] == 1)
  {
    tree.update(1, 1, n, pos[nd], {inf, sum[nd][1]});
  }
  if (a[nd] == 2)
  {
    tree.update(1, 1, n, pos[nd], {sum[nd][0], sum[nd][1]});
  }
  int qui = head[nd];
  node dp_qui = tree.query(1, 1, n, pos[qui], pos[down[qui]]);
  if (qui == 1)
  {
    return dp_qui;
  }
  for (int i = 0; i < 2; ++i)
  {
    sum[par[qui]][i] -= prv[qui][i];
    int best = inf;
    for (int j = 0; j < 2; ++j)
    {
      best = min(best, dp_qui.dp[i][j]);
      best = min(best, dp_qui.dp[i ^ 1][j] + 1);
    }
    prv[qui][i] = best;
    sum[par[qui]][i] += prv[qui][i];
  }
  return update(par[qui]);
}

int cat(int v)
{
  a[v] = 0;
  node ans = update(v);
  int best = inf;
  for (int i = 0; i < 2; ++i)
  {
    for (int j = 0; j < 2; ++j)
    {
      best = min(best, ans.dp[i][j]);
    }
  }
  return best;
}

int dog(int v)
{
  a[v] = 1;
  node ans = update(v);
  int best = inf;
  for (int i = 0; i < 2; ++i)
  {
    for (int j = 0; j < 2; ++j)
    {
      best = min(best, ans.dp[i][j]);
    }
  }
  return best;
}

int neighbor(int v)
{
  a[v] = 2;
  node ans = update(v);
  int best = inf;
  for (int i = 0; i < 2; ++i)
  {
    for (int j = 0; j < 2; ++j)
    {
      best = min(best, ans.dp[i][j]);
    }
  }
  return best;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...