제출 #947360

#제출 시각아이디문제언어결과실행 시간메모리
947360GrindMachineStar Trek (CEOI20_startrek)C++17
65 / 100
90 ms32852 KiB
#include <bits/stdc++.h> #include <ext/pb_ds/assoc_container.hpp> #include <ext/pb_ds/tree_policy.hpp> using namespace std; using namespace __gnu_pbds; template<typename T> using Tree = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>; typedef long long int ll; typedef long double ld; typedef pair<int,int> pii; typedef pair<ll,ll> pll; #define fastio ios_base::sync_with_stdio(false); cin.tie(NULL) #define pb push_back #define endl '\n' #define sz(a) (int)a.size() #define setbits(x) __builtin_popcountll(x) #define ff first #define ss second #define conts continue #define ceil2(x,y) ((x+y-1)/(y)) #define all(a) a.begin(), a.end() #define rall(a) a.rbegin(), a.rend() #define yes cout << "Yes" << endl #define no cout << "No" << endl #define rep(i,n) for(int i = 0; i < n; ++i) #define rep1(i,n) for(int i = 1; i <= n; ++i) #define rev(i,s,e) for(int i = s; i >= e; --i) #define trav(i,a) for(auto &i : a) template<typename T> void amin(T &a, T b) { a = min(a,b); } template<typename T> void amax(T &a, T b) { a = max(a,b); } #ifdef LOCAL #include "debug.h" #else #define debug(x) 42 #endif /* refs: edi https://youtu.be/Tjv78ZThV5c */ const int MOD = 1e9 + 7; const int N = 1e5 + 5; const int inf1 = int(1e9) + 5; const ll inf2 = ll(1e18) + 5; struct Matrix { vector<vector<ll>> a; int n, m; Matrix() { } Matrix(int row, int col) { n = row, m = col; a = vector<vector<ll>>(row, vector<ll>(col)); } Matrix operator*(const Matrix &mat2) { int n2 = mat2.n, m2 = mat2.m; Matrix res(n, m2); rep(i, n) { rep(j, m2) { rep(k, m) { ll temp = (a[i][k] * mat2.a[k][j]) % MOD; res.a[i][j] = (res.a[i][j] + temp) % MOD; } } } return res; } void exp(ll b) { Matrix res(n, m); Matrix curr = *this; rep(i, n) res.a[i][i] = 1; while (b) { if (b & 1) res = res * curr; curr = curr * curr; b /= 2; } a = res.a; } }; vector<ll> adj[N]; vector<ll> dp1(N), dp2(N); void dfs1(ll u, ll p){ dp1[u] = 0; trav(v,adj[u]){ if(v == p) conts; dfs1(v,u); dp1[u] += (dp1[v] == 0); } } void dfs2(ll u, ll p){ trav(v,adj[u]){ if(v == p) conts; ll val = dp2[u]-(dp1[v] == 0); dp2[v] += (val == 0); dfs2(v,u); } } ll dp3[N][2], dp4[N][2]; vector<ll> dp5(N); vector<ll> win_cnt(N), lose_cnt(N); ll win_sum[N][2], lose_sum[N][2]; void dfs3(ll u, ll p){ dp3[u][0] = 1; trav(v,adj[u]){ if(v == p) conts; dfs3(v,u); } vector<ll> win,lose; trav(v,adj[u]){ if(v == p) conts; if(dp1[v]){ win.pb(v); win_cnt[u]++; rep(j,2){ win_sum[u][j] += dp3[v][j]; } } else{ lose.pb(v); lose_cnt[u]++; rep(j,2){ lose_sum[u][j] += dp3[v][j]; } } } if(sz(lose) == 0){ rep(j,2){ dp3[u][j] += win_sum[u][j^1]; } } else if(sz(lose) == 1){ rep(j,2){ dp3[u][j] += lose_sum[u][j^1]; } } rep(j,2){ dp4[u][j] = dp3[u][j]; } } void dfs4(ll u, ll p){ trav(v,adj[u]){ if(v == p) conts; ll win = win_cnt[u], lose = lose_cnt[u]; array<ll,2> wsum,lsum; wsum.fill(0), lsum.fill(0); rep(j,2){ wsum[j] = win_sum[u][j]; lsum[j] = lose_sum[u][j]; } if(dp1[v]){ win--; rep(j,2){ wsum[j] -= dp3[v][j]; } } else{ lose--; rep(j,2){ lsum[j] -= dp3[v][j]; } } array<ll,2> dpu; dpu.fill(0); dpu[0] = 1; if(lose == 0){ rep(j,2){ dpu[j] += wsum[j^1]; } } else if(lose == 1){ rep(j,2){ dpu[j] += lsum[j^1]; } } ll val = dp2[u]; val -= (dp1[v] == 0); if(val){ win_cnt[v]++; rep(j,2){ win_sum[v][j] += dpu[j]; } } else{ lose_cnt[v]++; rep(j,2){ lose_sum[v][j] += dpu[j]; } } dp4[v][0] = 1, dp4[v][1] = 0; if(lose_cnt[v] == 0){ rep(j,2){ dp4[v][j] += win_sum[v][j^1]; } } else if(lose_cnt[v] == 1){ rep(j,2){ dp4[v][j] += lose_sum[v][j^1]; } } dfs4(v,u); } } void dfs5(ll u, ll p, ll depth, ll r){ if(!dp1[u]){ ll val = 1; if(depth&1) val = 0; dp5[r] -= dp2[r]; dp5[r] += val; } vector<ll> win,lose; trav(v,adj[u]){ if(v == p) conts; if(dp1[v]) win.pb(v); else lose.pb(v); } if(sz(lose) == 0){ trav(v,win){ dfs5(v,u,depth+1,r); } } else if(sz(lose) == 1){ dfs5(lose[0],u,depth+1,r); } } void solve(int test_case) { ll n,d; cin >> n >> d; rep1(i,n-1){ ll u,v; cin >> u >> v; adj[u].pb(v), adj[v].pb(u); } dfs1(1,-1); rep1(i,n) dp2[i] = dp1[i]; dfs2(1,-1); dfs3(1,-1); dfs4(1,-1); rep1(i,n){ amin(dp1[i],1ll); amin(dp2[i],1ll); } rep1(i,n){ dp5[i] = n*dp2[i]; dp5[i] -= (dp4[i][0]+dp4[i][1])*dp2[i]; dp5[i] += dp4[i][0]; } ll win_ways_w = 0, lose_ways_w = 0; rep1(r,n){ if(dp2[r]){ win_ways_w += n; } else{ lose_ways_w += n; } } ll win_ways_l = 0, lose_ways_l = 0; rep1(r,n){ win_ways_l += dp5[r]; lose_ways_l += n-dp5[r]; } Matrix base(1,2); rep1(i,n){ base.a[0][dp2[i]]++; } Matrix mat(2,2); mat.a = { {lose_ways_l, win_ways_l}, {lose_ways_w, win_ways_w} }; mat.exp(d-1); base = base*mat; ll ans = 0; if(dp2[1]){ ans += n*base.a[0][1]; } ans += dp5[1]*base.a[0][0]; ans %= MOD; cout << ans << endl; } int main() { fastio; int t = 1; // cin >> t; rep1(i, t) { solve(i); } return 0; }

컴파일 시 표준 에러 (stderr) 메시지

startrek.cpp: In member function 'Matrix Matrix::operator*(const Matrix&)':
startrek.cpp:76:13: warning: unused variable 'n2' [-Wunused-variable]
   76 |         int n2 = mat2.n, m2 = mat2.m;
      |             ^~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...