제출 #1237214

#제출 시각아이디문제언어결과실행 시간메모리
1237214mychecksedad통행료 (IOI18_highway)C++20
51 / 100
233 ms327680 KiB
#include "highway.h"
#include<bits/stdc++.h>
using namespace std;
#define pb push_back
#define vi vector<int>
#define ff first
#define ss second
#define ll long long int
#define pii pair<int,int>
const int N = 2e5+100;

int s[N], dep[N], PAR[N];
vector<pii> g[N];
vi D[N];
vi D2[N];
bitset<N> VIS;
vi C[N];
vector<vi> QC[N];
void pre(int v, int p){
  s[v] = 1;
  for(auto [u, id]: g[v]){
    if(!VIS[u] && u != p){
      pre(u, v);
      s[v] += s[u];
    }
  }
}
int num;
int centro(int v, int p){
  for(auto [u, id]: g[v]){
    if(!VIS[u] && u != p){
      if(s[u] >= (num+1)/2) return centro(u, v);
    }
  }
  return v;
}
void f(int v, int dep){
  pre(v, v);
  num = s[v];
  v = centro(v, v);
  C[dep].pb(v);
  vi q;
  for(auto [u, id]: g[v]){
    if(!VIS[u]){
      q.pb(id);
    }
  }
  QC[dep].pb(q);
  VIS[v] = 1;
  for(auto [u, id]: g[v]){
    if(!VIS[u]) f(u, dep + 1);
  }
}
void dfs(int v, int p){
  s[v] = 1;
  D[dep[v]].pb(v);
  for(auto [u, id]: g[v]){
    if(!VIS[u] && u != p){
      PAR[u] = id;
      dep[u] = dep[v] + 1;
      dfs(u, v);
      s[v] += s[u];
    }
  }
}
void dfs2(int v, int p){
  s[v] = 1;
  D2[dep[v]].pb(v);
  for(auto [u, id]: g[v]){
    if(!VIS[u] && u != p){
      PAR[u] = id;
      dep[u] = dep[v] + 1;
      dfs2(u, v);
      s[v] += s[u];
    }
  }
}
void gg(int v, int node, int idd){
  pre(v, v);
  num = s[v];
  v = centro(v, v);
  if(v == node){
    for(auto [u, id]: g[v]){
      if(id == idd){
        PAR[u] = id;
        dep[u] = 1;
        dfs(u, v);
        break;
      }
    }
  }

  VIS[v] = 1;
  for(auto [u, id]: g[v]){
    if(!VIS[u]) gg(u, node, idd);
  }
}
void ggg(int v, int node, int idd, int idd2){
  pre(v, v);
  num = s[v];
  v = centro(v, v);
  if(v == node){
    for(auto [u, id]: g[v]){
      if(id == idd){
        // cerr << u << ' ';
        PAR[u] = id;
        dep[u] = 1;
        dfs(u, v);
      }
      if(id == idd2){
        // cerr << u << ' ';
        PAR[u] = id;
        dep[u] = 1;
        dfs2(u, v);
      }
    }
  }

  VIS[v] = 1;
  for(auto [u, id]: g[v]){
    if(!VIS[u]) ggg(u, node, idd, idd2);
  }
}




