답안 #1016114

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
1016114 2024-07-07T11:59:21 Z hotboy2703 Simurgh (IOI17_simurgh) C++17
0 / 100
1 ms 1372 KB
#include "simurgh.h"

#include<bits/stdc++.h>
using namespace std;
using ll = int;
#define pll pair <ll,ll>
#define fi first
#define se second
#define MP make_pair
#define sz(a) (ll((a).size()))
#define BIT(mask,i) (((mask) >> (i))&1)
#define MASK(i) (1LL << (i))
namespace TRUNG{
    const ll MAXN = 510;
    ll n,m;
    ll ans[MAXN*MAXN];
    vector <pll> g[MAXN];
    ll in[MAXN],out[MAXN];
    ll timeDFS;
    vector <pll> a;
    ll cnt_query;
    void dfs(ll u,ll p){
        in[u] = ++timeDFS;
        for (auto tmp:g[u]){
            ll v = tmp.fi,id = tmp.se;
            if (v==p)continue;
            dfs(v,u);
        }
        out[u] = timeDFS;
    }
    bool sus_edge(ll i,ll j){
        ll x = a[i].fi;
        if (in[a[i].se] > x)x = a[i].se;
        if ((in[x] <= out[a[j].fi] && out[a[j].fi] <= out[x]) != (in[x] <= out[a[j].se] && out[a[j].se] <= out[x]))return 1;
        return 0;
    }

