답안 #1007915

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
1007915 2024-06-25T18:53:46 Z NValchanov Star Trek (CEOI20_startrek) C++17
45 / 100
36 ms 18260 KB
#include <bits/stdc++.h>

#define endl '\n'

using namespace std;

typedef long long ll;

const int MAXN = 1e5 + 10;
const ll MAXD = 1e18 + 10;
const ll MOD = 1e9 + 7;

struct matrix
{
    int n;
    vector < vector < ll > > m;

    matrix()
    {
        n = 0;
        m.clear();
    }

    matrix(int n)
    {
        this->n = n;
        m.resize(n);
        for(int i = 0; i < n; i++)
        {
            m[i].resize(n, 0LL);
            m[i][i] = 1LL;
        }
    }

    int size()
    {
        return n;
    }

    friend matrix operator*(matrix a, matrix b)
    {
        int sz = a.size();
        matrix result(sz);

        for(int k = 0; k < sz; k++)
        {
            for(int i = 0; i < sz; i++)
            {
                for(int j = 0; j < sz; j++)
                {
                    result.m[i][j] = (result.m[i][j] + (a.m[i][k] * b.m[k][j]) % MOD) % MOD;
                }
            }
        }

        return result;
    }
};

ll binpow(ll a, ll b)
{
    ll result = 1LL;

    while(b)
    {
        if(b & 1)
            result = (result * a) % MOD;

        a = (a * a) % MOD;
        b /= 2;
    }

    return result;
}

matrix binpow(matrix a, ll b)
{
    int sz = a.size();
    matrix result(sz);

    while(b)
    {
        if(b & 1)
            result = result * a;
            
        a = a * a;
        b /= 2;
    }

    return result;
}

vector < ll > mult(vector < ll > a, matrix b)
{
    vector < ll > result = {0LL, 0LL};
    for(int i = 0; i < 2; i++)
    {
        for(int j = 0; j < 2; j++)
        {
            result[j] = (result[j] + (b.m[i][j] * a[i]) % MOD) % MOD;
        }
    }

    return result;
}

ll n;
ll d;
vector < ll > adj[MAXN];
ll dp[MAXN];
ll dpt[MAXN];
ll dph[MAXN];
ll crit[MAXN];
ll win[MAXN];
ll lose[MAXN];

ll cntwin;
ll cntlose;
ll pot;

