답안 #459100

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
459100 2021-08-08T08:40:14 Z AmShZ Star Trek (CEOI20_startrek) C++11
0 / 100
2 ms 2688 KB
//khodaya khodet komak kon
# include <bits/stdc++.h>

using namespace std;

typedef long long                                        ll;
typedef long double                                      ld;
typedef pair <int, int>                                  pii;
typedef pair <pii, int>                                  ppi;
typedef pair <int, pii>                                  pip;
typedef pair <pii, pii>                                  ppp;
typedef pair <ll, ll>                                    pll;

# define A                                               first
# define B                                               second
# define endl                                            '\n'
# define sep                                             ' '
# define all(x)                                          x.begin(), x.end()
# define kill(x)                                         return cout << x << endl, 0
# define SZ(x)                                           int(x.size())
# define lc                                              id << 1
# define rc                                              id << 1 | 1
# define fast_io                                         ios::sync_with_stdio(0);cin.tie(0); cout.tie(0);

ll power(ll a, ll b, ll md) {return (!b ? 1 : (b & 1 ? a * power(a * a % md, b / 2, md) % md : power(a * a % md, b / 2, md) % md));}

const int xn = 1e5 + 10;
const int xm = - 20 + 10;
const int sq = 320;
const int inf = 1e9 + 10;
const ll INF = 1e18 + 10;
const ld eps = 1e-15;
const int mod = 1e9 + 7;//998244353;
const int base = 257;

int n, dp[2][xn], sz[2][xn], par[xn], ans;
bool f[2][xn];
ll D;
vector <int> adj[xn];

void DFS(int v, int p = - 1){
	sz[0][v] = 1;
	int ted = 0;
	for (int u : adj[v]){
		if (u == p)
			continue;
		par[u] = v;
		DFS(u, v);
		f[0][v] |= !(f[0][u]);
		sz[0][v] += sz[0][u];
		ted += !f[0][u];
	}
	dp[0][v] = sz[0][v];
	for (int u : adj[v])
		if (u != p && ted == !f[0][u])
			dp[0][v] -= dp[0][u];
}
void DFS2(int v, int p = - 1){
	int ted = !f[1][v];
	for (int u : adj[v])
		if (u != p)
			ted += !f[0][u], sz[1][v] += sz[0][u];
	++ sz[1][v];
	int sum0 = 0, sum1 = 0;
	for (int u : adj[v]){
		if (u == p)
			continue;
		sz[1][u] = sz[1][v] - sz[0][u];
		dp[1][u] = sz[1][u];
		f[1][u] = (0 < ted - !f[0][u]);
		if (f[0][u])
			sum1 += dp[0][u];
		else
			sum0 += dp[0][u];
	}
	if (f[1][v])
		sum1 += dp[1][v];
	else
		sum0 += dp[1][v];
	for (int u : adj[v]){
		if (u == p)
			continue;
		if (f[0][u])
			sum1 -= dp[0][u];
		else
			sum0 -= dp[0][u];
		if (ted - !f[0][u] == 0)
			dp[1][u] -= sum1;
		else if (ted - !f[0][u] == 1)
			dp[1][u] -= sum0;
		if (f[0][u])
			sum1 += dp[0][u];
		else
			sum0 += dp[0][u];
		DFS2(u, v);
	}
}

struct matrix{
    int a[2][2];
    matrix operator * (const matrix &t){
        matrix r;
        for (int i = 0; i < 2; ++ i)
        	for (int j = 0; j < 2; ++ j)
        		for (int k = 0; k < 2; ++ k)
        			r.a[i][k] = (r.a[i][k] + 1ll * a[i][j] * t.a[!j][k] % mod) % mod;
        return r;
    }
};
matrix Pow(matrix a, ll b){
    matrix res = a;
    if (b < 0)
    	return res;
    for (; b; b = b >> 1){
        if (b & 1)
            res = res * a;
        a = a * a;
    }
    return res;
}

int main(){
	fast_io;

	cin >> n >> D;
	for (int i = 0; i < n - 1; ++ i){
		int v, u;
		cin >> v >> u;
		adj[v].push_back(u);
		adj[u].push_back(v);
	}
	DFS(1);
	f[1][1] = 1;
	DFS2(1);
	matrix M;
	for (int i = 0; i < 2; ++ i)
		for (int j = 0; j < 2; ++ j)
			M.a[i][j] = 0;
	pii last;
	for (int v = 1; v <= n; ++ v){
		bool fl = !f[1][v];
		int ted = !f[1][v];
		for (int u : adj[v]){
			if (u == par[v])
				continue;
			fl |= !f[0][u];
			ted += !f[0][u];
		}
		int val = n;
		if (ted == !f[1][v])
			val -= dp[1][v];
		for (int u : adj[v])
			if (u != par[v] && ted == !f[0][u])
				val -= dp[0][u];
		M.a[1][1] = (M.a[1][1] + val) % mod;
		M.a[0][1] += (M.a[0][1] + n - val) % mod;
		if (fl)
			M.a[1][0] = (M.a[1][0] + n) % mod;
		else
			M.a[0][0] = (M.a[0][0] + n) % mod;
		if (v == 1)
			last = {fl, val};
	}
	M = Pow(M, D - 1);
	if (last.A)
		ans = (ans + M.a[1][0]) % mod;
	ans = (ans + 1ll * last.B * M.a[0][0] % mod) % mod;
	ans = 1ll * ans * power(n, mod - 2, mod) % mod;
	cout << ans << endl;

	return 0;
}
# 결과 실행 시간 메모리 Grader output
1 Correct 2 ms 2688 KB Output is correct
2 Incorrect 2 ms 2636 KB Output isn't correct
# 결과 실행 시간 메모리 Grader output
1 Incorrect 2 ms 2636 KB Output isn't correct
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 2 ms 2636 KB Output is correct
2 Incorrect 2 ms 2636 KB Output isn't correct
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 2 ms 2636 KB Output is correct
2 Incorrect 2 ms 2636 KB Output isn't correct
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 2 ms 2636 KB Output is correct
2 Incorrect 2 ms 2636 KB Output isn't correct
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 2 ms 2636 KB Output is correct
2 Incorrect 2 ms 2636 KB Output isn't correct
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 2 ms 2636 KB Output is correct
2 Incorrect 2 ms 2636 KB Output isn't correct
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 2 ms 2688 KB Output is correct
2 Incorrect 2 ms 2636 KB Output isn't correct