제출 #947361

#제출 시각아이디문제언어결과실행 시간메모리
947361GrindMachineStar Trek (CEOI20_startrek)C++17
100 / 100
88 ms32744 KiB
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>

using namespace std;
using namespace __gnu_pbds;

template<typename T> using Tree = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;
typedef long long int ll;
typedef long double ld;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;

#define fastio ios_base::sync_with_stdio(false); cin.tie(NULL)
#define pb push_back
#define endl '\n'
#define sz(a) (int)a.size()
#define setbits(x) __builtin_popcountll(x)
#define ff first
#define ss second
#define conts continue
#define ceil2(x,y) ((x+y-1)/(y))
#define all(a) a.begin(), a.end()
#define rall(a) a.rbegin(), a.rend()
#define yes cout << "Yes" << endl
#define no cout << "No" << endl

#define rep(i,n) for(int i = 0; i < n; ++i)
#define rep1(i,n) for(int i = 1; i <= n; ++i)
#define rev(i,s,e) for(int i = s; i >= e; --i)
#define trav(i,a) for(auto &i : a)

template<typename T>
void amin(T &a, T b) {
    a = min(a,b);
}

template<typename T>
void amax(T &a, T b) {
    a = max(a,b);
}

#ifdef LOCAL
#include "debug.h"
#else
#define debug(x) 42
#endif

/*

refs:
edi
https://youtu.be/Tjv78ZThV5c

*/

const int MOD = 1e9 + 7;
const int N = 1e5 + 5;
const int inf1 = int(1e9) + 5;
const ll inf2 = ll(1e18) + 5;

struct Matrix {
    vector<vector<ll>> a;
    int n, m;

    Matrix() {

    }

    Matrix(int row, int col) {
        n = row, m = col;
        a = vector<vector<ll>>(row, vector<ll>(col));
    }

    Matrix operator*(const Matrix &mat2) {
        int n2 = mat2.n, m2 = mat2.m;
        Matrix res(n, m2);

        rep(i, n) {
            rep(j, m2) {
                rep(k, m) {
                    ll temp = (a[i][k] * mat2.a[k][j]) % MOD;
                    res.a[i][j] = (res.a[i][j] + temp) % MOD;
                }
            }
        }

        return res;
    }

    void exp(ll b) {
        Matrix res(n, m);
        Matrix curr = *this;
        rep(i, n) res.a[i][i] = 1;

        while (b) {
            if (b & 1) res = res * curr;
            curr = curr * curr;
            b /= 2;
        }

        a = res.a;
    }
};

vector<ll> adj[N];
vector<ll> dp1(N), dp2(N);

void dfs1(ll u, ll p){
    dp1[u] = 0;
    trav(v,adj[u]){
        if(v == p) conts;
        dfs1(v,u);
        dp1[u] += (dp1[v] == 0);
    }
}

void dfs2(ll u, ll p){
    trav(v,adj[u]){
        if(v == p) conts;
        ll val = dp2[u]-(dp1[v] == 0);
        dp2[v] += (val == 0);
        dfs2(v,u);
    }
}

ll dp3[N][2], dp4[N][2];
vector<ll> dp5(N);
vector<ll> win_cnt(N), lose_cnt(N);
ll win_sum[N][2], lose_sum[N][2];

void dfs3(ll u, ll p){
    dp3[u][0] = 1;

    trav(v,adj[u]){
        if(v == p) conts;
        dfs3(v,u);
    }

    vector<ll> win,lose;
    trav(v,adj[u]){
        if(v == p) conts;
        if(dp1[v]){
            win.pb(v);
            win_cnt[u]++;
            rep(j,2){
                win_sum[u][j] += dp3[v][j];
            }
        }
        else{
            lose.pb(v);
            lose_cnt[u]++;
            rep(j,2){
                lose_sum[u][j] += dp3[v][j];
            }
        }
    }

    if(sz(lose) == 0){
        rep(j,2){
            dp3[u][j] += win_sum[u][j^1];
        }
    }
    else if(sz(lose) == 1){
        rep(j,2){
            dp3[u][j] += lose_sum[u][j^1];
        }
    }

    rep(j,2){
        dp4[u][j] = dp3[u][j];
    }
}

