Submission #559316

#TimeUsernameProblemLanguageResultExecution timeMemory
559316RedhoodPastiri (COI20_pastiri)C++14
100 / 100
987 ms139160 KiB
#include<bits/stdc++.h>

#define fi first
#define se second
#define sz(x) (int)(x).size()
#define pb push_back
#define mkp make_pair
using namespace std;

typedef long long ll;
typedef long double ld;
//#define int long long
const int N = (int)5e5 + 10;
const int inf = 1e9;
vector < int > g[N];

int mnd[N];
int shrt[N], pred[N], h[N];


int ans[N][2], high[N][2];


bool done[N][2];
void dfs(int v, int p, int H){
    h[v] = H;
    pred[v] = p;
    mnd[v] = (shrt[v] == 0 ? 0 : 1e9);
    for(auto &u : g[v]){
        if(u != p){
            dfs(u , v, H + 1);
            mnd[v] = min(mnd[v] , mnd[u] + 1);
        }
    }
}
int did[N][2];
bool bst[N];
bool better(int v){
    if(ans[v][1] == ans[v][0]){
        return high[v][1] < high[v][0];
    }
    return ans[v][1] < ans[v][0];
}
void dfs1(int v){
    /// mnd[v];
    for(auto &u : g[v]){
        if(u != pred[v]){
            dfs1(u);
            if(better(u))
                bst[u] = 1;
            else
                bst[u] = 0;
        }
    }



    vector < int > small;
    for(auto &u : g[v]){
        if(u == pred[v])
            continue;
        if(mnd[u] + 1 > mnd[v]){
        }else{
            small.pb(u);
        }
    }

    for(int fl = 0;fl < 2; ++fl){
        for(auto &u : g[v]){
            if(u == pred[v])
                continue;
            if(mnd[u] + 1 > mnd[v]){
                ans[v][fl] += ans[u][1];
                high[v][fl] = min(high[v][fl], high[u][1]);
            }
        }
         if(fl == 0){
                for(auto &u : small){
                    ans[v][fl] += ans[u][bst[u]];
                    high[v][fl] = min(high[v][fl], high[u][bst[u]]);
                }
            }else{
                /// two options
                pair < int , int > now = {ans[v][fl], high[v][fl]};
                for(auto &u : small){
                    now.fi += ans[u][1];
                    now.se = min(now.se, high[u][1]);
                }
                /// another option is to cover by the current

                for(auto &u : small){
                    ans[v][1] += ans[u][bst[u]];
                    high[v][1] = min(high[v][1], high[u][bst[u]]);
                }
                int x = h[v] - high[v][fl];
                bool put = 0;
                if(x != mnd[v]){
                    if(shrt[v] == mnd[v]){
                        ans[v][fl]++, put = 1;
                        high[v][fl] = min(high[v][fl], h[v] - shrt[v]);
                    }else
                        ans[v][fl] = inf;
                }


        //        cerr << " lol " << v << ' ' << now.ans << ' ' << now.high << endl;

                if((shrt[v]==0 && (now.se != h[v]))
                || (now > mkp(ans[v][fl], high[v][fl]))){
                    did[v][fl] = 1 + put;

                }else{
                    ans[v][fl] = now.fi;
                    high[v][fl] = now.se;
                }
            }
    }




}
bool visited[N][2];
bool answ[N];
void get_ans(int v , int fl){

    if(fl == 0){
        vector < int > small;
        for(auto &u : g[v]){
            if(u == pred[v])
                continue;
            if(mnd[u] + 1 > mnd[v]){
                get_ans(u, 1);
            }else{
                get_ans(u , bst[u]);
            }
        }
    }else{
        if(did[v][fl] > 0){
            answ[v] = did[v][fl] - 1;
            for(auto &u : g[v]){
                if(u == pred[v])
                    continue;
                if(mnd[u] + 1 > mnd[v]){
                    get_ans(u, 1);
                }else{
                    get_ans(u , bst[u]);
                }
            }
        }else{
//            assert(shrt[v] != 0);
            for(auto &u : g[v]){
                if(u == pred[v])
                    continue;
                get_ans(u , 1);
            }
        }

    }
}
signed main(){
    ios_base::sync_with_stdio(0), cin.tie(0) , cout.tie(0);
    int n , k;
    cin >> n >> k;


    for(int i=0;i<n;++i){
        high[i][0]=high[i][1]=inf;
    }

    for(int i = 0; i < n - 1; ++i){
        int a , b;
        cin >> a >> b;
        --a, --b;
        g[a].pb(b);
        g[b].pb(a);
    }
    vector < int > o(k);
    for(auto &i : o)
        cin >> i, --i;
    queue < int > bfs;
    for(auto &u : o)
        bfs.push(u);

    fill(shrt , shrt + N , -1);

    for(auto &u : o)
        shrt[u] = 0;
    while(!bfs.empty()){
        int v = bfs.front();
        bfs.pop();
        for(auto &u : g[v]){
            if(shrt[u] == -1){
                shrt[u] = shrt[v] + 1;
                bfs.push(u);
            }
        }
    }

    dfs(0 , -1, 0);




    dfs1(0);

    get_ans(0 , 1);


//    cout << "fukk \n";

//    for(int i = 0; i < n; ++i){
//        if(done[i][0])
//        cout << i << ' ' << 0 << ' ' << state[i][0].ans << '\n';
//
//        if(done[i][1])
//        cout << i << ' ' << 1 << ' ' << state[i][1].ans << '\n';
//        cout << " NEW \n";
//
//    }


    vector < int > take;
    for(int i = 0; i < n; ++i){
        if(answ[i])
            take.pb(i);
    }
    assert(sz(take) == ans[0][1]);

    cout << ans[0][1] << '\n';
    for(auto &u : take)
        cout << u + 1 << ' ';
    cout << '\n';




    return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...