제출 #975321

#제출 시각아이디문제언어결과실행 시간메모리
975321efedmrlrStar Trek (CEOI20_startrek)C++17
45 / 100
58 ms14420 KiB
#include <bits/stdc++.h>

#define lli long long int
#define ld long double
#define pb push_back
#define MP make_pair
#define all(x) x.begin(), x.end()
#define rall(x) x.rbegin(), x.rend()
#define REP(i, n) for(int i = 0; (i) < (n); (i)++)

using namespace std;

void fastio() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
}

const int N = 1e5 + 5;
const int INF = 1e9 + 500;
const int MOD = 1e9 + 7;
const int B = 2;
int add(int x, int y) {
    if(x + y >= MOD) return x + y - MOD;
    return x + y;
}
int mult(int x, int y) {
    return (int)((1ll * x * y) % MOD);
}
int subt(int x, int y) {
    if(x - y < 0) return x - y + MOD;
    return x - y;
}
int fp(int x, lli y) {
    int ret = 1;
    while(y > 0ll) {
        if(y & 1ll) {
            ret = mult(ret, x);
        }
        x = mult(x, x);
        y /= 2ll;
    }
    return ret;
}

struct Matrix {
    array<array<int, B>, B> mat;
    Matrix() {
        REP(i, B) REP(j, B) mat[i][j] = 0;
    }
};
Matrix mult(Matrix x, Matrix y) {
    Matrix ret;
    REP(i, 2) REP(j, 2) REP(k, 2) {
        ret.mat[i][j] = add(ret.mat[i][j], mult(x.mat[i][k], y.mat[k][j]));
    }
    return ret;
}
Matrix fp(Matrix x, lli y) {
    Matrix ret;
    REP(i, B) ret.mat[i][i] = 1;
    while(y > 0ll) {
        if(y & 1) {
            ret = mult(ret, x);
        }
        y /= 2ll;
        x = mult(x, x);
    }
    return ret;
}
int n;
lli d;
vector<vector<int> > adj(N, vector<int>());
vector<int> dp(N, 0), dpr(N, 0);
vector<int> dpcrit(N, 0), dpc(N, 0);
vector<int> wc(N, 0), lc(N, 0);
int L = 0;
int CL = 0, CW = 0;
void dfs1(int node, int par) {
    for(auto c : adj[node]) {
        if(c == par) continue;
        dfs1(c, node);
        if(!dpr[c]) dpr[node]++;
    }

}
void dfs2(int node, int par) {
    if(!dpr[node]) {
        dpcrit[node] = 1;
    }
    for(auto c : adj[node]) {
        if(c == par) continue;
        dfs2(c, node);
        if(dpr[c]) wc[node] += dpcrit[c];
        else lc[node] += dpcrit[c];
        if(dpr[node] == 1 && !dpr[c]) {
            dpcrit[node] += dpcrit[c];
        } 
        if(!dpr[node] && dpr[c]) {
            dpcrit[node] += dpcrit[c];
        }
    }
    
}
void change_root(int p, int x) {
    if(!dpr[x]) {
        dpr[p]--;
        lc[p] -= dpcrit[x];
    }
    else {
        wc[p] -= dpcrit[x];
    }
    if(dpr[p] >= 2) {
        dpcrit[p] = 0;
    }
    else if(dpr[p] == 1) {
        dpcrit[p] = lc[p];
    }
    else {
        dpcrit[p] = wc[p] + 1;
    }

    if(!dpr[p]) {
        dpr[x]++;
        lc[x] += dpcrit[p];
    }
    else {
        wc[x] += dpcrit[p];
    }
    if(dpr[x] >= 2) {
        dpcrit[x] = 0;
    }
    else if(dpr[x] == 1) {
        dpcrit[x] = lc[x];
    }
    else {
        dpcrit[x] = wc[x] + 1;
    }

}

void reroot(int node, int par) {
    dp[node] = dpr[node];
    dpc[node] = dpcrit[node];
    for(auto c : adj[node]) {
        if(c == par) continue;
        change_root(node, c);
        reroot(c, node);
        change_root(c, node); 
    }

}

void solve() {
    cin >> n >> d;
    REP(i, n - 1) {
        int u, v;
        cin >> u >> v;
        adj[u].pb(v);
        adj[v].pb(u);
    }
    dfs1(1, 0);
    dfs2(1, 0);
    // for(int i = 1; i <= n; i++) {
    //     cout << "i:" << i << " dpr:" << dpr[i] << " crit:" << dpcrit[i] << "\n";
    // }
    reroot(1, 0);
    // for(int i = 1; i <= n; i++) {
    //     cout << "i:" << i << " dp:" << dp[i] << " c:" << dpc[i] << "\n";
    // }
    for(int i = 1; i <= n; i++) {
        if(!dp[i]) {
            L++;
            CL = add(CL, dpc[i]);
        }
        else {
            CW = add(CW, dpc[i]);
        }
    }
    // cout << "L:" << L << " CL:" << CL << " CW:" << CW << "\n";
    // int LD = L, LDN = 0;
    // for(int i = 1; i < d; i++) {
    //     LDN = add(mult(LD, subt(CW, CL)), mult(L, fp(n, 2ll * i)));
    //     swap(LD, LDN);
    // }
    Matrix rel;
    rel.mat[0] = {subt(CW, CL), 0};
    rel.mat[1] = {1, mult(n, n)};
    Matrix st;
    st.mat[0][0] = L;
    st.mat[1][1] = mult(L, mult(n, n));
    st = mult(st, fp(rel, d - 1));
    int ans = 0;
    if(dp[1]) {
        ans = mult(dpc[1], st.mat[0][0]);
    }
    else {
        ans = subt(fp(n, 2ll * d), mult(dpc[1], st.mat[0][0]));
    }
    ans = subt(fp(n, 2ll * d), ans);
    cout << ans << "\n";
}

signed main() {
    fastio();
    solve();
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...