답안 #1007910

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
1007910 2024-06-25T18:43:21 Z NValchanov Star Trek (CEOI20_startrek) C++17
45 / 100
43 ms 18436 KB
#include <bits/stdc++.h>

#define endl '\n'

using namespace std;

typedef long long ll;

const int MAXN = 1e5 + 10;
const ll MAXD = 1e18 + 10;
const ll MOD = 1e9 + 7;

struct matrix
{
    int n;
    vector < vector < ll > > m;

    matrix()
    {
        n = 0;
        m.clear();
    }

    matrix(int n)
    {
        this->n = n;
        m.resize(n);
        for(int i = 0; i < n; i++)
        {
            m[i].resize(n, 0LL);
            m[i][i] = 1LL;
        }
    }

    int size()
    {
        return n;
    }

    friend matrix operator*(matrix a, matrix b)
    {
        int sz = a.size();
        matrix result(sz);

        for(int k = 0; k < sz; k++)
        {
            for(int i = 0; i < sz; i++)
            {
                for(int j = 0; j < sz; j++)
                {
                    result.m[i][j] = (result.m[i][j] + (a.m[i][k] * b.m[k][j]) % MOD) % MOD;
                }
            }
        }

        return result;
    }
};

ll binpow(ll a, ll b)
{
    ll result = 1;

    while(b)
    {
        if(b & 1)
            result = (result * a) % MOD;

        a = (a * a) % MOD;
        b /= 2;
    }

    return result;
}

matrix binpow(matrix a, ll b)
{
    int sz = a.size();
    matrix result(sz);

    while(b)
    {
        if(b & 1)
            result = result * a;
            
        a = a * a;
        b /= 2;
    }

    return result;
}

vector < ll > mult(vector < ll > a, matrix b)
{
    vector < ll > result = {0, 0};
    for(int i = 0; i < 2; i++)
    {
        for(int j = 0; j < 2; j++)
        {
            result[j] = (result[j] + (b.m[i][j] * a[i]) % MOD) % MOD;
        }
    }

    return result;
}

ll n;
ll d;
vector < ll > adj[MAXN];
ll dp[MAXN];
ll dpt[MAXN];
ll dph[MAXN];
ll crit[MAXN];
ll win[MAXN];
ll lose[MAXN];

ll cntwin;
ll cntlose;
ll pot;

