답안 #1007909

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
1007909 2024-06-25T18:41:25 Z NValchanov Star Trek (CEOI20_startrek) C++17
45 / 100
43 ms 19540 KB
#include <bits/stdc++.h>

#define endl '\n'

using namespace std;

typedef long long ll;

const int MAXN = 1e5 + 10;
const ll MAXD = 1e18 + 10;
const ll MOD = 1e9 + 7;

struct matrix
{
    int n;
    vector < vector < ll > > m;

    matrix()
    {
        n = 0;
        m.clear();
    }

    matrix(int n)
    {
        this->n = n;
        m.resize(n);
        for(int i = 0; i < n; i++)
        {
            m[i].resize(n, 0LL);
            m[i][i] = 1LL;
        }
    }

    int size()
    {
        return n;
    }

    friend matrix operator*(matrix a, matrix b)
    {
        int sz = a.size();
        matrix result(sz);

        for(int k = 0; k < sz; k++)
        {
            for(int i = 0; i < sz; i++)
            {
                for(int j = 0; j < sz; j++)
                {
                    result.m[i][j] = (result.m[i][j] + (a.m[i][k] * b.m[k][j]) % MOD) % MOD;
                }
            }
        }

        return result;
    }
};

ll binpow(ll a, ll b)
{
    ll result = 1;

    while(b)
    {
        if(b & 1)
            result = (result * a) % MOD;

        a = (a * a) % MOD;
        b /= 2;
    }

    return result;
}

matrix binpow(matrix a, ll b)
{
    int sz = a.size();
    matrix result(sz);

    while(b)
    {
        if(b & 1)
            result = result * a;
            
        a = a * a;
        b /= 2;
    }

    return result;
}

vector < ll > mult(vector < ll > a, matrix b)
{
    vector < ll > result = {0, 0};
    for(int i = 0; i < 2; i++)
    {
        for(int j = 0; j < 2; j++)
        {
            result[j] = (result[j] + b.m[i][j] * a[i]) % MOD;
        }
    }

    return result;
}

ll n;
ll d;
vector < ll > adj[MAXN];
ll dp[MAXN];
ll dpt[MAXN];
ll dph[MAXN];
ll crit[MAXN];
ll win[MAXN];
ll lose[MAXN];

ll cntwin;
ll cntlose;
ll pot;