void read()
{
    cin >> n >> d;
    for(int i = 1; i <= n - 1; i++)
    {
        int u, v;
        cin >> u >> v;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
}

void dfs(int u, int par)
{
    for(int v : adj[u])
    {
        if(v == par)
            continue;
        
        dfs(v, u);

        if(!dpt[v])
            dpt[u]++;
    }
}

void find_crit(int u, int par)
{
    if(!dpt[u])
        crit[u] = 1;

    for(int v : adj[u])
    {
        if(v == par)
            continue;
        
        find_crit(v, u);

        if(dpt[v])
            win[u] += crit[v];
        else
            lose[u] += crit[v];

        if(dpt[u] == 1 && !dpt[v])
            crit[u] += crit[v];

        if(!dpt[u] && dpt[v])
            crit[u] += crit[v];
    }
}

void change(int u, int v)
{
    if(!dpt[v])
    {
        dpt[u]--;
        lose[u] -= crit[v];
    }
    else
    {
        win[u] -= crit[v];
    }

    if(dpt[u] > 1)
        crit[u] = 0;
    else if(dpt[u] == 1)
        crit[u] = lose[u];
    else
        crit[u] = win[u] + 1;

    
    if(!dpt[u])
    {
        dpt[v]++;
        lose[v] += crit[u];
    }
    else
    {
        win[v] += crit[u];
    }

    if(dpt[v] > 1)
        crit[v] = 0;
    else if(dpt[v] == 1)
        crit[v] = lose[v];
    else
        crit[v] = win[v] + 1;
}

void reroot(int u, int par)
{
    dp[u] = dpt[u];
    dph[u] = crit[u];

    for(int v : adj[u])
    {
        if(v == par)
            continue;

        change(u, v);

        reroot(v, u);
        
        change(v, u);
    }
}

void find_win_lose()
{
    dfs(1, 0);
    find_crit(1, 0);
    reroot(1, 0);

    for(int i = 1; i <= n; i++)
    {
        if(!dp[i])
        {
            pot++;
            cntlose = (cntlose + dph[i]) % MOD;
        }
        else
        {
            cntwin = (cntwin + dph[i]) % MOD;
        }
    }
    // cout << "DPH : " << endl;
    // for(int i = 1; i <= n; i++)
    // {
    //     cout << dph[i] << " ";
    // }
    // cout << endl << endl;
    // cout << "DP : " << endl;
    // for(int i = 1; i <= n; i++)
    // {
    //     cout << dp[i] << " ";
    // }
    // cout << endl;
}

void solve()
{
    matrix mat(2);

    mat.m[0][0] = (cntwin - cntlose + MOD) % MOD;
    mat.m[0][1] = 0LL;
    mat.m[1][0] = 1LL;
    mat.m[1][1] = (pot * (n * n) % MOD) % MOD;

    vector < ll > tmp = {pot, pot * ((n * n) % MOD) % MOD};

    tmp = mult(tmp, binpow(mat, d - 1LL));

    ll ans = 0LL;

    if(dp[1])
        ans = (dph[1] * tmp[0]) % MOD;
    else
        ans = (binpow(n, 2LL * d) - (dph[1] * tmp[0]) % MOD + MOD) % MOD;

    ans = (binpow(n, 2LL * d) - ans + MOD) % MOD;

    cout << ans << endl;
}

int main()
{
    ios_base :: sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);

    read();
    find_win_lose();
    solve();

    return 0;
}
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 5208 KB Output is correct
2 Incorrect 1 ms 5212 KB Output isn't correct
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 5464 KB Output is correct
2 Correct 1 ms 5212 KB Output is correct
3 Correct 1 ms 5212 KB Output is correct
4 Correct 1 ms 5212 KB Output is correct
5 Correct 1 ms 5212 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 5208 KB Output is correct
2 Correct 1 ms 5212 KB Output is correct
3 Correct 1 ms 5212 KB Output is correct
4 Correct 1 ms 5212 KB Output is correct
5 Correct 1 ms 5212 KB Output is correct
6 Correct 1 ms 5212 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 5208 KB Output is correct
2 Correct 1 ms 5212 KB Output is correct
3 Correct 1 ms 5212 KB Output is correct
4 Correct 1 ms 5212 KB Output is correct
5 Correct 1 ms 5212 KB Output is correct
6 Correct 1 ms 5212 KB Output is correct
7 Correct 1 ms 5212 KB Output is correct
8 Correct 1 ms 5212 KB Output is correct
9 Correct 1 ms 5212 KB Output is correct
10 Correct 1 ms 5212 KB Output is correct
11 Correct 1 ms 5212 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 5208 KB Output is correct
2 Correct 1 ms 5212 KB Output is correct
3 Correct 1 ms 5212 KB Output is correct
4 Correct 1 ms 5212 KB Output is correct
5 Correct 1 ms 5212 KB Output is correct
6 Correct 1 ms 5212 KB Output is correct
7 Correct 1 ms 5212 KB Output is correct
8 Correct 1 ms 5212 KB Output is correct
9 Correct 1 ms 5212 KB Output is correct
10 Correct 1 ms 5212 KB Output is correct
11 Correct 1 ms 5212 KB Output is correct
12 Correct 34 ms 14164 KB Output is correct
13 Correct 36 ms 18260 KB Output is correct
14 Correct 24 ms 10836 KB Output is correct
15 Correct 32 ms 10840 KB Output is correct
16 Correct 28 ms 11056 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 5208 KB Output is correct
2 Correct 1 ms 5212 KB Output is correct
3 Correct 1 ms 5212 KB Output is correct
4 Correct 1 ms 5212 KB Output is correct
5 Correct 1 ms 5212 KB Output is correct
6 Correct 1 ms 5212 KB Output is correct
7 Correct 1 ms 5212 KB Output is correct
8 Correct 1 ms 5212 KB Output is correct
9 Correct 1 ms 5212 KB Output is correct
10 Correct 1 ms 5212 KB Output is correct
11 Correct 1 ms 5212 KB Output is correct
12 Correct 1 ms 5212 KB Output is correct
13 Incorrect 1 ms 5212 KB Output isn't correct
14 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 5208 KB Output is correct
2 Correct 1 ms 5212 KB Output is correct
3 Correct 1 ms 5212 KB Output is correct
4 Correct 1 ms 5212 KB Output is correct
5 Correct 1 ms 5212 KB Output is correct
6 Correct 1 ms 5212 KB Output is correct
7 Correct 1 ms 5212 KB Output is correct
8 Correct 1 ms 5212 KB Output is correct
9 Correct 1 ms 5212 KB Output is correct
10 Correct 1 ms 5212 KB Output is correct
11 Correct 1 ms 5212 KB Output is correct
12 Correct 34 ms 14164 KB Output is correct
13 Correct 36 ms 18260 KB Output is correct
14 Correct 24 ms 10836 KB Output is correct
15 Correct 32 ms 10840 KB Output is correct
16 Correct 28 ms 11056 KB Output is correct
17 Correct 1 ms 5212 KB Output is correct
18 Incorrect 1 ms 5212 KB Output isn't correct
19 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 5208 KB Output is correct
2 Incorrect 1 ms 5212 KB Output isn't correct
3 Halted 0 ms 0 KB -