답안 #865691

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
865691 2023-10-24T14:09:17 Z boris_mihov Star Trek (CEOI20_startrek) C++17
100 / 100
80 ms 26832 KB
#include <algorithm>
#include <iostream>
#include <numeric>
#include <cassert>
#include <vector>
#include <random>
#include <queue>
#include <stack>
#include <set>

typedef long long llong;
const int MAXN = 100000 + 10;
const int MOD = 1e9 + 7;
const int INF = 2e9;

int n;
llong d;
int sumL;
int sz[MAXN];
bool win[MAXN];
int myPar[MAXN];
int cntChildLoses[MAXN];
int cntChildLoses2[MAXN];
std::vector <int> g[MAXN];
std::vector <int> zero[MAXN];

llong power(int num, llong base)
{
    if (base == 0) return 1;
    if (base & 1) return (num * power(num, base - 1)) % MOD;
    llong res = power(num, base >> 1);
    return (res * res) % MOD;
}

bool calcSubtreeWin(int node, int par)
{
    cntChildLoses[node] = 0;
    for (const int &u : g[node])
    {
        if (u == par)
        {
            continue;
        }
 
        cntChildLoses[node] += !calcSubtreeWin(u, node);
    }
 
    return cntChildLoses[node];
}
 
void calcUpWin(int node, int par)
{
    if (par != 0 && cntChildLoses[par] == !cntChildLoses[node])
    {
        cntChildLoses[node]++;
    }
 
    for (const int &u : g[node])
    {
        if (u == par)
        {
            continue;
        }
 
        calcUpWin(u, node);
    }
}

struct DPElement
{
    int powCoef;
    int sumCoef;

    DPElement()
    {
        powCoef = sumCoef = 0;
    }

    DPElement(int _powCoef, int _sumCoef)
    {
        powCoef = _powCoef;
        sumCoef = _sumCoef;
    }

    friend DPElement operator + (const DPElement &a, const DPElement &b)
    {
        DPElement res;
        res.powCoef = (a.powCoef + b.powCoef);
        res.sumCoef = (a.sumCoef + b.sumCoef);
        if (res.powCoef >= MOD) res.powCoef -= MOD;
        if (res.sumCoef >= MOD) res.sumCoef -= MOD;
        return res;
    }

    friend DPElement operator - (const DPElement &a, const DPElement &b)
    {
        DPElement res;
        res.powCoef = (a.powCoef - b.powCoef);
        res.sumCoef = (a.sumCoef - b.sumCoef);
        if (res.powCoef < 0) res.powCoef += MOD;
        if (res.sumCoef < 0) res.sumCoef += MOD;
        return res;
    }

    void operator += (const DPElement &b)
    {
        powCoef += b.powCoef;
        sumCoef += b.sumCoef;
        if (powCoef >= MOD) powCoef -= MOD;
        if (sumCoef >= MOD) sumCoef -= MOD;
    }

    void operator -= (const DPElement &b)
    {
        powCoef -= b.powCoef;
        sumCoef -= b.sumCoef;
        if (powCoef < 0) powCoef += MOD;
        if (sumCoef < 0) sumCoef += MOD;
    }
};

DPElement dp[MAXN];
DPElement dp2[MAXN];
void calcForSubtree(int node, int par)
{
    sz[node] = 1;
    int which = -1;
    myPar[node] = par;
    cntChildLoses[node] = 0;
    for (const int &u : g[node])
    {
        if (u == par)
        {
            continue;
        }

        calcForSubtree(u, node);
        if (cntChildLoses[u] == 0)
        {
            zero[node].push_back(u);
            which = u;
        }

        cntChildLoses[node] += !cntChildLoses[u];
        sz[node] += sz[u];
    }
    
    dp[node] = {0, 0};
    if (cntChildLoses[node] >= 2)
    {
        dp[node].powCoef = (1LL * sz[node] * n) % MOD;
    } else if (cntChildLoses[node] == 0)
    {
        dp[node].sumCoef = 1;
        for (const int &u : g[node])
        {
            if (u == par)
            {
                continue;
            }

            dp[node] -= dp[u];
            dp[node].powCoef += (1LL * n * sz[u]) % MOD;
            if (dp[node].powCoef >= MOD) dp[node].powCoef -= MOD;
        }
    } else
    {
        dp[node].powCoef += (((1LL * (sz[node] - sz[which]) * n) % MOD));
        if (dp[node].powCoef >= MOD) dp[node].powCoef -= MOD;
        dp[node] -= dp[which];
        dp[node].powCoef += (1LL * n * sz[which]) % MOD;
        if (dp[node].powCoef >= MOD) dp[node].powCoef -= MOD;
    }
}

