Submission #473940

#TimeUsernameProblemLanguageResultExecution timeMemory
473940model_codeWells (CEOI21_wells)C++17
100 / 100
3790 ms256344 KiB
#include <cstdio>
#include <cassert>
#include <algorithm>
#include <vector>
#include <cstring>
#include <iostream>
#include <cmath>
#include <string>

#define FOR(i, a, b) for (int i=(a); i<(b); i++)
#define REP(i, n) FOR(i, 0, n)
#define TRACE(x) cerr << #x << " = " << x << endl
#define _ << " _ " <<
#define X first
#define Y second

using namespace std;

typedef pair<int, int> P;
typedef long long ll;

const int MAX = 3000050, MOD = 1e9 + 7;

int add(int a, int b) {
  a += b;
  if (a >= MOD) a -= MOD;
  return a;
}

int sub(int a, int b) {
  a -= b;
  if (a < 0) a += MOD;
  return a;
}

int mul(int a, int b) {
  return (int) (((ll) a * b) % MOD);
}

int inverse(int a) {
  int pot = MOD-2, ret = 1;
  for (; pot; pot /= 2, a = mul(a, a))
    if (pot & 1)
      ret = mul(ret, a);

  return ret;
}

vector <int> V[MAX];
int n, k;

void load() {
  scanf("%d%d", &n, &k);
  REP(i, n-1) {
    int a, b;
    scanf("%d%d", &a, &b); a--; b--;
    V[a].push_back(b);
    V[b].push_back(a);
  }
}

void dfs_dist(int node, int pr, vector<int> &dist) {
  if (pr != -1) dist[node] = dist[pr] + 1;
  else dist[node] = 0;

  for (auto it : V[node]) if (it != pr) dfs_dist(it, node, dist);
}

bool dfs_diam(int node, int pr, int fin, vector<int> &D) {
  D.push_back(node);
  if (node == fin) return true;

  for (auto it : V[node])
    if (it != pr && dfs_diam(it, node, fin, D)) return true;

  D.pop_back();
  return false;
}

vector <int> get_diam() {
  vector <int> dist(n);

  dfs_dist(0, -1, dist);
  int p1 = 0;
  REP(i, n) if (dist[i] > dist[p1]) p1 = i;

  dfs_dist(p1, -1, dist);
  int p2 = p1;
  REP(i, n) if (dist[i] > dist[p2]) p2 = i;

  vector <int> D;
  dfs_diam(p1, -1, p2, D);

  return D;
}

int dist_to_root[MAX], ind_diam_root[MAX], height[MAX];
bool on_diam[MAX];

int dfs_from_root(int node, int pr, int ind_diam_rt, int dst) {
  ind_diam_root[node] = ind_diam_rt;
  dist_to_root[node] = dst;

  int mx_dep = 0;
  for (auto it : V[node])
    if (it != pr && !on_diam[it])
      mx_dep = max(mx_dep, 1 + dfs_from_root(it, node, ind_diam_rt, dst + 1));

  return height[node] = mx_dep;
}

bool forb[MAX];
vector <int> not_forb;
bool irrelevant[MAX];

int get_ways(int node, int pr, int depth_left) {
  assert(depth_left >= 0);
  if (depth_left == 0) return 1;
  if (irrelevant[node]) return 1;

  int tmp = 1;
  for (auto it : V[node])
    if (!on_diam[it] && it != pr)
      tmp = mul(tmp, get_ways(it, node, depth_left-1));

  return add(tmp, 1);
}

int no_irrel=0;
int pref_mult[MAX];
int dsize;

void update_interval(int a, int b, int val) { //[, )
  assert(a >= 0 && a < k && b >= 0 && b <= k);
  if (a <= b) {
    pref_mult[a] = mul(pref_mult[a], val);
    pref_mult[b] = mul(pref_mult[b], inverse(val));
  }
  else {
    pref_mult[0] = mul(pref_mult[0], val);
    pref_mult[b] = mul(pref_mult[0], inverse(val));
    pref_mult[a] = mul(pref_mult[a], val);
  }
}

