Submission #674354

#TimeUsernameProblemLanguageResultExecution timeMemory
674354tibinyteCats or Dogs (JOI18_catdog)C++17
38 / 100
3036 ms41164 KiB
#include <bits/stdc++.h>
#include "catdog.h"
using namespace std;
const int inf = 1e5;

struct aint
{
  vector<vector<vector<int>>> a;
  void init(int n)
  {
    a = vector<vector<vector<int>>>(4 * n, vector<vector<int>>(2, vector<int>(2)));
  }
  vector<vector<int>> combine(vector<vector<int>> a, vector<vector<int>> b)
  {
    vector<vector<int>> ans(2, vector<int>(2, inf));
    if (a == ans)
    {
      return b;
    }
    if (b == ans)
    {
      return a;
    }
    for (int i = 0; i < 2; ++i)
    {
      for (int j = 0; j < 2; ++j)
      {
        for (int k = 0; k < 2; ++k)
        {
          for (int l = 0; l < 2; ++l)
          {
            ans[i][j] = min(ans[i][j], a[i][k] + b[l][j] + (k != l));
          }
        }
      }
    }
    return ans;
  }
  void update(int node, int left, int right, int pos, pair<int, int> val)
  {
    if (left == right)
    {
      a[node][0][0] = val.first;
      a[node][1][1] = val.second;
      a[node][0][1] = inf;
      a[node][1][0] = inf;
      return;
    }
    int mid = (left + right) / 2;
    if (pos <= mid)
    {
      update(2 * node, left, mid, pos, val);
    }
    else
    {
      update(2 * node + 1, mid + 1, right, pos, val);
    }
    a[node] = combine(a[2 * node], a[2 * node + 1]);
  }
  vector<vector<int>> query(int node, int left, int right, int st, int dr)
  {
    if (right < st || left > dr)
    {
      vector<vector<int>> ans(2, vector<int>(2, inf));
      return ans;
    }
    if (st <= left && dr >= right)
    {
      return a[node];
    }
    int mid = (left + right) / 2;
    return combine(query(2 * node, left, mid, st, dr), query(2 * node + 1, mid + 1, right, st, dr));
  }
};

int n;

vector<int> a;

vector<vector<int>> g;

vector<int> heavy, pos, head, par, down;

vector<vector<int>> sum, prv;

aint tree;

int p = 0;
void init_hld()
{
  pos = head = par = down = vector<int>(n + 1);
  heavy = vector<int>(n + 1, -1);
  function<int(int, int)> dfs = [&](int node, int parent)
  {
    int sz = 1;
    int maxi = 0;
    for (auto i : g[node])
    {
      if (i != parent)
      {
        int cine = dfs(i, node);
        sz += cine;
        if (cine > maxi)
        {
          maxi = cine;
          heavy[node] = i;
        }
      }
    }
    return sz;
  };
  dfs(1, 0);
  function<void(int, int, int)> get = [&](int node, int boss, int parent)
  {
    par[node] = parent;
    head[node] = boss;
    pos[node] = ++p;
    down[head[node]] = node;
    if (heavy[node] != -1)
    {
      get(heavy[node], boss, node);
    }
    for (auto i : g[node])
    {
      if (i != parent && i != heavy[node])
      {
        get(i, i, node);
      }
    }
  };
  get(1, 1, 0);
}

void initialize(int N, vector<int> A, vector<int> B)
{
  n = N;
  g = vector<vector<int>>(n + 1);
  a = vector<int>(n + 1, 2);
  prv = sum = vector<vector<int>>(n + 1, vector<int>(2));
  for (int i = 0; i <= n - 2; ++i)
  {
    g[A[i]].push_back(B[i]);
    g[B[i]].push_back(A[i]);
  }
  init_hld();
  tree.init(n);
  for (int i = 1; i <= n; ++i)
  {
    tree.update(1, 1, n, i, {0, 0});
  }
}
vector<vector<int>> update(int node)
{
  if (a[node] == 0)
  {
    tree.update(1, 1, n, pos[node], {sum[node][0], inf});
  }
  if (a[node] == 1)
  {
    tree.update(1, 1, n, pos[node], {inf, sum[node][1]});
  }
  if (a[node] == 2)
  {
    tree.update(1, 1, n, pos[node], {sum[node][0], sum[node][1]});
  }
  int qui = head[node];
  vector<vector<int>> dp_qui = tree.query(1, 1, n, pos[qui], pos[down[qui]]);
  if (qui == 1)
  {
    return dp_qui;
  }
  for (int i = 0; i < 2; ++i)
  {
    sum[par[qui]][i] -= prv[qui][i];
    int best = inf;
    for (int j = 0; j < 2; ++j)
    {
      best = min(best, dp_qui[i][j]);
      best = min(best, dp_qui[i ^ 1][j] + 1);
    }
    prv[qui][i] = best;
    sum[par[qui]][i] += prv[qui][i];
  }
  return update(par[qui]);
}

int cat(int v)
{
  a[v] = 0;
  vector<vector<int>> ans = update(v);
  int best = inf;
  for (int i = 0; i < 2; ++i)
  {
    for (int j = 0; j < 2; ++j)
    {
      best = min(best, ans[i][j]);
    }
  }
  return best;
}

int dog(int v)
{
  a[v] = 1;
  vector<vector<int>> ans = update(v);
  int best = inf;
  for (int i = 0; i < 2; ++i)
  {
    for (int j = 0; j < 2; ++j)
    {
      best = min(best, ans[i][j]);
    }
  }
  return best;
}

int neighbor(int v)
{
  a[v] = 2;
  vector<vector<int>> ans = update(v);
  int best = inf;
  for (int i = 0; i < 2; ++i)
  {
    for (int j = 0; j < 2; ++j)
    {
      best = min(best, ans[i][j]);
    }
  }
  return best;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...