Submission #1233378

#TimeUsernameProblemLanguageResultExecution timeMemory
1233378Ghulam_JunaidClosing Time (IOI23_closing)C++20
35 / 100
1096 ms70676 KiB
#include <bits/stdc++.h>
#include "closing.h"
// #include "grader.cpp"
using namespace std;

typedef long long ll;

const int N = 2e5 + 10;
int n, x[2];
ll k, dist[N][2], pref[N][2];
vector<pair<int, int>> g[N];

void dfs(int v, int id, int p = -1){
    for (auto [w, u] : g[v]){
        if (u == p) continue;
        dist[u][id] = dist[v][id] + w;
        dfs(u, id, v);
    }
}

int max_score(int nn, int xx, int yy, ll kk, vector<int> uu, vector<int> vv, vector<int> ww){
    xx++, yy++;
    for (int i = 0; i < nn - 1; i ++) 
        uu[i]++, vv[i]++;

    n = nn, x[0] = xx, x[1] = yy, k = kk;
    for (int i = 0; i < n - 1; i ++){
        g[uu[i]].push_back({ww[i], vv[i]});
        g[vv[i]].push_back({ww[i], uu[i]});
    }
    
    multiset<ll> st[2];
    for (int id : {0, 1}){
        dist[x[id]][id] = 0;
        dfs(x[id], id);
        for (int i = 1; i <= n; i ++){
            pref[i][id] = pref[i - 1][id] + dist[i][id];
            st[id].insert(dist[i][id]);
        }
    }

    int ans = 0;
    while (st[0].size() and st[1].size()){
        ll a = *st[0].begin(), b = *st[1].begin();
        if (a <= b){
            if (k < a) break;
            k -= a;
            st[0].erase(st[0].begin());
            ans++;
            continue;
        }
        if (k < b) break;
        k -= b;
        st[1].erase(st[1].begin());
        ans++;
        continue;
    }

    for (int mid = x[0]; mid <= x[1]; mid ++){
        k = kk;
        set<pair<ll, int>> st[2];
        int vis[n + 1] = {}, cur = 0;
        memset(vis, 0, sizeof vis);

        for (int i = 1; i < x[0]; i ++)
            for (int id : {0, 1})
                st[id].insert({dist[i][id], i});
        for (int i = x[0]; i < mid; i ++){
            vis[i] = 1;
            k -= dist[i][0];
            st[1].insert({max(0ll, dist[i][1] - dist[i][0]), i});
        }
        k -= max(dist[mid][0], dist[mid][1]);
        vis[mid] = 3;
        for (int i = mid + 1; i <= x[1]; i ++){
            vis[i] = 2;
            k -= dist[i][1];
            st[0].insert({max(0ll, dist[i][0] - dist[i][1]), i});
        }
        for (int i = x[1] + 1; i <= n; i ++)
            for (int id : {0, 1})
                st[id].insert({dist[i][id], i});

        if (k < 0) continue;
        
        cur = x[1] - x[0] + 2;
        int ite = 0;
        while (!st[0].empty() or !st[1].empty()){
            ite++;
            if (ite > 1e6) return 1/0;
            ll d1, v1, d2, v2;
            if (st[0].empty()){
                d2 = (*st[1].begin()).first;
                v2 = (*st[1].begin()).second;
                d1 = 1e18;
                v1 = -1;
            }
            else if (st[1].empty()){
                d1 = (*st[0].begin()).first;
                v1 = (*st[0].begin()).second;
                d2 = 1e18;
                v2 = -1;
            }
            else{
                d1 = (*st[0].begin()).first;
                v1 = (*st[0].begin()).second;
                d2 = (*st[1].begin()).first;
                v2 = (*st[1].begin()).second;
            }

            if (d1 <= d2){
                if (k < d1) break;
                k -= d1;
                cur++;
                st[0].erase(st[0].begin());
                vis[v1] |= 1;
                if (vis[v1] & 2) continue;
                if (st[1].find({dist[v1][1], v1}) == st[1].end()) continue;
                st[1].erase({dist[v1][1], v1});
                st[1].insert({max(0ll, dist[v1][1] - dist[v1][0]), v1});
            }
            else{
                if (k < d2) break;
                k -= d2;
                cur++;
                st[1].erase(st[1].begin());
                vis[v2] |= 2;
                if (vis[v2] & 1) continue;
                if (st[0].find({dist[v2][0], v2}) == st[0].end()) continue;
                st[0].erase({dist[v2][0], v2});
                st[0].insert({max(0ll, dist[v2][0] - dist[v2][1]), v2});
            }
        }

        ans = max(ans, cur);
    }

    for (int i = 0; i <= n; i ++) g[i].clear();
    return ans;
}

Compilation message (stderr)

closing.cpp: In function 'int max_score(int, int, int, ll, std::vector<int>, std::vector<int>, std::vector<int>)':
closing.cpp:90:36: warning: division by zero [-Wdiv-by-zero]
   90 |             if (ite > 1e6) return 1/0;
      |                                   ~^~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...