Submission #884148

#TimeUsernameProblemLanguageResultExecution timeMemory
884148mgl_diamondPastiri (COI20_pastiri)C++17
100 / 100
482 ms112032 KiB
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
using ii = pair<ll, ll>;

#define foru(i, l, r) for(int i=(l); i<=(r); ++i)
#define ford(i, l, r) for(int i=(l); i>=(r); --i)
#define fore(x, v) for(auto &x : v)
#define all(x) (x).begin(), (x).end()
#define sz(x) (int)(x).size()
#define fi first
#define se second
#define file "input"

template<class T> bool minimize(T &a, T b) { if (a > b) { a = b; return 1; } return 0; }
template<class T> bool maximize(T &a, T b) { if (a < b) { a = b; return 1; } return 0; }

void setIO() {
  ios::sync_with_stdio(0);
  cin.tie(0); cout.tie(0);
  if (fopen(file".inp", "r")) {
    freopen(file".inp", "r", stdin);
    freopen(file".out", "w", stdout);
  }
}

const int N = 5e5+5;

int n, k;
vector<int> adj[N];
bool sheep[N];

void solve1() {
  vector<int> rem;
  foru(i, 1, n) if (sheep[i]) rem.push_back(i);
  sort(all(rem));
  vector<int> ans;
  while (!rem.empty()) {
    int u = rem.back();
    if (rem.size() >= 2 && (rem[sz(rem)-2]+u)%2==0) {
      int v = rem[sz(rem)-2];
      ans.push_back((u+v)/2);
      rem.pop_back();
    } else ans.push_back(u);
    rem.pop_back();
  }
  cout << sz(ans) << "\n";
  fore(x, ans) cout << x << " ";
}

int kth[N], DP[1<<15], F[1<<15];
ii T[1<<15];

void solve2() {
  memset(DP, 0x3f, sizeof(DP));
  queue<int> qu;
  vector<int> tmp;
  vector<vector<int>> dist;
  int id=0;
  foru(i, 1, n) if (sheep[i]) {
    tmp.push_back(i);
    queue<int> qu;
    qu.push(i);
    dist.push_back(vector<int>(n+1, -1));
    dist[id][i] = 0;
    while (!qu.empty()) {
      int u = qu.front();
      qu.pop();
      fore(v, adj[u]) if (dist[id][v] == -1) {
        dist[id][v] = dist[id][u]+1;
        qu.push(v);
      }
    }
    kth[i] = id++;
  }
  foru(i, 1, n) {
    int mask = 0;
    int mn = N;
    fore(v, tmp)
      if (dist[kth[v]][i] < mn) {
        mn = dist[kth[v]][i];
        mask = 1<<kth[v];
      }
      else if (dist[kth[v]][i] == mn)
        mask += 1<<kth[v];
    F[mask] = i;
  }

  int full = (1<<id)-1;
  DP[0] = 0;
  foru(m, 0, full) {
    for(int sub=m; sub>0; sub=(sub-1)&m) if (F[sub] > 0) {
      if (minimize(DP[m], DP[m^sub] + 1)) {
        T[m].fi = sub;
        T[m].se = m^sub;
      }
    }
    for(int sub=m; sub>0; sub=(sub-1)&m)
      if (minimize(DP[sub], DP[m])) {
        T[sub].fi = -1;
        T[sub].se = m;
      }
  }
  cout << DP[full] << "\n";
  vector<int> shepherd;
  while(full > 0) {
    if (T[full].fi != -1) shepherd.push_back(F[T[full].fi]);
    full = T[full].se;
  }
  fore(x, shepherd) {
    cout << x << " ";
  }
}

int vis[N], dist[N], high[N], up[N];

int mark[N*2];
vector<int> block[N];

// dist[v] + high[v] == high[u]

void dfs_first(int u, int p) {
  if (!mark[dist[u]+high[u]]) mark[dist[u]+high[u]] = u;
  if (sheep[u]) {
    assert(mark[high[u]] > 0);
    up[u] = mark[high[u]];
    block[high[u]].push_back(u);
  }
  fore(v, adj[u]) {
    if (v == p) continue;
    high[v] = high[u]+1;
    dfs_first(v, u);
  }
  if (mark[dist[u]+high[u]] == u) mark[dist[u]+high[u]] = 0;
}

void dfs_second(int u) {
  vis[u] = 1;
  fore(v, adj[u]) {
    if (vis[v]) continue;
    if (dist[u]==dist[v]+1) dfs_second(v);
  }
}

void solve3() {
  queue<int> qu;
  memset(dist, -1, sizeof(dist));
  foru(i, 1, n) if (sheep[i]) { dist[i] = 0; qu.push(i); }
  while (!qu.empty()) {
    int u = qu.front();
    qu.pop();
    fore(v, adj[u]) if (dist[v] == -1) { dist[v] = dist[u]+1; qu.push(v); }
  }
  dfs_first(1, 0);
  vector<int> ans;
  ford(i, n, 0) if (!block[i].empty()) {
    fore(u, block[i]) if (!vis[u]) {
      ans.push_back(up[u]);
      dfs_second(up[u]);
    }
  }
  cout << sz(ans) << "\n";
  fore(x, ans) cout << x << " ";
}

int main() {
  setIO();

  bool line = 1;

  cin >> n >> k;
  foru(i, 2, n) {
    int u, v;
    cin >> u >> v;
    adj[u].push_back(v);
    adj[v].push_back(u);
  }
  foru(i, 1, k) {
    int node;
    cin >> node;
    sheep[node] = 1;
  }

  foru(i, 1, n)
    line &= (sz(adj[i]) <= 2);
  line &= sz(adj[1]) == 1;
  line &= sz(adj[n]) == 1;

  solve3();
//  if (line) solve1();
//  else if (k <= 15) solve2();
//  else solve3();
}

Compilation message (stderr)

pastiri.cpp: In function 'void setIO()':
pastiri.cpp:22:12: warning: ignoring return value of 'FILE* freopen(const char*, const char*, FILE*)' declared with attribute 'warn_unused_result' [-Wunused-result]
   22 |     freopen(file".inp", "r", stdin);
      |     ~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~
pastiri.cpp:23:12: warning: ignoring return value of 'FILE* freopen(const char*, const char*, FILE*)' declared with attribute 'warn_unused_result' [-Wunused-result]
   23 |     freopen(file".out", "w", stdout);
      |     ~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...