bool calced[MAXN];
DPElement wMeVal[MAXN];
DPElement sumZero[MAXN];
DPElement withoutMe(int node, int par)
{
    if (calced[node])
    {
        return wMeVal[node];
    }

    DPElement res;
    calced[node] = true;
    if (cntChildLoses2[par] - (cntChildLoses[node] == 0) >= 2)
    {
        res.powCoef = (1LL * (n - sz[node]) * n) % MOD;
        return wMeVal[node] = res;
    }

    if (cntChildLoses2[par] - (cntChildLoses[node] == 0) == 0)
    {
        res = sumZero[par];
        res += dp[node];
        res.powCoef -= (1LL * n * sz[node]) % MOD;
        if (res.powCoef < 0) res.powCoef += MOD;
        return wMeVal[node] = res;
    }

    int which = zero[par][0];
    if (which == node) which = zero[par][1];
    
    int realSz = sz[which];
    if (which == myPar[par]) realSz = n - sz[par];
    res.powCoef += ((((1LL * (n - sz[node]) - realSz) * n) % MOD));
    if (res.powCoef >= MOD) res.powCoef -= MOD;
    if (which != myPar[par]) res -= dp[which];
    else res -= withoutMe(par, which);

    res.powCoef += (1LL * n * realSz) % MOD;
    if (res.powCoef >= MOD) res.powCoef -= MOD;
    return wMeVal[node] = res;
}

void calcForUp(int node, int par)
{
    bool wasHere = false;
    cntChildLoses2[node] = cntChildLoses[node];
    if (par != 0 && cntChildLoses2[par] == !cntChildLoses2[node])
    {
        wasHere = true;
        cntChildLoses2[node]++;
        zero[node].push_back(par);
    }

    dp2[node] = dp[node];
    if (par != 0)
    {
        if (cntChildLoses2[node] >= 2)
        {
            dp2[node].powCoef = (1LL * n * n) % MOD;
            dp2[node].sumCoef = 0;
        } else if (cntChildLoses2[node] == 0)
        {
            dp2[node].sumCoef = 1;
            dp2[node].powCoef = 0;
            for (const int &u : g[node])
            {
                if (u == par)
                {
                    continue;
                }

                dp2[node] -= dp[u];
                dp2[node].powCoef += (1LL * n * sz[u]) % MOD;
                if (dp2[node].powCoef >= MOD) dp2[node].powCoef -= MOD;
            }

            dp2[node] -= withoutMe(node, par);
            dp2[node].powCoef += (1LL * n * (n - sz[node])) % MOD;
            if (dp2[node].powCoef >= MOD) dp2[node].powCoef -= MOD;
        } else
        {
            dp2[node] = {0, 0};
            int which = -1;
            int szWhich;
            if (wasHere)
            {
                which = par;
                szWhich = n - sz[node];
                dp2[node].powCoef += (((1LL * (n - szWhich) * n) % MOD));
                if (dp2[node].powCoef >= MOD) dp2[node].powCoef -= MOD;
                dp2[node] -= withoutMe(node, par);
                dp2[node].powCoef += (1LL * n * szWhich) % MOD;
                if (dp2[node].powCoef >= MOD) dp2[node].powCoef -= MOD;
            } else
            {
                for (const int &u : g[node])
                {
                    if (u == par)
                    {
                        continue;
                    }

                    if (cntChildLoses[u] == 0)
                    {
                        which = u;
                        break;
                    }
                }

                dp2[node].powCoef += (((1LL * (n - sz[which]) * n) % MOD));
                if (dp2[node].powCoef >= MOD) dp2[node].powCoef -= MOD;
                dp2[node] -= dp[which];
                dp2[node].powCoef += (1LL * n * sz[which]) % MOD;
                if (dp2[node].powCoef >= MOD) dp2[node].powCoef -= MOD;
            }
        }
    }

    sumZero[node].sumCoef = 1;
    if (par != 0) 
    {
        sumZero[node] -= withoutMe(node, par);
        sumZero[node].powCoef += (1LL * n * (n - sz[node])) % MOD;
        if (sumZero[node].powCoef >= MOD) sumZero[node].powCoef -= MOD;     
    }

    for (const int &u : g[node])
    {
        if (u == par)
        {
            continue;
        }

        sumZero[node] -= dp[u];
        sumZero[node].powCoef += (1LL * n * sz[u]) % MOD;
        if (sumZero[node].powCoef >= MOD) sumZero[node].powCoef -= MOD;       
    }

    for (const int &u : g[node])
    {
        if (u == par)
        {
            continue;
        }

        calcForUp(u, node);
    }
}

