답안 #515428

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
515428 2022-01-19T07:17:45 Z blue Wells (CEOI21_wells) C++17
0 / 100
10000 ms 460 KB
#include <iostream>
#include <vector>
#include <algorithm>
#include <queue>
using namespace std;

const int maxN = 200;
#define sz(x) int(x.size())
using vi = vector<int>;
using vvi = vector<vi>;
using ll = long long;
const ll mod = 1'000'000'007;

vi edge[1+maxN];
vvi new_edge;

int dep(int u, int p, int d)
{
    int ans = d;
    for(int v: edge[u])
    {
        if(v == p) continue;
        ans = max(ans, dep(v, u, d+1));
    }
    return ans;
}

vvi dist;

void dfs(int src, int u, int p)
{
    for(int v: new_edge[u])
    {
        if(v == p) continue;
        dist[src][v] = dist[src][u] + 1;
        dfs(src, v, u);
    }
}

int N, K;

vi occ[1+maxN];

int rt;


int main()
{
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    cin >> N >> K;

    for(int e = 1; e <= N-1; e++)
    {
        int a, b;
        cin >> a >> b;
        edge[a].push_back(b);
        edge[b].push_back(a);
    }

    vi good(1+N, 0);
    int badct = 0;

    for(int i = 1; i <= N; i++)
    {
        vi dist;
        for(int j: edge[i])
        {
            dist.push_back(dep(j, i, 1));
        }

        sort(dist.begin(), dist.end());

        if(dist.back() >= K-1 || (sz(dist) >= 2 && dist[sz(dist)-1] + dist[sz(dist)-2] >= K-1))
            good[i] = 1;
        else
            badct++;
    }

    ll p2 = 1;
    for(int e = 1; e <= badct; e++)
        p2 = (p2 * 2) % mod;

    if(badct == N)
    {
        cout << p2 << '\n';
        return 0;
    }


    vi new_ind(1+(N-badct));
    new_edge = vvi(1+(N-badct));
    int ct = 0;

    for(int i = 1; i <= N; i++)
    {
        if(!good[i]) continue;
        new_ind[i] = ++ct;
    }

    for(int i = 1; i <= N; i++)
    {
        if(!good[i]) continue;
        for(int j: edge[i])
        {
            if(!good[j]) continue;
            new_edge[ new_ind[i] ].push_back(new_ind[j]);
        }
    }
        // for(int i = 1; i <= N; i++) cerr << i << " : " << new_ind[i] << '\n';

    N -= badct;




    dist = vvi(1+N, vi(1+N, 5*N));
    for(int i = 1; i <= N; i++)
    {
        dist[i][i] = 0;
        dfs(i, i, i);
    }


    int x = 1;
    int y = 1;

    for(int i = 1; i <= N; i++)
        if(dist[x][i] > dist[x][y])
            y = i;

    for(int i = 1; i <= N; i++)
        if(dist[y][i] > dist[y][x])
            x = i;

    vi lst{x};
    while(sz(lst) < K)
    {
        int h = lst.back();
        for(int i = 1; i <= N; i++)
        {
            if(dist[x][i] + dist[i][y] == dist[x][h] + dist[h][y] && dist[x][i] == dist[x][h]+1)
            {
                lst.push_back(i);
                break;
            }
        }
    }

    if(N >= 10) while(1);

    int basicRes = 0;

    for(int l: lst)
    {
        for(int i = 0; i < N; i += K)
            occ[i].clear();

        bool works = 1;

        // cerr << "\n\nl = " << l << '\n';


        rt = l;

        vi visit(1+N, 0);


        for(int i = 1; i <= N; i++)
        {
            if(dist[l][i] % K == 0) continue;
            if(visit[i]) continue;

            // cerr << "i = " << i << '\n';


            queue<int> tbv;
            vi vlist;

            tbv.push(i);
            visit[i] = 1;
            while(!tbv.empty())
            {
                int u = tbv.front();
                tbv.pop();
                vlist.push_back(u);

                for(int v: new_edge[u])
                {
                    if(visit[v]) continue;
                    if(dist[l][v] % K == 0) continue;

                    tbv.push(v);
                    visit[v] = 1;
                }
            }

            int z1 = vlist[0];
            int z2 = z1;
            for(int v: vlist)
                if(dist[v][z1] > dist[z2][z1])
                    z2 = v;

            for(int v: vlist)
                if(dist[v][z2] > dist[z2][z1])
                    z1 = v;

            if(dist[z1][z2] >= K-1)
                works = 0;
        }

        // cerr << "pre works = " << works << '\n';





        for(int i = 1; i <= N; i++)
            if(dist[l][i] % K == 0)
                for(int j = i+1; j <= N; j++)
                    if(dist[l][j] % K == 0)
                        if(dist[i][j] < K)
                            works = 0;


        if(works) basicRes++;
    }





    if(basicRes == 0)
        cout << "NO\n";
    else
        cout << "YES\n";

    cout << (basicRes * p2) % mod << '\n';
}
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 204 KB Output is correct
2 Execution timed out 10005 ms 460 KB Time limit exceeded
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 204 KB Output is correct
2 Execution timed out 10005 ms 460 KB Time limit exceeded
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 204 KB Output is correct
2 Execution timed out 10005 ms 460 KB Time limit exceeded
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 204 KB Output is correct
2 Execution timed out 10005 ms 460 KB Time limit exceeded
3 Halted 0 ms 0 KB -