제출 #1132548

#제출 시각아이디문제언어결과실행 시간메모리
1132548SpyrosAliv봉쇄 시간 (IOI23_closing)C++20
29 / 100
1092 ms33536 KiB
// linear network
#include <bits/stdc++.h>
using namespace std;
#define ll long long

int n, a, b;
ll k; 
vector<int> w;
vector<ll> minX, minY, minBoth;
vector<vector<pair<int, int>>> tree;
vector<bool> mark;
ll currCost;
int currAns;

int solve_line(int x, int y);

void dfs(int node, bool a, int par = -1, ll dep = 0LL) {
    if (!a) minX[node] = dep;
    else minY[node] = dep;
    for (auto [next, c]: tree[node]) {
        if (next == par) continue;
        dfs(next, a, node, dep + c);
    }
}

bool marks(int node, int goal, int par = -1) { // marks nodes in path (l, r), they are included in X set and Y set
    if (node == goal) {
        currCost += minBoth[node];
        mark[node] = true;
        return true;
    } 
    bool f = false;
    for (auto [next, c]: tree[node]) {
        if (next == par) continue;
        if (marks(next, goal, node)) {
            f = true;
            break;
        }
    }
    if (f) {
        mark[node] = true;
        currCost += minBoth[node];
    }
    return f;
}

bool check(int node, int goal, bool pass = false, int par = -1) {
    if (mark[node]) pass = true;
    if (node == goal) return pass;
    bool f = false;
    for (auto [next, c]: tree[node]) {
        if (next == par) continue;
        if (check(next, goal, pass, node)) {
            f = true;
            break;
        }
    }
    return f;
}

bool add(int node, bool fromY, int par = -1) {
    if (mark[node]) {
        return true;
    }
    bool inc = false;
    for (auto [next, c]: tree[node]) {
        if (next == par) continue;
        if (add(next, fromY, node)) {
            inc = true;
            break;
        }
    }
    if (inc) {
        mark[node] = true;
        currAns++;
        if (fromY) currCost += minY[node];
        else currCost += minX[node];
        return true;
    }
    return false;
}

int solve_easy(int x, int y);

int max_score(int sz, int x, int y, ll l, vector<int> u, vector<int> v, vector<int> weights) {
    a = x; b = y; n = sz;
    tree.clear(); tree.resize(n);
    k = l; w = weights;
    for (int i = 0; i < n-1; i++) {
        //int xx = u[i];
        tree[u[i]].push_back({v[i], w[i]});
        tree[v[i]].push_back({u[i], w[i]});
    }
    minX.clear(); minX.assign(n, 0LL);
    minY.clear(); minY.assign(n, 0LL);
    minBoth.clear(); minBoth.assign(n, 0LL);
    bool line = true;
    for (int i = 0; i < n-1; i++) {if (v[i] != u[i] + 1) {line = false;break;}}
    if (line) return solve_line(x, y);
    if (n >= 501) return solve_easy(x, y);
    dfs(x, false); // min dis from x
    dfs(y, true); // min dis from y
    for (int i = 0; i < n; i++) minBoth[i] = max(minY[i], minX[i]);
    int maxAns = 0;

    for (int l = 0; l < n; l++) {
        for (int r = 0; r < n; r++) {
            // try from node l to node r
            currCost = 0;
            currAns = 0;
            mark.assign(n, false);

            marks(l, r); // mark every node in path from l to r

            //if (!check(x, y)) continue; // if no node was in path from x to y
            for (int i = 0; i < n; i++) {
                if (mark[i]) currAns += 2; // included in both set X and Y
            }

            add(y, true);
            add(x, false);

            if (currCost > k) continue;
            vector<ll> bonus;
            for (int i = 0; i < n; i++) {
                if (!mark[i]) {
                    bonus.push_back(min(minX[i], minY[i]));
                }
            }
            sort(bonus.begin(), bonus.end());
            for (auto x: bonus) {
                if (x + currCost <= k) {
                    currAns++;
                    currCost += x;
                }
                else break;
            }
            maxAns = max(maxAns, currAns);
        }
    }
    int ans = 0;
    vector<ll> bonus;
    for (int i = 0; i < n; i++) bonus.push_back(min(minX[i], minY[i]));
    sort(bonus.begin(), bonus.end());
    for (auto x: bonus) {
        if (x > k) continue;
        ans++;
        k -= x;
    }
    return max(ans, maxAns);
}
/*
int main() {
    cout << max_score(7, 0, 2, 10, {0, 0, 1, 2, 2, 5}, {1, 3, 2, 4, 5, 6}, {2, 3, 4, 2, 5, 3});
}*/