struct Matrix
{
    int t[2][2];
    Matrix()
    {
        for (int i = 0 ; i < 2 ; ++i)
        {
            for (int j = 0 ; j < 2 ; ++j)
            {
                t[i][j] = 0;
            }
        }
    }

    Matrix(int par)
    {
        for (int i = 0 ; i < 2 ; ++i)
        {
            for (int j = 0 ; j < 2 ; ++j)
            {
                t[i][j] = (i == j);
            }
        }
    }

    friend Matrix operator * (Matrix a, Matrix b)
    {
        Matrix res;
        for (int i = 0 ; i < 2 ; ++i)
        {
            for (int j = 0 ; j < 2 ; ++j)
            {
                for (int k = 0 ; k < 2 ; ++k)
                {
                    res.t[i][j] += (1LL * a.t[i][k] * b.t[k][j]) % MOD;
                    if (res.t[i][j] >= MOD) res.t[i][j] -= MOD;
                }
            }
        }

        return res;
    }
};

Matrix powerMatrix(Matrix m, llong base)
{
    if (base == 0)
    {
        Matrix res(1);
        return res;
    }

    if (base & 1) return (m * powerMatrix(m, base - 1));
    Matrix res = powerMatrix(m, base / 2);
    return res * res;
}

void solve()
{
    calcSubtreeWin(1, 0);
    calcUpWin(1, 0);

    sumL = 0;
    for (int i = 1 ; i <= n ; ++i)
    {
        win[i] = (cntChildLoses[i] > 0);
        sumL += 1 - win[i];
    }

    DPElement sum, dp1;
    calcForSubtree(1, 0);
    calcForUp(1, 0);

    for (int i = 1 ; i <= n ; ++i)
    {
        sum += dp2[i];
    }

    // std::cout << "here: " << sumL << ' ' << sum.sumCoef << ' ' << sum.powCoef << '\n';

    Matrix m;
    m.t[0][0] = MOD - sum.sumCoef;
    m.t[0][1] = (MOD + 1LL * n * n * n - sum.powCoef) % MOD;
    m.t[1][0] = 0;
    m.t[1][1] = (1LL * n * n) % MOD;
    m = powerMatrix(m, d - 1);

    sumL = (1LL * sumL * m.t[0][0] + m.t[0][1]) % MOD;
    sumL = (1LL * dp[1].powCoef * power(n, 2 * (d - 1)) + 1LL * dp[1].sumCoef * sumL) % MOD;
    std::cout << sumL << '\n';
}   

void input()
{
    std::cin >> n >> d;
    for (int i = 1 ; i < n ; ++i)
    {
        int u, v;
        std::cin >> u >> v;
        g[u].push_back(v);
        g[v].push_back(u);
    }
}

void fastIOI()
{
    std::ios_base :: sync_with_stdio(0);
    std::cout.tie(nullptr);
    std::cin.tie(nullptr);
}

