Submission #556297

#TimeUsernameProblemLanguageResultExecution timeMemory
556297FatihSolakStar Trek (CEOI20_startrek)C++17
100 / 100
131 ms41624 KiB
#include <bits/stdc++.h>
#define N 100005
using namespace std;
const int mod = 1e9 + 7;
long long binpow(long long a,long long b){
    long long ret = 1;
    while(b){
        if(b & 1)
            ret = ret * a %mod;
        a = a * a %mod;
        b >>=1;
    }
    return ret;
}
struct node{
    long long a,b,sz,sz2;
    node(){
        a = b = sz = sz2 = 0;
    }
};
struct matrix{
    long long a[3][3];
    matrix(){
        for(int i = 0;i<3;i++){
            for(int j = 0;j<3;j++){
                a[i][j] = 0;
            }
        }
    }
    matrix operator*(matrix other){
        matrix ret;
        for(int i = 0;i<3;i++){
            for(int j =0;j<3;j++){
                for(int k = 0;k<3;k++){
                    ret.a[i][k] = (ret.a[i][k] + a[i][j]*other.a[j][k])%mod;
                }
            }
        }
        return ret;
    }
};
matrix binexpo(matrix a,long long b){
    matrix ret;
    for(int i = 0;i<3;i++)
        ret.a[i][i] = 1;
    while(b){
        if(b&1)
            ret = ret * a;
        a = a * a;
        b>>=1;
    }
    return ret;
}
vector<int> adj[N];
long long ans = 0;
long long n,d;
int win[N];
int sub[N];
int top[N];
int len[N];
int sum = 0;
int timer = 1;
int tin[N],tout[N];
bool dfs(int v,int par){
    tin[v] = timer++;
    sub[v] = 0;
    for(auto u:adj[v]){
        if(u == par)continue;
        if(dfs(u,v) == 0){
            win[v] = 1;
            sub[v] = 1;
        }
    }
    tout[v] = timer - 1;
    len[v] = tout[v] - tin[v] + 1;
    return sub[v];
}
void dfs2(int v,int par){
    int cnt = !top[v];
    for(auto u:adj[v]){
        if(u == par)continue;
        cnt += !sub[u];
    }
    for(auto u:adj[v]){
        if(u == par)continue;
        cnt -= !sub[u];
        if(!cnt){
            win[u] = 1;
            top[u] = 0;
        }
        cnt += !sub[u];
        dfs2(u,v);
    }
}
node subval[N];
node val[N];
node topval[N];
void merge(node &a,node b){
    a.a += b.b;
    a.b += b.a;
    a.sz += b.sz2;
    a.sz2 += b.sz;
    a.sz %= mod;
    a.sz2 %= mod;
}
void antimerge(node &a,node b){
    a.a -= b.b;
    a.b -= b.a;
    a.sz -= b.sz2;
    a.sz2 -= b.sz;
    a.sz = (a.sz + mod)%mod;
    a.sz2 = (a.sz2 + mod)%mod;
}
void dfs3(int v,int par){
    vector<int> places;
    subval[v] = node();
    for(auto u:adj[v]){
        if(u == par)continue;
        if(!sub[u]){
            places.push_back(u);
        }
        dfs3(u,v);
    }
    if(places.empty()){
        subval[v].b++;
        for(auto u:adj[v]){
            if(u == par)continue;
            merge(subval[v],subval[u]);
        }
    }
    if(places.size() == 1){        
        subval[v].sz += n * (len[v] - len[places[0]]);
        subval[v].sz %= mod;
        merge(subval[v],subval[places[0]]);
    }
    if(places.size() > 1){
        subval[v].sz += n * len[v];
        subval[v].sz %= mod;
    }
}
void dfs4(int v,int par){
    vector<int> places;
    if(!top[v])
        places.push_back(par);
    val[v] = node();
    for(auto u:adj[v]){
        if(u == par)continue;
        if(!sub[u]){
            places.push_back(u);
        }
    }
    if(places.empty()){
        val[v].b++;
        merge(val[v],topval[v]);
        for(auto u:adj[v]){
            if(u == par)continue;
            merge(val[v],subval[u]);
        }
    }
    if(places.size() == 1){        
        if(places[0] == par){
            val[v].sz += n * len[v];
            val[v].sz %= mod;
            merge(val[v],topval[v]);
        }
        else{
            val[v].sz += n * (n - len[places[0]]);
            val[v].sz %= mod;
            merge(val[v],subval[places[0]]);
        }
    }
    if(places.size() > 1){
        val[v].sz += n * n;
        val[v].sz %= mod;
    }
    node sum;
    merge(sum,topval[v]);
    for(auto u:adj[v]){
        if(u == par)continue;
        merge(sum,subval[u]);
    }
    for(auto u:adj[v]){
        if(u == par)continue;
        antimerge(sum,subval[u]);
        node tmp = sum;
        int x = places.size() - !sub[u];
        if(x == 0){
            tmp.b++;
        }
        if(x == 1){
            tmp = node();
            int pos = places[0];
            if(pos == u)
                pos = places[1];
            if(pos == par){
                tmp.sz += n * ( len[v] - len[u]);
                tmp.sz %= mod;
                merge(tmp,topval[v]);
            }
            else{
                tmp.sz += n * (n - len[pos] - len[u]);
                tmp.sz %= mod;
                merge(tmp,subval[pos]);
            }
        }
        if(x > 1){
            tmp = node();
            tmp.sz += n * (n - len[u]);
            tmp.sz %= mod;
        }
        topval[u] = tmp;
        merge(sum,subval[u]);
        dfs4(u,v);
    }
}
void solve(){
    cin >> n >> d;
    for(int i = 1;i<n;i++){
        int u,v;
        cin >> u >> v;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
    vector<node> v;
    long long winstate = 0;
    for(int i = 1;i<=n;i++){
        win[i] = 0;
        sub[i] = 0;
        top[i] = 1;
    }
    dfs(1,0);
    dfs2(1,0);
    dfs3(1,0);
    dfs4(1,0);
    for(int x = 1;x<=n;x++){
        v.push_back(val[x]);
    }
    long long suma = 0,sumb = 0,sumsz = 0;
    for(auto u:v){
        suma += u.a;
        sumb += u.b;
        sumsz += u.sz;
        suma %= mod;
        sumb %= mod;
        sumsz %= mod;
    }
    for(int i = 1;i<=n;i++){
        winstate += win[i];
    }
    matrix single;
    single.a[0][0] = (suma - sumb +mod)%mod;
    single.a[1][0] = (sumb)%mod;
    single.a[2][0] = (sumsz)%mod;
    single.a[1][1] = n*n%mod;
    single.a[2][2] = n*n%mod;
    matrix total = binexpo(single,d-1);
    long long val = 0;
    val = (val + total.a[0][0] * winstate)%mod;
    val = (val + total.a[1][0] * n)%mod;
    val = (val + total.a[2][0] * 1)%mod;
    winstate = val;
    /*
    for(int i = 0;i<d-1;i++){
        winstate = (winstate * (suma - sumb + mod) + binpow(n,2*i+1)*sumb + binpow(n,2*i)*sumsz)%mod;
    }*/
    winstate = (winstate * (v[0].a - v[0].b + mod) + binpow(n,2*d-1)*v[0].b + binpow(n,2*d-2)*v[0].sz)%mod;
    cout << winstate;

}
int main(){
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
    #ifdef Local
        freopen("in.txt","r",stdin);
        freopen("out.txt","w",stdout);
    #endif
    int t = 1;
    //cin >> t;
    while(t--){
        solve();
    }
    #ifdef Local
        cout << endl << fixed << setprecision(2) << 1000.0 * clock() / CLOCKS_PER_SEC << " milliseconds.";
    #endif
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...