    namespace DSU{
        ll dsu[MAXN];
        void init(){
            memset(dsu,-1,sizeof dsu);
        }
        ll f(ll x){
            if (dsu[x] < 0)return x;
            return (dsu[x] = f(dsu[x]));
        }
        void join(ll x,ll y){
            x = f(x),y = f(y);
            if (x!=y){
                if (dsu[x] > dsu[y])swap(x,y);
                dsu[x] += dsu[y];
                dsu[y] = x;
            }
        }
    }
    bitset <MAXN> bs[MAXN];
    ll ind[MAXN][MAXN];
    vector <ll> tree,extra;
    ll superior_count(vector <ll> all){
        DSU::init();
        for (auto x:all){
            DSU::join(a[x].fi,a[x].se);
        }
        ll res = 0;
        for (auto x:tree){
            if (DSU::f(a[x].fi) != DSU::f(a[x].se)){
                DSU::join(a[x].fi,a[x].se);
                all.push_back(x);
                res-=ans[x];
            }
        }
        assert(++cnt_query<=8000);
        res += count_common_roads(all);
        return res;
    }
    void solve(vector <ll> all,ll k){
        if (k==0)return;
        if (sz(all) == 1){
            ans[all[0]] = k;
            return;
        }
        ll mid = sz(all) / 2;
        vector <ll> L,R;
        for (ll i = 0;i < sz(all);i ++){
            if (i < mid)L.push_back(all[i]);
            else R.push_back(all[i]);
        }
        ll cnt_L = superior_count(L);
        solve(L,cnt_L);
        solve(R,k-cnt_L);
    }
}
std::vector<int> find_roads(int N, std::vector<int> U, std::vector<int> V) {
    using namespace TRUNG;
    memset(ans,-1,sizeof ans);
    ll m = sz(U),n=N ;
    a.resize(m);
    for (ll i = 0;i < m;i ++)a[i] = MP(U[i],V[i]);
    DSU::init();
    for (ll i = 0;i < m;i ++){
        if (sz(tree)==n-1){
            if (DSU::f(a[i].fi) != DSU::f(a[i].se))extra.push_back(i);
        }
        else{
            if (DSU::f(a[i].fi) != DSU::f(a[i].se))tree.push_back(i);
            if (sz(tree)==n-1)DSU::init();
        }
        DSU::join(a[i].fi,a[i].se);
    }
    for (auto id:tree){
        g[a[id].fi].push_back(MP(a[id].se,id));
        g[a[id].se].push_back(MP(a[id].fi,id));
    }
//    for (auto x:tree)cout<<x<<' ';
//    cout<<'\n';
//    for (auto x:extra)cout<<x<<' ';
//    cout<<'\n';
    dfs(0,0);
    ll cur = count_common_roads(tree);
        assert(++cnt_query<=8000);

    for (auto id:tree){
        if (ans[id] != -1)continue;
        for (auto x:extra){
            if (sus_edge(id,x)){
                auto cal = [&](ll y){
                    vector <ll> tmp;
                    for (auto sss:tree)if (sss != y)tmp.push_back(sss);
                    tmp.push_back(x);
                            assert(++cnt_query<=8000);

                    return count_common_roads(tmp);
                };
                vector <ll> cycle;
                for (auto y:tree){
                    if (sus_edge(y,x)){
                        cycle.push_back(y);
                    }
                }
                ll sum = -1;
                if (ans[x] != -1)sum = ans[x] + cur;
                for (auto y:cycle){
                    if (ans[y] != -1 && sum == -1){
                        sum = cal(y) + ans[y];
                        break;
                    }
                }
                if (sum==-1){
                    vector <ll> com(sz(cycle));
                    for (ll j = 0;j < sz(cycle);j ++){
                        com[j] = cal(cycle[j]);
                        if (com[j] < cur)sum = cur;
                        if (com[j] > cur)sum = cur+1;
                    }
                    if (sum == -1)sum = cur;
                    for (ll j = 0;j < sz(cycle);j ++){
                        ans[cycle[j]] = sum - com[j];
                    }
                    ans[x] = sum - cur;
                }
                else{
                    for (auto y:cycle)if (ans[y] == -1){
                        ans[y] = sum - cal(y);
                    }
                    ans[x] = sum - cur;
                }
                break;
            }
        }
    }
//    for (ll j = 0;j < m;j ++)cout<<ans[j]<<' ';
//    cout<<endl;
    for (ll j = 0;j < m;j ++){
        pll x = a[j];
        ind[x.fi][x.se]=ind[x.se][x.fi]=j;
        if (ans[j] == -1)bs[x.fi][x.se] = bs[x.se][x.fi] = 1;
    }
    while (true){
        bitset <MAXN> rem;
        rem.set();
        bitset <MAXN> tmp;
        vector <ll> cur;
        for (ll i = 0;i < n;i ++){
            if (rem[i]){
                queue <ll> q;
                q.push(i);
                rem[i] = 0;
                while (!q.empty()){
                    ll u = q.front();
                    q.pop();
                    tmp = rem & bs[u];
                    for (ll v = tmp._Find_first();v < n;v = tmp._Find_next(v)){
                        cur.push_back(ind[u][v]);
                        bs[u][v] = bs[v][u] = 0;
                        q.push(v);
                        rem[v] = 0;
                    }
                }
            }
        }
        if (!sz(cur))break;
//        cout<<sz(cur)<<' '<<cur[0]<<endl;
        solve(cur,superior_count(cur));
    }
    vector <ll> r;
    for (ll i = 0;i < m;i ++)if (ans[i]==1)r.push_back(i);
//    assert(cnt_query<=8000);
//    count_common_roads
    return r;
}

Compilation message

simurgh.cpp: In function 'void TRUNG::dfs(ll, ll)':
simurgh.cpp:25:27: warning: unused variable 'id' [-Wunused-variable]
   25 |             ll v = tmp.fi,id = tmp.se;
      |                           ^~
# 결과 실행 시간 메모리 Grader output
1 Incorrect 1 ms 1372 KB WA in grader: NO
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 1 ms 1372 KB WA in grader: NO
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 1 ms 1372 KB WA in grader: NO
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 1372 KB correct
2 Incorrect 1 ms 1372 KB WA in grader: NO
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 1 ms 1372 KB WA in grader: NO
2 Halted 0 ms 0 KB -