void dfs4(ll u, ll p){
    trav(v,adj[u]){
        if(v == p) conts;

        ll win = win_cnt[u], lose = lose_cnt[u];
        array<ll,2> wsum,lsum;
        wsum.fill(0), lsum.fill(0);

        rep(j,2){
            wsum[j] = win_sum[u][j];
            lsum[j] = lose_sum[u][j];
        }

        if(dp1[v]){
            win--;
            rep(j,2){
                wsum[j] -= dp3[v][j];
            }
        }
        else{
            lose--;
            rep(j,2){
                lsum[j] -= dp3[v][j];
            }
        }

        array<ll,2> dpu;
        dpu.fill(0);
        dpu[0] = 1;

        if(lose == 0){
            rep(j,2){
                dpu[j] += wsum[j^1];
            }
        }
        else if(lose == 1){
            rep(j,2){
                dpu[j] += lsum[j^1];
            }
        }

        ll val = dp2[u];
        val -= (dp1[v] == 0);

        if(val){
            win_cnt[v]++;
            rep(j,2){
                win_sum[v][j] += dpu[j];
            }
        }
        else{
            lose_cnt[v]++;
            rep(j,2){
                lose_sum[v][j] += dpu[j]; 
            }
        }

        dp4[v][0] = 1, dp4[v][1] = 0;

        if(lose_cnt[v] == 0){
            rep(j,2){
                dp4[v][j] += win_sum[v][j^1];
            }
        }
        else if(lose_cnt[v] == 1){
            rep(j,2){
                dp4[v][j] += lose_sum[v][j^1];
            }
        }

        dfs4(v,u);
    }
}

void dfs5(ll u, ll p, ll depth, ll r){
    if(!dp1[u]){
        ll val = 1;
        if(depth&1) val = 0;
        dp5[r] -= dp2[r];
        dp5[r] += val;
    }

    vector<ll> win,lose;
    trav(v,adj[u]){
        if(v == p) conts;
        if(dp1[v]) win.pb(v);
        else lose.pb(v);
    }

    if(sz(lose) == 0){
        trav(v,win){
            dfs5(v,u,depth+1,r);
        }
    }
    else if(sz(lose) == 1){
        dfs5(lose[0],u,depth+1,r);
    }
}

void solve(int test_case)
{
    ll n,d; cin >> n >> d;
    rep1(i,n-1){
        ll u,v; cin >> u >> v;
        adj[u].pb(v), adj[v].pb(u);
    }

    dfs1(1,-1);
    rep1(i,n) dp2[i] = dp1[i];
    dfs2(1,-1);

    dfs3(1,-1);
    dfs4(1,-1);

    rep1(i,n){
        amin(dp1[i],1ll);
        amin(dp2[i],1ll);
    }

    rep1(i,n){
        dp5[i] = n*dp2[i];
        dp5[i] -= (dp4[i][0]+dp4[i][1])*dp2[i];
        dp5[i] += dp4[i][0];
    }

    ll win_ways_w = 0, lose_ways_w = 0;
    rep1(r,n){
        if(dp2[r]){
            win_ways_w += n;
        }
        else{
            lose_ways_w += n;
        }
    }

    ll win_ways_l = 0, lose_ways_l = 0;
    rep1(r,n){
        win_ways_l += dp5[r];
        lose_ways_l += n-dp5[r];
    }

    Matrix base(1,2);
    rep1(i,n){
        base.a[0][dp2[i]]++;
    }

    Matrix mat(2,2);
    mat.a = {
        {lose_ways_l%MOD, win_ways_l%MOD},
        {lose_ways_w%MOD, win_ways_w%MOD}
    };

    mat.exp(d-1);
    base = base*mat;

    ll ans = 0;
    if(dp2[1]){
        ans += n*base.a[0][1];
    }
    
    ans += dp5[1]*base.a[0][0];
    ans %= MOD;

    cout << ans << endl;
}

int main()
{
    fastio;

    int t = 1;
    // cin >> t;

    rep1(i, t) {
        solve(i);
    }

    return 0;
}

컴파일 시 표준 에러 (stderr) 메시지

startrek.cpp: In member function 'Matrix Matrix::operator*(const Matrix&)':
startrek.cpp:76:13: warning: unused variable 'n2' [-Wunused-variable]
   76 |         int n2 = mat2.n, m2 = mat2.m;
      |             ^~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...