This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
#include <cstdio>
#include <cassert>
#include <algorithm>
#include <vector>
#include <cstring>
#include <iostream>
#include <cmath>
#include <string>
#define FOR(i, a, b) for (int i=(a); i<(b); i++)
#define REP(i, n) FOR(i, 0, n)
#define TRACE(x) cerr << #x << " = " << x << endl
#define _ << " _ " <<
#define X first
#define Y second
using namespace std;
typedef pair<int, int> P;
typedef long long ll;
const int MAX = 3000050, MOD = 1e9 + 7;
int add(int a, int b) {
a += b;
if (a >= MOD) a -= MOD;
return a;
}
int sub(int a, int b) {
a -= b;
if (a < 0) a += MOD;
return a;
}
int mul(int a, int b) {
return (int) (((ll) a * b) % MOD);
}
int inverse(int a) {
int pot = MOD-2, ret = 1;
for (; pot; pot /= 2, a = mul(a, a))
if (pot & 1)
ret = mul(ret, a);
return ret;
}
vector <int> V[MAX];
int n, k;
void load() {
scanf("%d%d", &n, &k);
REP(i, n-1) {
int a, b;
scanf("%d%d", &a, &b); a--; b--;
V[a].push_back(b);
V[b].push_back(a);
}
}
void dfs_dist(int node, int pr, vector<int> &dist) {
if (pr != -1) dist[node] = dist[pr] + 1;
else dist[node] = 0;
for (auto it : V[node]) if (it != pr) dfs_dist(it, node, dist);
}
bool dfs_diam(int node, int pr, int fin, vector<int> &D) {
D.push_back(node);
if (node == fin) return true;
for (auto it : V[node])
if (it != pr && dfs_diam(it, node, fin, D)) return true;
D.pop_back();
return false;
}
vector <int> get_diam() {
vector <int> dist(n);
dfs_dist(0, -1, dist);
int p1 = 0;
REP(i, n) if (dist[i] > dist[p1]) p1 = i;
dfs_dist(p1, -1, dist);
int p2 = p1;
REP(i, n) if (dist[i] > dist[p2]) p2 = i;
vector <int> D;
dfs_diam(p1, -1, p2, D);
return D;
}
int dist_to_root[MAX], ind_diam_root[MAX], height[MAX];
bool on_diam[MAX];
int dfs_from_root(int node, int pr, int ind_diam_rt, int dst) {
ind_diam_root[node] = ind_diam_rt;
dist_to_root[node] = dst;
int mx_dep = 0;
for (auto it : V[node])
if (it != pr && !on_diam[it])
mx_dep = max(mx_dep, 1 + dfs_from_root(it, node, ind_diam_rt, dst + 1));
return height[node] = mx_dep;
}
bool forb[MAX];
vector <int> not_forb;
bool irrelevant[MAX];
int get_ways(int node, int pr, int depth_left) {
assert(depth_left >= 0);
if (depth_left == 0) return 1;
if (irrelevant[node]) return 1;
int tmp = 1;
for (auto it : V[node])
if (!on_diam[it] && it != pr)
tmp = mul(tmp, get_ways(it, node, depth_left-1));
return add(tmp, 1);
}
int no_irrel=0;
int pref_mult[MAX];
int dsize;
void update_interval(int a, int b, int val) { //[, )
assert(a >= 0 && a < k && b >= 0 && b <= k);
if (a <= b) {
pref_mult[a] = mul(pref_mult[a], val);
pref_mult[b] = mul(pref_mult[b], inverse(val));
}
else {
pref_mult[0] = mul(pref_mult[0], val);
pref_mult[b] = mul(pref_mult[0], inverse(val));
pref_mult[a] = mul(pref_mult[a], val);
}
}
void dfs_ways(int node, int pr) {
if (!on_diam[node]) {
int d_l = dist_to_root[node] + ind_diam_root[node];
int d_r = dist_to_root[node] + dsize-1 - ind_diam_root[node];
int mx_l = d_l + height[node];
int mx_r = d_r + height[node];
if (mx_l + 1 >= k && mx_r + 1 < k) {
if (d_l < k) {
int ways = get_ways(node, pr, k - d_l - 1);
update_interval(ind_diam_root[node]+1, k, ways);
//TRACE(ind_diam_root[node]+1 _ k _ ways);
}
return;
}
else if (mx_l + 1 < k && mx_r + 1 >= k) {
if (d_r < k) {
int ways = get_ways(node, pr, k - d_r - 1);
update_interval(dsize % k, ind_diam_root[node], ways);
//TRACE(dsize % k _ ind_diam_root[node] _ ways);
}
return;
}
else if (mx_l + 1 >= k && mx_r + 1 >= k) {
int dep1 = -2 * MAX, dep2 = -2 * MAX;
for (auto ch : V[node]) {
if (!on_diam[ch] && ch != pr) {
if (height[ch] > dep1) {
dep2 = dep1;
dep1 = height[ch];
}
else dep2 = max(dep2, height[ch]);
}
}
if (1+dep1 + 1+dep2 + 1 >= k) {
assert(not_forb.size() <= 2);
int my_res = d_l % k;
for (auto residue : not_forb) {
if (my_res == residue) continue;
int to_next = (k - my_res + residue) % k;
//TRACE(node _ to_next _ dep1 _ dep2);
//TRACE(forb[0] _ forb[1]);
if (to_next <= 1 + dep2 && 2 * to_next + 1 <= k) forb[residue] = true; //two reds
if (min(1+dep1, to_next-1) + min(1+dep2, to_next-1) + 1 >= k) forb[residue] = true; //no reds
//TRACE("AFTER" _ forb[0] _ forb[1]);
}
}
}
}
for (auto ch : V[node])
if (!on_diam[ch] && ch != pr)
dfs_ways(ch, node);
}
int main()
{
load();
vector <int> D = get_diam();
for (auto it : D) on_diam[it] = true;
dsize = (int) D.size();
REP(i, dsize)
dfs_from_root(D[i], -1, i, 0);
// REP(i, dsize) TRACE(i _ D[i]);
if (dsize < k) {
printf("YES\n");
int cnt=1;
REP(i, n) cnt = mul(cnt, 2);
printf("%d\n", cnt);
return 0;
}
REP(i, MAX) pref_mult[i] = 1;
REP(node, n) {
if (on_diam[node]) continue;
int d_l = dist_to_root[node] + ind_diam_root[node];
int d_r = dist_to_root[node] + dsize-1 - ind_diam_root[node];
int mx_l = d_l + height[node];
int mx_r = d_r + height[node];
if (mx_l+1 < k && mx_r+1 < k) {
no_irrel++;
irrelevant[node] = true;
}
else if (mx_l+1 >= k && mx_r+1 >= k) {
int colored_l = d_l % k;
int colored_r = ((ind_diam_root[node] - dist_to_root[node]) % k + k) % k;
if (colored_l != colored_r)
forb[colored_l] = forb[colored_r] = true;
}
}
REP(i, k)
if (!forb[i]) not_forb.push_back(i);
// REP(i, k) TRACE("ASDASD" _ i _ forb[i]);
for (auto it : D)
dfs_ways(it, -1);
bool can = false;
REP(i, k) can |= !forb[i];
printf("%s\n", (can ? "YES" : "NO"));
int ways = 0;
REP(i, k) {
if (i) pref_mult[i] = mul(pref_mult[i], pref_mult[i-1]);
if (!forb[i]) {
ways = add(ways, pref_mult[i]);
//TRACE(i _ pref_mult[i]);
}
//TRACE(i _ forb[i] _ pref_mult[i]);
}
//TRACE(no_irrel);
REP(i, no_irrel) ways = mul(ways, 2);
printf("%d\n", ways);
return 0;
}
Compilation message (stderr)
wells.cpp: In function 'void load()':
wells.cpp:53:8: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
53 | scanf("%d%d", &n, &k);
| ~~~~~^~~~~~~~~~~~~~~~
wells.cpp:56:10: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
56 | scanf("%d%d", &a, &b); a--; b--;
| ~~~~~^~~~~~~~~~~~~~~~
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |