Submission #1363461

#TimeUsernameProblemLanguageResultExecution timeMemory
1363461vlomaczkSpace Thief (JOI25_thief)C++20
100 / 100
27 ms3340 KiB
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
typedef long long ll;
using namespace __gnu_pbds;
using namespace std;
#include "thief.h"

template <typename T>
using ordered_set = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;

vector<vector<pair<int, int>>> g;
vector<int> is_off, sajz, par, ip, Zbior;

void sajz_dfs(int v, int p) {
    sajz[v] = 1;
    for(auto[u,k] : g[v]) {
        if(is_off[u] || u==p) continue;
        par[u] = v;
        ip[u] = k;
        sajz_dfs(u,v);
        sajz[v] += sajz[u];
    }
}

ll find_centroid(ll v, ll ts) {
	for(auto[u,k] : g[v]) {
		if(is_off[u]) continue;
		if(u==par[v]) {
			if(ts-sajz[v] > ts/2) return find_centroid(u,ts);
		} else {
			if(sajz[u] > ts/2) return find_centroid(u,ts);
		}
	}
	return v;
}

void V_dfs(int v, int p) {
    Zbior.push_back(v);
    for(auto[u,k] : g[v]) {
        if(is_off[u] || u==p) continue;
        V_dfs(u,v);
    }
}

vector<int> post,vis,res;
void post_dfs(int v, vector<int> &myres) {
    if(vis[v]) return;
    vis[v] = 1;
    for(auto[u,k] : g[v]) if(k%2 == myres[k/2]) post_dfs(u, myres);
    post.push_back(v);
}

vector<int> rep;
int Find(int v) { return rep[v]==v?v:rep[v]=Find(rep[v]); }
bool Union(int a, int b) { 
	if(Find(a)==Find(b)) return 0;
	rep[Find(a)]=Find(b);
	return 1;
}

vector<pair<int, int>> edges;
vector<int> mst;
int n,m;
int q_cnt = 0;
bool Query(vector<int> &R) {
	q_cnt++;
	while(post.size()) post.pop_back();
    vis.assign(n, 0);
    for(int i=0; i<n; ++i) post_dfs(i, R);
    reverse(post.begin(), post.end());
	vector<int> topo(n);
	for(int i=0; i<n; ++i) topo[post[i]] = i;
	for(int i=0; i<m; ++i) {
		if(mst[i]) continue;
		auto[a,b] = edges[i];
		if(topo[a] > topo[b]) R[i] = 1;
	}
	return query(R);
}

void solve(int N, int M, std::vector<int> U, std::vector<int> V) {
    n = N; m = M;

    is_off.assign(n,1);
    sajz.assign(n,0);
    g.assign(n, {});
    par.assign(n, 0);
    ip.assign(n, 0);
	while(edges.size()) edges.pop_back();

	rep.assign(n,0);
    mst.assign(m,0);
	for(int i=0; i<n; ++i) rep[i] = i;
	for(int i=0; i<m; ++i) {
		if(Union(U[i], V[i])) mst[i] = 1;
        edges.push_back({U[i], V[i]});
	}
    for(int i=0; i<m; ++i) {
		if(!mst[i]) continue;
        int a = U[i];
        int b = V[i];
        g[a].push_back({b, 2*i});
        g[b].push_back({a, 2*i+1});
    }
    vector<vector<int>> curr;
    curr.push_back({0});
    for(int i=1; i<n; ++i) curr.back().push_back(i);
    bool ok = 0;
    while(1) {
        vector<vector<int>> nxt;
        vector<vector<int>> change(3);

        for(auto VV : curr) {
            int v = VV[0];
            for(auto x : VV) is_off[x] = 0;
			sajz_dfs(v,v);
            int ctr = find_centroid(v,VV.size());
			v=ctr;
            sajz_dfs(v,v);
            
            vector<pair<int, int>> nei;
            for(auto[u,k] : g[v]) {
                if(is_off[u]) continue;
                nei.push_back({sajz[u], u});
            }
			sort(nei.begin(), nei.end());
			reverse(nei.begin(), nei.end());
            vector<vector<int>> part(3);
            int idx = 0, sum = 0;
			vector<int> S(3);
			ll half = (VV.size())/2;
            for(int i=0; i<(int)nei.size(); ++i) {
                if(S[idx] + nei[i].first > half) idx = (idx+1)%3;
				if(S[idx] + nei[i].first > half) idx = (idx+1)%3;
				S[idx] += nei[i].first;
				/*if(S[idx] > half) {
					cout << "-------------\n";
					cout << S[idx] << " " << half << "\n";
					cout << v << ": "; for(auto x : VV) cout << x << " "; cout << "\n";
					cout << "-------------\n";
				}*/
				part[idx].push_back(nei[i].second);
				idx = (idx+1)%3;
            }

            vector<vector<int>> whole(3);
            for(int k=0; k<3; ++k) {
                while(Zbior.size()) Zbior.pop_back();
                for(auto u : part[k]) {
                    V_dfs(u, v);
                }
                whole[k] = Zbior;
            }

            for(int k=0; k<3; ++k) {
                for(int l=0; l<3; ++l) {
                    for(auto x : whole[l]) {
                        change[k].push_back((ip[x]^(l==k)));
                    }
                }
            }

            for(int k=0; k<3; ++k) {
                while(Zbior.size()) Zbior.pop_back();
                int am = 0;
                int ss = 0;
                for(auto u : part[k]) {
                    ss += sajz[u];
                    am++;
                    V_dfs(u, v);
                }
                if(am > 1) Zbior.push_back(v);
                if(Zbior.size() && ss > 1) nxt.push_back(Zbior);
            }
            for(auto x : VV) is_off[x] = 1;
        }
        for(auto vec : change) {
            vector<int> Q(m,0);
            for(auto x : vec) Q[x/2] = x%2;
            if(Query(Q)) {
                res = Q;
                while(nxt.size()) nxt.pop_back();
                // cerr << "hurra\n";
                ok = 1;
                break;
            }
        }
        if(nxt.empty()) break;
        swap(curr, nxt);
    }
	// cerr << q_cnt << "\n";

    // for(int i=0; i<n-1; ++i) cerr << res[i] << " "; cerr << "\n";
    while(post.size()) post.pop_back();
    vis.assign(n, 0);
    for(int i=0; i<n; ++i) post_dfs(i, res);
    reverse(post.begin(), post.end());

    for(int i=0; i<m; ++i) {
        int a = U[i];
        int b = V[i];
        if(mst[i]) continue;
        g[a].push_back({b, 2*i});
        g[b].push_back({a, 2*i+1});
    }

    // for(auto v : post) cout << v << "\n";

    int lo = 0;
    int hi = n-1;
    while(lo < hi) {
        vector<int> Q = res;
        int mid = (lo+hi)/2;
        for(int i=0; i<=mid; ++i) {
            int v = post[i];
            for(auto[u,k] : g[v]) {
                if(res[k/2]==k%2) Q[k/2]^=1;
            }
        }
        if(query(Q)) lo = mid+1;
        else hi = mid;
    }
    int A = post[lo];

    reverse(post.begin(), post.end());
    lo = 0;
    hi = n-1;
    while(lo < hi) {
        vector<int> Q = res;
        int mid = (lo+hi)/2;
        for(int i=0; i<=mid; ++i) {
            int v = post[i];
            for(auto[u,k] : g[v]) {
                if(res[k/2]!=k%2) Q[k/2]^=1;
            }
        }
        if(query(Q)) lo = mid+1;
        else hi = mid;
    }
    int B = post[lo];
    // cerr << A << " " << B << "\n";
    answer(A,B);
}
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...