Submission #556297

#TimeUsernameProblemLanguageResultExecution timeMemory
556297FatihSolakStar Trek (CEOI20_startrek)C++17
100 / 100
131 ms41624 KiB
#include <bits/stdc++.h> #define N 100005 using namespace std; const int mod = 1e9 + 7; long long binpow(long long a,long long b){ long long ret = 1; while(b){ if(b & 1) ret = ret * a %mod; a = a * a %mod; b >>=1; } return ret; } struct node{ long long a,b,sz,sz2; node(){ a = b = sz = sz2 = 0; } }; struct matrix{ long long a[3][3]; matrix(){ for(int i = 0;i<3;i++){ for(int j = 0;j<3;j++){ a[i][j] = 0; } } } matrix operator*(matrix other){ matrix ret; for(int i = 0;i<3;i++){ for(int j =0;j<3;j++){ for(int k = 0;k<3;k++){ ret.a[i][k] = (ret.a[i][k] + a[i][j]*other.a[j][k])%mod; } } } return ret; } }; matrix binexpo(matrix a,long long b){ matrix ret; for(int i = 0;i<3;i++) ret.a[i][i] = 1; while(b){ if(b&1) ret = ret * a; a = a * a; b>>=1; } return ret; } vector<int> adj[N]; long long ans = 0; long long n,d; int win[N]; int sub[N]; int top[N]; int len[N]; int sum = 0; int timer = 1; int tin[N],tout[N]; bool dfs(int v,int par){ tin[v] = timer++; sub[v] = 0; for(auto u:adj[v]){ if(u == par)continue; if(dfs(u,v) == 0){ win[v] = 1; sub[v] = 1; } } tout[v] = timer - 1; len[v] = tout[v] - tin[v] + 1; return sub[v]; } void dfs2(int v,int par){ int cnt = !top[v]; for(auto u:adj[v]){ if(u == par)continue; cnt += !sub[u]; } for(auto u:adj[v]){ if(u == par)continue; cnt -= !sub[u]; if(!cnt){ win[u] = 1; top[u] = 0; } cnt += !sub[u]; dfs2(u,v); } } node subval[N]; node val[N]; node topval[N]; void merge(node &a,node b){ a.a += b.b; a.b += b.a; a.sz += b.sz2; a.sz2 += b.sz; a.sz %= mod; a.sz2 %= mod; } void antimerge(node &a,node b){ a.a -= b.b; a.b -= b.a; a.sz -= b.sz2; a.sz2 -= b.sz; a.sz = (a.sz + mod)%mod; a.sz2 = (a.sz2 + mod)%mod; } void dfs3(int v,int par){ vector<int> places; subval[v] = node(); for(auto u:adj[v]){ if(u == par)continue; if(!sub[u]){ places.push_back(u); } dfs3(u,v); } if(places.empty()){ subval[v].b++; for(auto u:adj[v]){ if(u == par)continue; merge(subval[v],subval[u]); } } if(places.size() == 1){ subval[v].sz += n * (len[v] - len[places[0]]); subval[v].sz %= mod; merge(subval[v],subval[places[0]]); } if(places.size() > 1){ subval[v].sz += n * len[v]; subval[v].sz %= mod; } } void dfs4(int v,int par){ vector<int> places; if(!top[v]) places.push_back(par); val[v] = node(); for(auto u:adj[v]){ if(u == par)continue; if(!sub[u]){ places.push_back(u); } } if(places.empty()){ val[v].b++; merge(val[v],topval[v]); for(auto u:adj[v]){ if(u == par)continue; merge(val[v],subval[u]); } } if(places.size() == 1){ if(places[0] == par){ val[v].sz += n * len[v]; val[v].sz %= mod; merge(val[v],topval[v]); } else{ val[v].sz += n * (n - len[places[0]]); val[v].sz %= mod; merge(val[v],subval[places[0]]); } } if(places.size() > 1){ val[v].sz += n * n; val[v].sz %= mod; } node sum; merge(sum,topval[v]); for(auto u:adj[v]){ if(u == par)continue; merge(sum,subval[u]); } for(auto u:adj[v]){ if(u == par)continue; antimerge(sum,subval[u]); node tmp = sum; int x = places.size() - !sub[u]; if(x == 0){ tmp.b++; } if(x == 1){ tmp = node(); int pos = places[0]; if(pos == u) pos = places[1]; if(pos == par){ tmp.sz += n * ( len[v] - len[u]); tmp.sz %= mod; merge(tmp,topval[v]); } else{ tmp.sz += n * (n - len[pos] - len[u]); tmp.sz %= mod; merge(tmp,subval[pos]); } } if(x > 1){ tmp = node(); tmp.sz += n * (n - len[u]); tmp.sz %= mod; } topval[u] = tmp; merge(sum,subval[u]); dfs4(u,v); } } void solve(){ cin >> n >> d; for(int i = 1;i<n;i++){ int u,v; cin >> u >> v; adj[u].push_back(v); adj[v].push_back(u); } vector<node> v; long long winstate = 0; for(int i = 1;i<=n;i++){ win[i] = 0; sub[i] = 0; top[i] = 1; } dfs(1,0); dfs2(1,0); dfs3(1,0); dfs4(1,0); for(int x = 1;x<=n;x++){ v.push_back(val[x]); } long long suma = 0,sumb = 0,sumsz = 0; for(auto u:v){ suma += u.a; sumb += u.b; sumsz += u.sz; suma %= mod; sumb %= mod; sumsz %= mod; } for(int i = 1;i<=n;i++){ winstate += win[i]; } matrix single; single.a[0][0] = (suma - sumb +mod)%mod; single.a[1][0] = (sumb)%mod; single.a[2][0] = (sumsz)%mod; single.a[1][1] = n*n%mod; single.a[2][2] = n*n%mod; matrix total = binexpo(single,d-1); long long val = 0; val = (val + total.a[0][0] * winstate)%mod; val = (val + total.a[1][0] * n)%mod; val = (val + total.a[2][0] * 1)%mod; winstate = val; /* for(int i = 0;i<d-1;i++){ winstate = (winstate * (suma - sumb + mod) + binpow(n,2*i+1)*sumb + binpow(n,2*i)*sumsz)%mod; }*/ winstate = (winstate * (v[0].a - v[0].b + mod) + binpow(n,2*d-1)*v[0].b + binpow(n,2*d-2)*v[0].sz)%mod; cout << winstate; } int main(){ ios_base::sync_with_stdio(false); cin.tie(nullptr); #ifdef Local freopen("in.txt","r",stdin); freopen("out.txt","w",stdout); #endif int t = 1; //cin >> t; while(t--){ solve(); } #ifdef Local cout << endl << fixed << setprecision(2) << 1000.0 * clock() / CLOCKS_PER_SEC << " milliseconds."; #endif }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...