void find_pair(int n, std::vector<int> U, std::vector<int> V, int A, int B) {
  int m = U.size();

  for(int i = 0; i < m; ++i){
    g[U[i]].pb({V[i], i});
    g[V[i]].pb({U[i], i});
  }

  vi w(m);
  ll val = ask(w);
  int k = val / A;

  f(0, 0);

  for(int d = 0; d < 30; ++d){
    if(C[d].empty()) assert(false);

    w.clear();
    w.resize(m);
    for(auto vv: QC[d]){
      for(int x: vv) w[x] = 1;
    }

    ll ress = ask(w);
    if(ress != val){
      // that means we found one centroid which satisfies
      // let's bs
      int sz = QC[d].size();
      int l = 0, r = sz - 2, res = sz - 1;
      while(l <= r){
        int mid = l+r>>1;
        w.clear();
        w.resize(m);
        for(int i = 0; i <= mid; ++i){
          for(int x: QC[d][i]) w[x] = 1;
        }
        ll y = ask(w);
        if(y != val){
          res = mid;
          r = mid - 1;
        }else{
          l = mid + 1;
        }
      }
      // now we now the path passes through C
      int node = C[d][res];

      vi arr = QC[d][res];
      sz = QC[d][res].size();
      // now we do bit coloring...
      // vector<bool> RES;
      // for(int bit = 0; (1<<bit) < sz*2; ++sz){
      //   w.clear();
      //   w.resize(m);
      //   for(int j = 0; j < sz; ++j) if(j&(1<<bit)) w[arr[j]] = 1;
      //   ll y = ask(w);
      //   RES.pb(y != val);
      // }
      // int cnt = 0;
      // for(auto x: RES) cnt += x;
      // cerr << node << '\n';
      l = 0, r = sz-2, res = sz-1;
      while(l <= r){
        int mid = l+r>>1;
        w.clear();
        w.resize(m);
        for(int j = 0; j <= mid; ++j) w[arr[j]] = 1;
        ll y = ask(w);
        if(y != val){
          res = mid;
          r = mid - 1;
        }else{
          l = mid + 1;
        }
      }
      l = 1, r = sz-1;
      int res2 = 0;
      while(l <= r){
        int mid = l+r>>1;
        w.clear();
        w.resize(m);
        for(int j = mid; j < sz; ++j) w[arr[j]] = 1;
        ll y = ask(w);
        if(y != val){
          res2 = mid;
          l = mid + 1;
        }else{
          r = mid - 1;
        }
      }
      // cerr << res << ' ' << res2 << '\n';
      // now we now that there aren't many options..
      if(res == res2){
        // this is nicer tho
        // basically s=0 case
        VIS = 0;
        gg(0, node, arr[res]);
        arr = D[k];
        sz = arr.size();
        l = 0, r = sz - 2, res = sz - 1;
        while(l <= r){
          int mid = l+r>>1;
          w.clear();
          w.resize(m);
          for(int j = 0; j <= mid; ++j) w[PAR[arr[j]]] = 1;
          ll y = ask(w);
          if(y != val){
            res = mid;
            r = mid - 1;
          }else{
            l = mid + 1;
          }
        }
        answer(node, arr[res]);
        // exit(0);
      }else{
        // cerr << arr[res] << ' ' << res2 << " crap\n";
        // now we gotta solve...
        VIS = 0;
        ggg(0, node, arr[res], arr[res2]);

        w.clear();
        w.resize(m);
        for(int i = 0; i <= n; ++i) for(int x: D[i]) w[PAR[x]] = 1;

        ll y = ask(w);

      // cerr << y << ' ';

        int k1;
        for(ll j = 1; j <= n; ++j){
          if(j * B + A * (k-j) == y){
            k1 = j;
            break;
          }
        }
        // cerr << k1 << ' ';
        // so now we know both depths
        int k2 = k - k1;

        arr.clear();
        for(int x: D[k1]) arr.pb(x);
        sz = arr.size();
        l = 0, r = sz - 2, res = sz - 1;
        while(l <= r){
          int mid = l+r>>1;
          w.clear();
          w.resize(m);
          for(int j = 0; j <= mid; ++j) w[PAR[arr[j]]] = 1;
          ll y = ask(w);
          if(y != val){
            res = mid;
            r = mid - 1;
          }else{
            l = mid + 1;
          }
        }
        int s = arr[res];


        arr.clear();
        for(int x: D2[k2]) arr.pb(x);
        sz = arr.size();
        l = 0, r = sz - 2, res = sz - 1;
        while(l <= r){
          int mid = l+r>>1;
          w.clear();
          w.resize(m);
          for(int j = 0; j <= mid; ++j) w[PAR[arr[j]]] = 1;
          ll y = ask(w);
          if(y != val){
            res = mid;
            r = mid - 1;
          }else{
            l = mid + 1;
          }
        }
        int t = arr[res];
        answer(s, t);
      }
      return;
    }

  }

  
  // vector<bool> vis(n);
  // queue<int> q;
  // vi dist(n);
  // q.push(0);
  // vis[0] = 1;
  // vi T, par(n);
  // while(!q.empty()){
  //   int v = q.front(); q.pop();
  //   if(dist[v] == k) T.pb(v);
  //   for(auto [u, id]: g[v]){
  //     if(!vis[u]){
  //       par[u] = id;
  //       dist[u] = dist[v] + 1;
  //       q.push(u);
  //       vis[u] = 1;
  //     }
  //   }
  // }

  // int l = 0, r = int(T.size()) - 2, res = int(T.size())-1;
  // while(l <= r){
  //   int mid = l+r>>1;
  //   w.clear();
  //   w.resize(m, 0);
  //   for(int i = 0; i <= mid; ++i) w[par[T[i]]] = 1;
  //   if(ask(w) > val){
  //     res = mid;
  //     r = mid - 1;
  //   }else{
  //     l = mid + 1;
  //   }
  // }
  // answer(0, T[res]);
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...