int solve_easy(int x, int y) {
    dfs(x, false);
    dfs(y, true);
        vector<ll> costs;
        for (int i = 0; i < n; i++) {
            costs.push_back(min(minX[i], minY[i]));
            //costs.push_back(minY[i]);
        }
        sort(costs.begin(), costs.end());
        int ans = 0;
        int p = (int)costs.size();
        for (int i = 0; i < p; i++) {
            ll x = costs[i];
            if (k >= x) {
                ans++;
                k -= x;
            }
            else break;
        }
        return ans;
}

int solve_line(int x, int y) {
    for (int i = x+1; i < n; i++) {
        minX[i] = minX[i-1] + w[i - 1];
    }
    for (int i = x-1; i >= 0; i--) {
        minX[i] = minX[i+1] + w[i];
    }
    for (int i = y+1; i < n; i++) {
        minY[i] = minY[i-1] + w[i - 1];
    }
    for (int i = y-1; i >= 0; i--) {
        minY[i] = minY[i+1] + w[i];
    }
    vector<ll> f, s;
    for (int i = 0; i < n; i++) {
        minBoth[i] = max(minX[i], minY[i]);
    }
    f = minX;
    s = minY;
    for (int i = 1; i < n; i++) {
        minX[i] = minX[i-1] + minX[i];
        minY[i] = minY[i-1] + minY[i];
        minBoth[i] = minBoth[i-1] + minBoth[i];
    }
    int maxAns = 0;
    for (int l = 0; l <= y; l++) {
        for (int r = max(x, l); r < n; r++) { // (l, r) included in both
            int ans = (r - l + 1) * 2;
            ll cost = minBoth[r];
            if (l > 0) cost -= minBoth[l-1];
            if (y > r) {
                cost += minY[y] - minY[r];
                ans += (y - r);
            }
            if (l > x) {
                cost += minX[l-1];
                if (x > 0) cost -= minX[x-1];
                ans += (l - x);
            }
            if (cost > k) continue;
            int a = min(l-1, x-1), b = max(y+1, r+1);
            while (cost <= k && (a >= 0 || b < n)) {
                ll nxt = 0;
                if (a < 0) {
                    nxt = minY[b];
                    if (b > 0) nxt -= minY[b-1];
                    b++;
                }
                else if (b >= n) {
                    nxt = minX[a];
                    if (a > 0) nxt -= minX[a-1];
                    a--;
                }
                else {
                    ll opt1 = minY[b], opt2 = minX[a];
                    if (a > 0) opt2 -= minX[a-1];
                    if (b > 0) opt1 -= minY[b-1];
                    if (opt2 <= opt1) {
                        a--;
                        nxt = opt2;
                    }
                    else {
                        b++;
                        nxt = opt1;
                    }
                }
                if (nxt + cost > k) break;
                cost += nxt;
                ans++;
            }
            maxAns = max(maxAns, ans);
        }
    }
    int ans = 0;
    vector<ll> sec;
    for (int i = 0; i < n; i++) {
        sec.push_back(min(f[i], s[i]));
    }
    sort(sec.begin(), sec.end());
    for (auto cr: sec) {
        if (cr <= k) {
            k -= cr;
            ans++;
        }
    }
    maxAns = max(maxAns, ans);
    return maxAns;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...