signed main()
{
    fastIOI();
    input();
    solve();

    return 0;
}
# 결과 실행 시간 메모리 Grader output
1 Correct 2 ms 8796 KB Output is correct
2 Correct 3 ms 8868 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 2 ms 8796 KB Output is correct
2 Correct 2 ms 8796 KB Output is correct
3 Correct 2 ms 8636 KB Output is correct
4 Correct 2 ms 8792 KB Output is correct
5 Correct 2 ms 8792 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 2 ms 8796 KB Output is correct
2 Correct 2 ms 8792 KB Output is correct
3 Correct 2 ms 8796 KB Output is correct
4 Correct 2 ms 8796 KB Output is correct
5 Correct 2 ms 8796 KB Output is correct
6 Correct 2 ms 8796 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 2 ms 8796 KB Output is correct
2 Correct 2 ms 8792 KB Output is correct
3 Correct 2 ms 8796 KB Output is correct
4 Correct 2 ms 8796 KB Output is correct
5 Correct 2 ms 8796 KB Output is correct
6 Correct 2 ms 8796 KB Output is correct
7 Correct 2 ms 8796 KB Output is correct
8 Correct 2 ms 8796 KB Output is correct
9 Correct 2 ms 8796 KB Output is correct
10 Correct 2 ms 8884 KB Output is correct
11 Correct 2 ms 8796 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 2 ms 8796 KB Output is correct
2 Correct 2 ms 8792 KB Output is correct
3 Correct 2 ms 8796 KB Output is correct
4 Correct 2 ms 8796 KB Output is correct
5 Correct 2 ms 8796 KB Output is correct
6 Correct 2 ms 8796 KB Output is correct
7 Correct 2 ms 8796 KB Output is correct
8 Correct 2 ms 8796 KB Output is correct
9 Correct 2 ms 8796 KB Output is correct
10 Correct 2 ms 8884 KB Output is correct
11 Correct 2 ms 8796 KB Output is correct
12 Correct 59 ms 21456 KB Output is correct
13 Correct 65 ms 26708 KB Output is correct
14 Correct 42 ms 15824 KB Output is correct
15 Correct 49 ms 17232 KB Output is correct
16 Correct 46 ms 15952 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 2 ms 8796 KB Output is correct
2 Correct 2 ms 8792 KB Output is correct
3 Correct 2 ms 8796 KB Output is correct
4 Correct 2 ms 8796 KB Output is correct
5 Correct 2 ms 8796 KB Output is correct
6 Correct 2 ms 8796 KB Output is correct
7 Correct 2 ms 8796 KB Output is correct
8 Correct 2 ms 8796 KB Output is correct
9 Correct 2 ms 8796 KB Output is correct
10 Correct 2 ms 8884 KB Output is correct
11 Correct 2 ms 8796 KB Output is correct
12 Correct 2 ms 8792 KB Output is correct
13 Correct 2 ms 8796 KB Output is correct
14 Correct 2 ms 8840 KB Output is correct
15 Correct 2 ms 8792 KB Output is correct
16 Correct 2 ms 8792 KB Output is correct
17 Correct 2 ms 8796 KB Output is correct
18 Correct 2 ms 8792 KB Output is correct
19 Correct 2 ms 8796 KB Output is correct
20 Correct 2 ms 8796 KB Output is correct
21 Correct 2 ms 8796 KB Output is correct
22 Correct 2 ms 8796 KB Output is correct
23 Correct 2 ms 8792 KB Output is correct
24 Correct 2 ms 8796 KB Output is correct
25 Correct 2 ms 8796 KB Output is correct
26 Correct 2 ms 9096 KB Output is correct
27 Correct 2 ms 8796 KB Output is correct
28 Correct 2 ms 8792 KB Output is correct
29 Correct 2 ms 8792 KB Output is correct
30 Correct 2 ms 8796 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 2 ms 8796 KB Output is correct
2 Correct 2 ms 8792 KB Output is correct
3 Correct 2 ms 8796 KB Output is correct
4 Correct 2 ms 8796 KB Output is correct
5 Correct 2 ms 8796 KB Output is correct
6 Correct 2 ms 8796 KB Output is correct
7 Correct 2 ms 8796 KB Output is correct
8 Correct 2 ms 8796 KB Output is correct
9 Correct 2 ms 8796 KB Output is correct
10 Correct 2 ms 8884 KB Output is correct
11 Correct 2 ms 8796 KB Output is correct
12 Correct 59 ms 21456 KB Output is correct
13 Correct 65 ms 26708 KB Output is correct
14 Correct 42 ms 15824 KB Output is correct
15 Correct 49 ms 17232 KB Output is correct
16 Correct 46 ms 15952 KB Output is correct
17 Correct 2 ms 8792 KB Output is correct
18 Correct 2 ms 8796 KB Output is correct
19 Correct 2 ms 8840 KB Output is correct
20 Correct 2 ms 8792 KB Output is correct
21 Correct 2 ms 8792 KB Output is correct
22 Correct 2 ms 8796 KB Output is correct
23 Correct 2 ms 8792 KB Output is correct
24 Correct 2 ms 8796 KB Output is correct
25 Correct 2 ms 8796 KB Output is correct
26 Correct 2 ms 8796 KB Output is correct
27 Correct 2 ms 8796 KB Output is correct
28 Correct 2 ms 8792 KB Output is correct
29 Correct 2 ms 8796 KB Output is correct
30 Correct 2 ms 8796 KB Output is correct
31 Correct 2 ms 9096 KB Output is correct
32 Correct 2 ms 8796 KB Output is correct
33 Correct 2 ms 8792 KB Output is correct
34 Correct 2 ms 8792 KB Output is correct
35 Correct 2 ms 8796 KB Output is correct
36 Correct 69 ms 21452 KB Output is correct
37 Correct 65 ms 26832 KB Output is correct
38 Correct 37 ms 15832 KB Output is correct
39 Correct 51 ms 17236 KB Output is correct
40 Correct 49 ms 15956 KB Output is correct
41 Correct 62 ms 24072 KB Output is correct
42 Correct 63 ms 25220 KB Output is correct
43 Correct 33 ms 15056 KB Output is correct
44 Correct 49 ms 16844 KB Output is correct
45 Correct 45 ms 15952 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 2 ms 8796 KB Output is correct
2 Correct 3 ms 8868 KB Output is correct
3 Correct 2 ms 8796 KB Output is correct
4 Correct 2 ms 8796 KB Output is correct
5 Correct 2 ms 8636 KB Output is correct
6 Correct 2 ms 8792 KB Output is correct
7 Correct 2 ms 8792 KB Output is correct
8 Correct 2 ms 8796 KB Output is correct
9 Correct 2 ms 8792 KB Output is correct
10 Correct 2 ms 8796 KB Output is correct
11 Correct 2 ms 8796 KB Output is correct
12 Correct 2 ms 8796 KB Output is correct
13 Correct 2 ms 8796 KB Output is correct
14 Correct 2 ms 8796 KB Output is correct
15 Correct 2 ms 8796 KB Output is correct
16 Correct 2 ms 8796 KB Output is correct
17 Correct 2 ms 8884 KB Output is correct
18 Correct 2 ms 8796 KB Output is correct
19 Correct 59 ms 21456 KB Output is correct
20 Correct 65 ms 26708 KB Output is correct
21 Correct 42 ms 15824 KB Output is correct
22 Correct 49 ms 17232 KB Output is correct
23 Correct 46 ms 15952 KB Output is correct
24 Correct 2 ms 8792 KB Output is correct
25 Correct 2 ms 8796 KB Output is correct
26 Correct 2 ms 8840 KB Output is correct
27 Correct 2 ms 8792 KB Output is correct
28 Correct 2 ms 8792 KB Output is correct
29 Correct 2 ms 8796 KB Output is correct
30 Correct 2 ms 8792 KB Output is correct
31 Correct 2 ms 8796 KB Output is correct
32 Correct 2 ms 8796 KB Output is correct
33 Correct 2 ms 8796 KB Output is correct
34 Correct 2 ms 8796 KB Output is correct
35 Correct 2 ms 8792 KB Output is correct
36 Correct 2 ms 8796 KB Output is correct
37 Correct 2 ms 8796 KB Output is correct
38 Correct 2 ms 9096 KB Output is correct
39 Correct 2 ms 8796 KB Output is correct
40 Correct 2 ms 8792 KB Output is correct
41 Correct 2 ms 8792 KB Output is correct
42 Correct 2 ms 8796 KB Output is correct
43 Correct 69 ms 21452 KB Output is correct
44 Correct 65 ms 26832 KB Output is correct
45 Correct 37 ms 15832 KB Output is correct
46 Correct 51 ms 17236 KB Output is correct
47 Correct 49 ms 15956 KB Output is correct
48 Correct 62 ms 24072 KB Output is correct
49 Correct 63 ms 25220 KB Output is correct
50 Correct 33 ms 15056 KB Output is correct
51 Correct 49 ms 16844 KB Output is correct
52 Correct 45 ms 15952 KB Output is correct
53 Correct 68 ms 26764 KB Output is correct
54 Correct 80 ms 25020 KB Output is correct
55 Correct 28 ms 14296 KB Output is correct
56 Correct 55 ms 20308 KB Output is correct
57 Correct 48 ms 16212 KB Output is correct
58 Correct 47 ms 15952 KB Output is correct
59 Correct 47 ms 17240 KB Output is correct
60 Correct 48 ms 15920 KB Output is correct