제출 #297098

#제출 시각아이디문제언어결과실행 시간메모리
297098shayan_pHighway Tolls (IOI18_highway)C++17
90 / 100
636 ms12484 KiB
#include<bits/stdc++.h>
#include "highway.h"

#define F first
#define S second
#define PB push_back
#define sz(s) (int(s.size()))
#define bit(n, k) (((n)>>(k))&1)

using namespace std;

typedef long long ll;
typedef pair<int, int> pii;

const int maxn = 1e5 + 10, inf = 1e9 + 10;

ll len;
vector<pii> v[maxn];
int h[maxn];
bool bad[maxn];

void bfs(int s){
    queue<int> q;
    q.push(s);
    fill(h, h + maxn, inf);
    h[s] = 0;
    while(sz(q)){
	int u = q.front();
	q.pop();
	for(auto [y, c] : v[u]){
	    if(h[y] == inf){
		h[y] = h[u] + 1;
		q.push(y);
	    }
	}
    }
}

void find_pair(int N, vector<int> U, vector<int> V, int A, int B){
    int m = sz(U);
    vector<int> w(m);
    for(int i = 0; i < m; i++){
	v[U[i]].PB({V[i], i});
	v[V[i]].PB({U[i], i});
    }
    len = ask(w);

    auto del = [&](int u){
		   for(auto [y, c] : v[u]){
		       w[c] = 1;
		   }
	       };
    auto bin = [&](vector<int> vrts, function<bool(vector<int>)> f){
		   int l = 0, r = sz(vrts);
		   while(r-l > 1){
		       int mid = (l+r) >> 1;
		       vector<int> vrts2;
		       for(int i = 0; i < mid; i++)
			   vrts2.PB(vrts[i]);
		       if(f(vrts2))
			   r = mid;
		       else
			   l = mid;
		   }
		   return vrts[l];
	       };
    auto delete_and_ask = [&](vector<int> v){
			      for(int i = 0; i < N; i++){
				  if(bad[i])
				      v.PB(i);
			      }
			      for(int i = 0; i < m; i++)
				  w[i] = 0;		  
			      for(int u:  v)
				  del(u);
			      return ask(w) != len;
			  };
    vector<int> vrts(N), tmp;

    auto erase_bads = [&](int root){
			  for(int x : vrts){
			      if(x == root)
				  break;
			      bad[x] = 1;
			  }
			  tmp = vrts;
			  vrts.clear();
			  for(int x : tmp){
			      if(!bad[x])
				  vrts.PB(x);
			  }
		      };
    
    iota(vrts.begin(), vrts.end(), 0);
    int root = bin(vrts, delete_and_ask);
    erase_bads(root);
    
    bfs(root);
    sort(vrts.begin(), vrts.end(), [&](int i, int j){ return h[i] > h[j]; } );
    int ans1 = bin(vrts, delete_and_ask);	
    erase_bads(ans1);
    
    bfs(ans1);
    sort(vrts.begin(), vrts.end(), [&](int i, int j){ return h[i] > h[j]; } );
    int ans2 = bin(vrts, delete_and_ask);

    answer(ans1, ans2);
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...