void read()
{
    cin >> n >> d;
    for(int i = 1; i <= n - 1; i++)
    {
        int u, v;
        cin >> u >> v;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
}

void dfs(int u, int par)
{
    for(int v : adj[u])
    {
        if(v == par)
            continue;
        
        dfs(v, u);

        if(!dpt[v])
            dpt[u]++;
    }
}

void find_crit(int u, int par)
{
    if(!dpt[u])
        crit[u] = 1;

    for(int v : adj[u])
    {
        if(v == par)
            continue;
        
        find_crit(v, u);

        if(dpt[v])
            win[u] += crit[v];
        else
            lose[u] += crit[v];

        if(dpt[u] == 1 && !dpt[v])
            crit[u] += crit[v];

        if(!dpt[u] && dpt[v])
            crit[u] += crit[v];
    }
}

void change(int u, int v)
{
    if(!dpt[v])
    {
        dpt[u]--;
        lose[u] -= crit[v];
    }
    else
    {
        win[u] -= crit[v];
    }

    if(dpt[u] > 1)
        crit[u] = 0;
    else if(dpt[u] == 1)
        crit[u] = lose[u];
    else
        crit[u] = win[u] + 1;

    
    if(!dpt[u])
    {
        dpt[v]++;
        lose[v] += crit[u];
    }
    else
    {
        win[v] += crit[u];
    }

    if(dpt[v] > 1)
        crit[v] = 0;
    else if(dpt[v] == 1)
        crit[v] = lose[v];
    else
        crit[v] = win[v] + 1;
}

void reroot(int u, int par)
{
    dp[u] = dpt[u];
    dph[u] = crit[u];

    for(int v : adj[u])
    {
        if(v == par)
            continue;

        change(u, v);

        reroot(v, u);
        
        change(v, u);
    }
}

void find_win_lose()
{
    dfs(1, 0);
    find_crit(1, 0);
    reroot(1, 0);

    for(int i = 1; i <= n; i++)
    {
        if(!dp[i])
        {
            pot++;
            cntlose = (cntlose + dph[i]) % MOD;
        }
        else
        {
            cntwin = (cntwin + dph[i]) % MOD;
        }
    }
    // cout << "DPH : " << endl;
    // for(int i = 1; i <= n; i++)
    // {
    //     cout << dph[i] << " ";
    // }
    // cout << endl << endl;
    // cout << "DP : " << endl;
    // for(int i = 1; i <= n; i++)
    // {
    //     cout << dp[i] << " ";
    // }
    // cout << endl;
}

void solve()
{
    matrix mat(2);

    mat.m[0][0] = (cntwin - cntlose + MOD) % MOD;
    mat.m[0][1] = 0LL;
    mat.m[1][0] = 1LL;
    mat.m[1][1] = (n * n) % MOD;

    vector < ll > tmp = {pot, pot * ((n * n) % MOD) % MOD};

    tmp = mult(tmp, binpow(mat, d - 1));

    ll ans = 0;

    if(dp[1])
        ans = (dph[1] * tmp[0]) % MOD;
    else
        ans = (binpow(n, 2LL * d) - (dph[1] * tmp[0]) % MOD + MOD) % MOD;

    ans = (binpow(n, 2LL * d) - ans + MOD) % MOD;

    cout << ans << endl;
}

int main()
{
    ios_base :: sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);

    read();
    find_win_lose();
    solve();

    return 0;
}
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 5212 KB Output is correct
2 Incorrect 1 ms 3164 KB Output isn't correct
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 5208 KB Output is correct
2 Correct 1 ms 3164 KB Output is correct
3 Correct 1 ms 5212 KB Output is correct
4 Correct 1 ms 3164 KB Output is correct
5 Correct 1 ms 3164 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 5208 KB Output is correct
2 Correct 1 ms 3164 KB Output is correct
3 Correct 1 ms 5300 KB Output is correct
4 Correct 1 ms 5276 KB Output is correct
5 Correct 1 ms 3164 KB Output is correct
6 Correct 1 ms 5212 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 5208 KB Output is correct
2 Correct 1 ms 3164 KB Output is correct
3 Correct 1 ms 5300 KB Output is correct
4 Correct 1 ms 5276 KB Output is correct
5 Correct 1 ms 3164 KB Output is correct
6 Correct 1 ms 5212 KB Output is correct
7 Correct 1 ms 5208 KB Output is correct
8 Correct 1 ms 5212 KB Output is correct
9 Correct 1 ms 3164 KB Output is correct
10 Correct 1 ms 3164 KB Output is correct
11 Correct 1 ms 5212 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 5208 KB Output is correct
2 Correct 1 ms 3164 KB Output is correct
3 Correct 1 ms 5300 KB Output is correct
4 Correct 1 ms 5276 KB Output is correct
5 Correct 1 ms 3164 KB Output is correct
6 Correct 1 ms 5212 KB Output is correct
7 Correct 1 ms 5208 KB Output is correct
8 Correct 1 ms 5212 KB Output is correct
9 Correct 1 ms 3164 KB Output is correct
10 Correct 1 ms 3164 KB Output is correct
11 Correct 1 ms 5212 KB Output is correct
12 Correct 37 ms 15468 KB Output is correct
13 Correct 43 ms 19540 KB Output is correct
14 Correct 30 ms 11856 KB Output is correct
15 Correct 34 ms 12124 KB Output is correct
16 Correct 33 ms 12116 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 5208 KB Output is correct
2 Correct 1 ms 3164 KB Output is correct
3 Correct 1 ms 5300 KB Output is correct
4 Correct 1 ms 5276 KB Output is correct
5 Correct 1 ms 3164 KB Output is correct
6 Correct 1 ms 5212 KB Output is correct
7 Correct 1 ms 5208 KB Output is correct
8 Correct 1 ms 5212 KB Output is correct
9 Correct 1 ms 3164 KB Output is correct
10 Correct 1 ms 3164 KB Output is correct
11 Correct 1 ms 5212 KB Output is correct
12 Correct 1 ms 5208 KB Output is correct
13 Incorrect 1 ms 5212 KB Output isn't correct
14 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 5208 KB Output is correct
2 Correct 1 ms 3164 KB Output is correct
3 Correct 1 ms 5300 KB Output is correct
4 Correct 1 ms 5276 KB Output is correct
5 Correct 1 ms 3164 KB Output is correct
6 Correct 1 ms 5212 KB Output is correct
7 Correct 1 ms 5208 KB Output is correct
8 Correct 1 ms 5212 KB Output is correct
9 Correct 1 ms 3164 KB Output is correct
10 Correct 1 ms 3164 KB Output is correct
11 Correct 1 ms 5212 KB Output is correct
12 Correct 37 ms 15468 KB Output is correct
13 Correct 43 ms 19540 KB Output is correct
14 Correct 30 ms 11856 KB Output is correct
15 Correct 34 ms 12124 KB Output is correct
16 Correct 33 ms 12116 KB Output is correct
17 Correct 1 ms 5208 KB Output is correct
18 Incorrect 1 ms 5212 KB Output isn't correct
19 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 5212 KB Output is correct
2 Incorrect 1 ms 3164 KB Output isn't correct
3 Halted 0 ms 0 KB -