답안 #877249

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
877249 2023-11-23T04:55:36 Z huutuan Pastiri (COI20_pastiri) C++14
8 / 100
1000 ms 119752 KB
#include<bits/stdc++.h>
#define taskname "sheep"

using namespace std;

const int inf=1e9;

struct Node{
   int val, lazy;
   Node (int a=inf, int b=0): val(a), lazy(b){}
   friend Node operator+(const Node &a, const Node &b){
      return Node(min(a.val, b.val));
   }
};

struct SegmentTree{
   int n;
   vector<Node> t;
   void init(int _n){
      n=_n;
      t.assign(4*n+1, Node());
   }
   void build(int k, int l, int r, int *a){
      if (l==r){
         t[k]=Node(a[l]);
         return;
      }
      int mid=(l+r)>>1;
      build(k<<1, l, mid, a);
      build(k<<1|1, mid+1, r, a);
      t[k]=t[k<<1]+t[k<<1|1];
   }
   void apply(int k, int val){
      t[k].val+=val;
      t[k].lazy+=val;
   }
   void push(int k){
      apply(k<<1, t[k].lazy);
      apply(k<<1|1, t[k].lazy);
      t[k].lazy=0;
   }
   void point_update(int k, int l, int r, int pos, int val){
      if (l==r){
         t[k]=Node(val);
         return;
      }
      push(k);
      int mid=(l+r)>>1;
      if (pos<=mid) point_update(k<<1, l, mid, pos, val);
      else point_update(k<<1|1, mid+1, r, pos, val);
      t[k]=t[k<<1]+t[k<<1|1];
   }
   void update(int k, int l, int r, int L, int R, int val){
      if (r<L || R<l) return;
      if (L<=l && r<=R){
         apply(k, val);
         return;
      }
      push(k);
      int mid=(l+r)>>1;
      update(k<<1, l, mid, L, R, val);
      update(k<<1|1, mid+1, r, L, R, val);
      t[k]=t[k<<1]+t[k<<1|1];
   }
   int walk(int k, int l, int r, int val){
      if (l==r) return l;
      push(k);
      int mid=(l+r)>>1;
      if (t[k<<1].val==val) return walk(k<<1, l, mid, val);
      return walk(k<<1|1, mid+1, r, val);
   }
} st;

const int N=5e5+1;
int n, k, a[N], dist[N], best[N], tin[N], tout[N], tdfs, f[N], vis[N], tour[N];
vector<int> g[N], trace[N], gg[N];

void pre_dfs(int u, int p){
   tin[u]=++tdfs;
   tour[tdfs]=u;
   for (int v:g[u]) if (v!=p){
      dist[v]=dist[u]+1;
      pre_dfs(v, u);
   }
   tout[u]=tdfs;
}

void dfs(int u, int p){
   best[u]=st.t[1].val;
   for (int v:g[u]) if (v!=p){
      st.update(1, 1, n, 1, n, 1);
      st.update(1, 1, n, tin[v], tout[v], -2);
      dfs(v, u);
      st.update(1, 1, n, tin[v], tout[v], 2);
      st.update(1, 1, n, 1, n, -1);
   }
}

void dfs_trace(int u, int p){
   best[u]=st.t[1].val;
   if (f[u]){
      vector<int> idx;
      while (st.t[1].val==best[u]){
         idx.push_back(st.walk(1, 1, n, best[u]));
         st.point_update(1, 1, n, idx.back(), inf);
         trace[u].push_back(tour[idx.back()]);
      }
      for (int i:idx) st.point_update(1, 1, n, i, best[u]);
   }
   for (int v:g[u]) if (v!=p){
      st.update(1, 1, n, 1, n, 1);
      st.update(1, 1, n, tin[v], tout[v], -2);
      dfs_trace(v, u);
      st.update(1, 1, n, tin[v], tout[v], 2);
      st.update(1, 1, n, 1, n, -1);
   }
}

int32_t main(){
   ios_base::sync_with_stdio(false);
   cin.tie(nullptr);
   // freopen(taskname".inp", "r", stdin);
   // freopen(taskname".out", "w", stdout);
   cin >> n >> k;
   bool sub1=1;
   for (int i=1; i<n; ++i){
      int u, v; cin >> u >> v;
      sub1&=abs(u-v)==1;
      g[u].push_back(v);
      g[v].push_back(u);
   }
   for (int i=1; i<=k; ++i){
      int x; cin >> x;
      a[x]=1;
   }
   if (sub1){
      vector<int> v;
      for (int i=1; i<=n; ++i) if (a[i]) v.push_back(i);
      vector<int> ans;
      for (int i=0; i<k; ++i){
         if (i+1<k && (v[i+1]-v[i])%2==0){
            ans.push_back((v[i+1]+v[i])/2);
            ++i;
         }else ans.push_back(v[i]);
      }
      cout << ans.size() << '\n';
      for (int i:ans) cout << i << ' ';
      return 0;
   }
   pre_dfs(1, 0);
   st.init(n);
   for (int i=1; i<=n; ++i){
      if (a[i]) st.point_update(1, 1, n, tin[i], dist[i]);
   }
   dfs(1, 0);
   vector<int> ans;
   for (int i=1; i<=n; ++i){
      bool check=1;
      for (int j:g[i]) check&=best[i]>=best[j];
      if (check) f[i]=1, ans.push_back(i);
   }
   for (int i=1; i<=n; ++i){
      if (a[i]) st.point_update(1, 1, n, tin[i], dist[i]);
   }
   dfs_trace(1, 0);
   for (int i=1; i<=n; ++i){
      sort(trace[i].begin(), trace[i].end(), [&](int x, int y){
         return dist[x]<dist[y];
      });
      for (int j:trace[i]) gg[j].push_back(i);
   }
   set<pair<pair<int, int>, int>> st;
   for (int i=1; i<=n; ++i) if (trace[i].size()) st.insert({{dist[trace[i].back()], -dist[i]}, i});
   vector<int> real;
   while (st.size()){
      int idx=st.rbegin()->second; st.erase(prev(st.end()));
      real.push_back(idx);
      vis[idx]=1;
      for (int j:trace[idx]) if (!vis[j]){
         vis[j]=1;
         for (int l:gg[j]) if (trace[l].size() && l!=idx){
            st.erase({{dist[trace[l].back()], -dist[l]}, l});
            while (trace[l].size() && vis[trace[l].back()]) trace[l].pop_back();
            if (trace[l].size()) st.insert({{dist[trace[l].back()], -dist[l]}, l});
         }
      }
      trace[idx].clear();
   }
   cout << real.size() << '\n';
   for (int i:real) cout << i << ' ';
   return 0;
}
# 결과 실행 시간 메모리 Grader output
1 Correct 103 ms 55248 KB Output is correct
2 Correct 100 ms 62800 KB Output is correct
3 Correct 111 ms 62804 KB Output is correct
4 Correct 138 ms 69984 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 15 ms 46524 KB Output is correct
2 Correct 15 ms 46168 KB Output is correct
3 Execution timed out 1073 ms 119752 KB Time limit exceeded
4 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 12 ms 45912 KB Output isn't correct
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Execution timed out 1098 ms 99688 KB Time limit exceeded
2 Halted 0 ms 0 KB -