void dfs_ways(int node, int pr) {
  if (!on_diam[node]) {
    int d_l = dist_to_root[node] + ind_diam_root[node];
    int d_r = dist_to_root[node] + dsize-1 - ind_diam_root[node];
    int mx_l = d_l + height[node];
    int mx_r = d_r + height[node];

    if (mx_l + 1 >= k && mx_r + 1 < k) {
      if (d_l < k) {
        int ways = get_ways(node, pr, k - d_l - 1);
        update_interval(ind_diam_root[node]+1, k, ways);
        //TRACE(ind_diam_root[node]+1 _ k _ ways);
      }
      return;
    }
    else if (mx_l + 1 < k && mx_r + 1 >= k) {
      if (d_r < k) {
        int ways = get_ways(node, pr, k - d_r - 1);
        update_interval(dsize % k, ind_diam_root[node], ways);
        //TRACE(dsize % k _ ind_diam_root[node] _ ways);
      }
      return;
    }
    else if (mx_l + 1 >= k && mx_r + 1 >= k) {

      int dep1 = -2 * MAX, dep2 = -2 * MAX;
      for (auto ch : V[node]) {
        if (!on_diam[ch] && ch != pr) {
          if (height[ch] > dep1) {
            dep2 = dep1;
            dep1 = height[ch];
          }
          else dep2 = max(dep2, height[ch]);
        }
      }

      if (1+dep1 + 1+dep2 + 1 >= k) {
        assert(not_forb.size() <= 2);
        int my_res = d_l % k;

        for (auto residue : not_forb) {
          if (my_res == residue) continue;
          int to_next = (k - my_res + residue) % k;

          //TRACE(node _ to_next _ dep1 _ dep2);
          //TRACE(forb[0] _ forb[1]);
          if (to_next <= 1 + dep2 && 2 * to_next + 1 <= k) forb[residue] = true; //two reds
          if (min(1+dep1, to_next-1) + min(1+dep2, to_next-1) + 1 >= k) forb[residue] = true; //no reds


          //TRACE("AFTER" _ forb[0] _ forb[1]);
        }
      }
    }
  }

  for (auto ch : V[node])
    if (!on_diam[ch] && ch != pr)
      dfs_ways(ch, node);
}

int main()
{
  load();

  vector <int> D = get_diam();
  for (auto it : D) on_diam[it] = true;

  dsize = (int) D.size();
  REP(i, dsize)
    dfs_from_root(D[i], -1, i, 0);

//  REP(i, dsize) TRACE(i _ D[i]);

  if (dsize < k) {
    printf("YES\n");
    int cnt=1;
    REP(i, n) cnt = mul(cnt, 2);
    printf("%d\n", cnt);
    return 0;
  }

  REP(i, MAX) pref_mult[i] = 1;


  REP(node, n) {
    if (on_diam[node]) continue;

    int d_l = dist_to_root[node] + ind_diam_root[node];
    int d_r = dist_to_root[node] + dsize-1 - ind_diam_root[node];
    int mx_l = d_l + height[node];
    int mx_r = d_r + height[node];

    if (mx_l+1 < k && mx_r+1 < k) {
      no_irrel++;
      irrelevant[node] = true;
    }
    else if (mx_l+1 >= k && mx_r+1 >= k) {
      int colored_l = d_l % k;
      int colored_r = ((ind_diam_root[node] - dist_to_root[node]) % k + k) % k;

      if (colored_l != colored_r)
        forb[colored_l] = forb[colored_r] = true;
    }
  }

  REP(i, k)
    if (!forb[i]) not_forb.push_back(i);

//  REP(i, k) TRACE("ASDASD" _ i _ forb[i]);

  for (auto it : D)
    dfs_ways(it, -1);

  bool can = false;
  REP(i, k) can |= !forb[i];

  printf("%s\n", (can ? "YES" : "NO"));
  int ways = 0;

  REP(i, k) {
    if (i) pref_mult[i] = mul(pref_mult[i], pref_mult[i-1]);
    if (!forb[i]) {
      ways = add(ways, pref_mult[i]);
      //TRACE(i _ pref_mult[i]);
    }

    //TRACE(i _ forb[i] _ pref_mult[i]);
  }

  //TRACE(no_irrel);
  REP(i, no_irrel) ways = mul(ways, 2);
  printf("%d\n", ways);

  return 0;
}

Compilation message (stderr)

wells.cpp: In function 'void load()':
wells.cpp:53:8: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
   53 |   scanf("%d%d", &n, &k);
      |   ~~~~~^~~~~~~~~~~~~~~~
wells.cpp:56:10: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
   56 |     scanf("%d%d", &a, &b); a--; b--;
      |     ~~~~~^~~~~~~~~~~~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...