제출 #1201767

#제출 시각아이디문제언어결과실행 시간메모리
1201767byunjaewooStar Trek (CEOI20_startrek)C++20
38 / 100
57 ms23880 KiB
#include <bits/stdc++.h>
#define int long long
using namespace std;

const int N=100010, Mod=1e9+7;
int n, d, dp[N], ep[N], cx[N], cy[N], cnt[N], w[N], c[N], c1, c2, ans[N], rans;
vector<int> adj[N];

void dfs1(int curr, int prev) {
    for(int next:adj[curr]) if(next!=prev) {
        dfs1(next, curr);
        cnt[curr]+=1-dp[next];
        if(!dp[next]) dp[curr]=true;
    }
}

void dfs2(int curr, int prev) {
    for(int next:adj[curr]) if(next!=prev) {
        if(cnt[curr]-(1-dp[next])==0 && !ep[curr]) ep[next]=true;
        dfs2(next, curr);
    }
    w[curr]=dp[curr]|ep[curr];
}

void dfs3(int curr, int prev) {
    cx[curr]=1-dp[curr];
    for(int next:adj[curr]) if(next!=prev) dfs3(next, curr);
    if(dp[curr]) {
        vector<int> v;
        for(int next:adj[curr]) if(next!=prev && !dp[next]) v.push_back(next);
        if(v.size()==1) cx[curr]+=cx[v[0]];
    }
    else {
        for(int next:adj[curr]) if(next!=prev) cx[curr]+=cx[next];
    }
}

void dfs4(int curr, int prev) {
    cy[curr]+=1-dp[curr];
    for(int next:adj[curr]) if(next!=prev) {
        if(!w[next] || (cnt[next]+ep[next]==1 && !dp[next])) {
            cy[next]=cy[curr]+cx[curr]-(1-dp[curr]);
            if(!dp[curr] || dp[curr] && cnt[curr]==1 && !dp[next]) cy[next]-=cx[next];
        }
        dfs4(next, curr);
    }
    if(!w[curr]) c[curr]=cx[curr]+cy[curr]-(1-dp[curr]);
    else if(cnt[curr]+ep[curr]>=2) c[curr]=0;
    else if(dp[curr]) c[curr]=cx[curr];
    else c[curr]=cy[curr];
}

signed main() {
    ios_base::sync_with_stdio(0); cin.tie(0);
    cin>>n>>d;
    for(int i=1; i<n; i++) {
        int u, v; cin>>u>>v;
        adj[u].push_back(v), adj[v].push_back(u);
    }
    dfs1(1, 0), dfs2(1, 0), dfs3(1, 0), dfs4(1, 0);
    for(int i=1; i<=n; i++) ans[0]+=w[i];
    for(int i=1; i<=n; i++) c1+=w[i], c2+=c[i]*(1-2*w[i]);
    int val=1;
    for(int i=1; i<d; i++, val=(val*n)%Mod, val=(val*n)%Mod) {
        ans[i]=(c1*val)%Mod, ans[i]=(ans[i]*n)%Mod, ans[i]=(ans[i]*n)%Mod;
        ans[i]=(ans[i]+((val*n)%Mod)*c2)%Mod, ans[i]=(ans[i]-ans[i-1]*c2)%Mod, ans[i]=(ans[i]+Mod)%Mod;
    }
    val=(val*n)%Mod, rans=(c[1]*(val-ans[d-1]+Mod))%Mod;
    if(w[1]) val=(val*n)%Mod, rans=(val-rans+Mod)%Mod;
    cout<<rans;
    return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...