void read()
{
    cin >> n >> d;
    for(int i = 1; i <= n - 1; i++)
    {
        int u, v;
        cin >> u >> v;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
}

void dfs(int u, int par)
{
    for(int v : adj[u])
    {
        if(v == par)
            continue;
        
        dfs(v, u);

        if(!dpt[v])
            dpt[u]++;
    }
}

void find_crit(int u, int par)
{
    if(!dpt[u])
        crit[u] = 1;

    for(int v : adj[u])
    {
        if(v == par)
            continue;
        
        find_crit(v, u);

        if(dpt[v])
            win[u] += crit[v];
        else
            lose[u] += crit[v];

        if(dpt[u] == 1 && !dpt[v])
            crit[u] += crit[v];

        if(!dpt[u] && dpt[v])
            crit[u] += crit[v];
    }
}

void change(int u, int v)
{
    if(!dpt[v])
    {
        dpt[u]--;
        lose[u] -= crit[v];
    }
    else
    {
        win[u] -= crit[v];
    }

    if(dpt[u] > 1)
        crit[u] = 0;
    else if(dpt[u] == 1)
        crit[u] = lose[u];
    else
        crit[u] = win[u] + 1;

    
    if(!dpt[u])
    {
        dpt[v]++;
        lose[v] += crit[u];
    }
    else
    {
        win[v] += crit[u];
    }

    if(dpt[v] > 1)
        crit[v] = 0;
    else if(dpt[v] == 1)
        crit[v] = lose[v];
    else
        crit[v] = win[v] + 1;
}

void reroot(int u, int par)
{
    dp[u] = dpt[u];
    dph[u] = crit[u];

    for(int v : adj[u])
    {
        if(v == par)
            continue;

        change(u, v);

        reroot(v, u);
        
        change(v, u);
    }
}

void find_win_lose()
{
    dfs(1, 0);
    find_crit(1, 0);
    reroot(1, 0);

    for(int i = 1; i <= n; i++)
    {
        if(!dp[i])
        {
            pot++;
            cntlose = (cntlose + dph[i]) % MOD;
        }
        else
        {
            cntwin = (cntwin + dph[i]) % MOD;
        }
    }
    // cout << "DPH : " << endl;
    // for(int i = 1; i <= n; i++)
    // {
    //     cout << dph[i] << " ";
    // }
    // cout << endl << endl;
    // cout << "DP : " << endl;
    // for(int i = 1; i <= n; i++)
    // {
    //     cout << dp[i] << " ";
    // }
    // cout << endl;
}

void solve()
{
    matrix mat(2);

    mat.m[0][0] = (cntwin - cntlose + MOD) % MOD;
    mat.m[0][1] = 0LL;
    mat.m[1][0] = 1LL;
    mat.m[1][1] = (n * n) % MOD;

    vector < ll > tmp = {pot, pot * ((n * n) % MOD) % MOD};

    tmp = mult(tmp, binpow(mat, d - 1));

    ll ans = 0;

    if(dp[1])
        ans = (dph[1] * tmp[0]) % MOD;
    else
        ans = (binpow(n, 2LL * d) - (dph[1] * tmp[0]) % MOD + MOD) % MOD;

    ans = (binpow(n, 2LL * d) - ans + MOD) % MOD;

    cout << ans << endl;
}

int main()
{
    ios_base :: sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);

    read();
    find_win_lose();
    solve();

    return 0;
}
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 5208 KB Output is correct
2 Incorrect 1 ms 5212 KB Output isn't correct
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 5212 KB Output is correct
2 Correct 1 ms 5212 KB Output is correct
3 Correct 1 ms 5368 KB Output is correct
4 Correct 1 ms 5212 KB Output is correct
5 Correct 1 ms 5212 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 5212 KB Output is correct
2 Correct 1 ms 5212 KB Output is correct
3 Correct 1 ms 5212 KB Output is correct
4 Correct 1 ms 5212 KB Output is correct
5 Correct 1 ms 5212 KB Output is correct
6 Correct 1 ms 5212 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 5212 KB Output is correct
2 Correct 1 ms 5212 KB Output is correct
3 Correct 1 ms 5212 KB Output is correct
4 Correct 1 ms 5212 KB Output is correct
5 Correct 1 ms 5212 KB Output is correct
6 Correct 1 ms 5212 KB Output is correct
7 Correct 1 ms 5212 KB Output is correct
8 Correct 1 ms 5304 KB Output is correct
9 Correct 1 ms 5212 KB Output is correct
10 Correct 1 ms 5212 KB Output is correct
11 Correct 1 ms 5212 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 5212 KB Output is correct
2 Correct 1 ms 5212 KB Output is correct
3 Correct 1 ms 5212 KB Output is correct
4 Correct 1 ms 5212 KB Output is correct
5 Correct 1 ms 5212 KB Output is correct
6 Correct 1 ms 5212 KB Output is correct
7 Correct 1 ms 5212 KB Output is correct
8 Correct 1 ms 5304 KB Output is correct
9 Correct 1 ms 5212 KB Output is correct
10 Correct 1 ms 5212 KB Output is correct
11 Correct 1 ms 5212 KB Output is correct
12 Correct 39 ms 14172 KB Output is correct
13 Correct 43 ms 18436 KB Output is correct
14 Correct 27 ms 10764 KB Output is correct
15 Correct 39 ms 10836 KB Output is correct
16 Correct 33 ms 10840 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 5212 KB Output is correct
2 Correct 1 ms 5212 KB Output is correct
3 Correct 1 ms 5212 KB Output is correct
4 Correct 1 ms 5212 KB Output is correct
5 Correct 1 ms 5212 KB Output is correct
6 Correct 1 ms 5212 KB Output is correct
7 Correct 1 ms 5212 KB Output is correct
8 Correct 1 ms 5304 KB Output is correct
9 Correct 1 ms 5212 KB Output is correct
10 Correct 1 ms 5212 KB Output is correct
11 Correct 1 ms 5212 KB Output is correct
12 Correct 1 ms 5212 KB Output is correct
13 Incorrect 1 ms 5212 KB Output isn't correct
14 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 5212 KB Output is correct
2 Correct 1 ms 5212 KB Output is correct
3 Correct 1 ms 5212 KB Output is correct
4 Correct 1 ms 5212 KB Output is correct
5 Correct 1 ms 5212 KB Output is correct
6 Correct 1 ms 5212 KB Output is correct
7 Correct 1 ms 5212 KB Output is correct
8 Correct 1 ms 5304 KB Output is correct
9 Correct 1 ms 5212 KB Output is correct
10 Correct 1 ms 5212 KB Output is correct
11 Correct 1 ms 5212 KB Output is correct
12 Correct 39 ms 14172 KB Output is correct
13 Correct 43 ms 18436 KB Output is correct
14 Correct 27 ms 10764 KB Output is correct
15 Correct 39 ms 10836 KB Output is correct
16 Correct 33 ms 10840 KB Output is correct
17 Correct 1 ms 5212 KB Output is correct
18 Incorrect 1 ms 5212 KB Output isn't correct
19 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 5208 KB Output is correct
2 Incorrect 1 ms 5212 KB Output isn't correct
3 Halted 0 ms 0 KB -