답안 #1007917

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
1007917 2024-06-25T18:58:34 Z NValchanov Star Trek (CEOI20_startrek) C++17
45 / 100
43 ms 18268 KB
#include <bits/stdc++.h>

#define endl '\n'

using namespace std;

typedef long long ll;

const int MAXN = 1e5 + 10;
const ll MAXD = 1e18 + 10;
const ll MOD = 1e9 + 7;

struct matrix
{
    int n;
    vector < vector < ll > > m;

    matrix()
    {
        n = 0;
        m.clear();
    }

    matrix(int n)
    {
        this->n = n;
        m.resize(n);
        for(int i = 0; i < n; i++)
        {
            m[i].resize(n, 0LL);
            m[i][i] = 1LL;
        }
    }

    int size()
    {
        return n;
    }

    friend matrix operator*(matrix a, matrix b)
    {
        int sz = a.size();
        matrix result(sz);

        for(int k = 0; k < sz; k++)
        {
            for(int i = 0; i < sz; i++)
            {
                for(int j = 0; j < sz; j++)
                {
                    result.m[i][j] = (result.m[i][j] + (a.m[i][k] * b.m[k][j]) % MOD) % MOD;
                }
            }
        }

        return result;
    }
};

ll binpow(ll a, ll b)
{
    ll result = 1LL;

    while(b)
    {
        if(b & 1)
            result = (result * a) % MOD;

        a = (a * a) % MOD;
        b /= 2;
    }

    return result;
}

matrix binpow(matrix a, ll b)
{
    int sz = a.size();
    matrix result(sz);

    while(b)
    {
        if(b & 1)
            result = result * a;
            
        a = a * a;
        b /= 2;
    }

    return result;
}

vector < ll > mult(vector < ll > a, matrix b)
{
    vector < ll > result = {0LL, 0LL};
    for(int i = 0; i < 2; i++)
    {
        for(int j = 0; j < 2; j++)
        {
            result[j] = (result[j] + (b.m[i][j] * a[i]) % MOD) % MOD;
        }
    }

    return result;
}

ll n;
ll d;
vector < ll > adj[MAXN];
ll dp[MAXN];
ll dpt[MAXN];
ll dph[MAXN];
ll crit[MAXN];
ll win[MAXN];
ll lose[MAXN];

ll cntwin;
ll cntlose;
ll pot;

void read()
{
    cin >> n >> d;
    for(int i = 1; i <= n - 1; i++)
    {
        int u, v;
        cin >> u >> v;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
}

void dfs(int u, int par)
{
    for(int v : adj[u])
    {
        if(v == par)
            continue;
        
        dfs(v, u);

        if(!dpt[v])
            dpt[u]++;
    }
}

void find_crit(int u, int par)
{
    if(!dpt[u])
        crit[u] = 1;

    for(int v : adj[u])
    {
        if(v == par)
            continue;
        
        find_crit(v, u);

        if(dpt[v])
            win[u] += crit[v];
        else
            lose[u] += crit[v];

        if(dpt[u] == 1 && !dpt[v])
            crit[u] += crit[v];

        if(!dpt[u] && dpt[v])
            crit[u] += crit[v];
    }
}

void change(int u, int v)
{
    if(!dpt[v])
    {
        dpt[u]--;
        lose[u] -= crit[v];
    }
    else
    {
        win[u] -= crit[v];
    }

    if(dpt[u] > 1)
        crit[u] = 0;
    else if(dpt[u] == 1)
        crit[u] = lose[u];
    else
        crit[u] = win[u] + 1;

    
    if(!dpt[u])
    {
        dpt[v]++;
        lose[v] += crit[u];
    }
    else
    {
        win[v] += crit[u];
    }

    if(dpt[v] > 1)
        crit[v] = 0;
    else if(dpt[v] == 1)
        crit[v] = lose[v];
    else
        crit[v] = win[v] + 1;
}

void reroot(int u, int par)
{
    dp[u] = dpt[u];
    dph[u] = crit[u];

    for(int v : adj[u])
    {
        if(v == par)
            continue;

        change(u, v);

        reroot(v, u);
        
        change(v, u);
    }
}

void find_win_lose()
{
    dfs(1, 0);
    find_crit(1, 0);
    reroot(1, 0);

    for(int i = 1; i <= n; i++)
    {
        if(!dp[i])
        {
            pot++;
            cntlose = (cntlose + dph[i]) % MOD;
        }
        else
        {
            cntwin = (cntwin + dph[i]) % MOD;
        }
    }
    // cout << "DPH : " << endl;
    // for(int i = 1; i <= n; i++)
    // {
    //     cout << dph[i] << " ";
    // }
    // cout << endl << endl;
    // cout << "DP : " << endl;
    // for(int i = 1; i <= n; i++)
    // {
    //     cout << dp[i] << " ";
    // }
    // cout << endl;
}

void solve()
{
    matrix mat(2);

    mat.m[0][0] = (cntwin - cntlose + MOD) % MOD;
    mat.m[0][1] = 0LL;
    mat.m[1][0] = 1LL;
    mat.m[1][1] = (n * n) % MOD;

    vector < ll > tmp = {pot, (pot * (n * n) % MOD) % MOD};

    tmp = mult(tmp, binpow(mat, d - 1LL));

    ll ans = 0LL;

    if(dp[1])
        ans = (dph[1] * tmp[0]) % MOD;
    else
        ans = (binpow(n, 2LL * d) - (dph[1] * tmp[0]) % MOD + MOD) % MOD;

    ans = (binpow(n, 2LL * d) - ans + MOD) % MOD;

    cout << ans << endl;
}

int main()
{
    ios_base :: sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);

    read();
    find_win_lose();
    solve();

    return 0;
}
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 5464 KB Output is correct
2 Incorrect 2 ms 5208 KB Output isn't correct
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 5212 KB Output is correct
2 Correct 1 ms 5212 KB Output is correct
3 Correct 1 ms 5212 KB Output is correct
4 Correct 1 ms 5212 KB Output is correct
5 Correct 1 ms 5212 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 5212 KB Output is correct
2 Correct 1 ms 5208 KB Output is correct
3 Correct 1 ms 5212 KB Output is correct
4 Correct 1 ms 5212 KB Output is correct
5 Correct 1 ms 5212 KB Output is correct
6 Correct 1 ms 5212 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 5212 KB Output is correct
2 Correct 1 ms 5208 KB Output is correct
3 Correct 1 ms 5212 KB Output is correct
4 Correct 1 ms 5212 KB Output is correct
5 Correct 1 ms 5212 KB Output is correct
6 Correct 1 ms 5212 KB Output is correct
7 Correct 1 ms 5212 KB Output is correct
8 Correct 1 ms 5212 KB Output is correct
9 Correct 1 ms 5212 KB Output is correct
10 Correct 1 ms 5212 KB Output is correct
11 Correct 1 ms 5212 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 5212 KB Output is correct
2 Correct 1 ms 5208 KB Output is correct
3 Correct 1 ms 5212 KB Output is correct
4 Correct 1 ms 5212 KB Output is correct
5 Correct 1 ms 5212 KB Output is correct
6 Correct 1 ms 5212 KB Output is correct
7 Correct 1 ms 5212 KB Output is correct
8 Correct 1 ms 5212 KB Output is correct
9 Correct 1 ms 5212 KB Output is correct
10 Correct 1 ms 5212 KB Output is correct
11 Correct 1 ms 5212 KB Output is correct
12 Correct 38 ms 14164 KB Output is correct
13 Correct 43 ms 18268 KB Output is correct
14 Correct 26 ms 10840 KB Output is correct
15 Correct 32 ms 10812 KB Output is correct
16 Correct 36 ms 10832 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 5212 KB Output is correct
2 Correct 1 ms 5208 KB Output is correct
3 Correct 1 ms 5212 KB Output is correct
4 Correct 1 ms 5212 KB Output is correct
5 Correct 1 ms 5212 KB Output is correct
6 Correct 1 ms 5212 KB Output is correct
7 Correct 1 ms 5212 KB Output is correct
8 Correct 1 ms 5212 KB Output is correct
9 Correct 1 ms 5212 KB Output is correct
10 Correct 1 ms 5212 KB Output is correct
11 Correct 1 ms 5212 KB Output is correct
12 Correct 1 ms 5208 KB Output is correct
13 Incorrect 1 ms 5212 KB Output isn't correct
14 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 5212 KB Output is correct
2 Correct 1 ms 5208 KB Output is correct
3 Correct 1 ms 5212 KB Output is correct
4 Correct 1 ms 5212 KB Output is correct
5 Correct 1 ms 5212 KB Output is correct
6 Correct 1 ms 5212 KB Output is correct
7 Correct 1 ms 5212 KB Output is correct
8 Correct 1 ms 5212 KB Output is correct
9 Correct 1 ms 5212 KB Output is correct
10 Correct 1 ms 5212 KB Output is correct
11 Correct 1 ms 5212 KB Output is correct
12 Correct 38 ms 14164 KB Output is correct
13 Correct 43 ms 18268 KB Output is correct
14 Correct 26 ms 10840 KB Output is correct
15 Correct 32 ms 10812 KB Output is correct
16 Correct 36 ms 10832 KB Output is correct
17 Correct 1 ms 5208 KB Output is correct
18 Incorrect 1 ms 5212 KB Output isn't correct
19 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 5464 KB Output is correct
2 Incorrect 2 ms 5208 KB Output isn't correct
3 Halted 0 ms 0 KB -