Submission #1247376

#TimeUsernameProblemLanguageResultExecution timeMemory
1247376azaylibelzTwo Currencies (JOI23_currencies)C++20
0 / 100
21 ms11080 KiB
#pragma GCC optimize("O3,unroll-loops")

#include <cstdio>
#include <vector>
#include <algorithm>
#include <numeric>

using namespace std;
typedef long long ll;
const ll infmax = 1e18; // Use a standard large value for infinity
const int maxq = 1e5 + 5, maxn = 1e5 + 5;

// Query struct with optimized data types
struct Query {
    int id, s, t;
    ll x, y;
};

// Global variables with optimized types
vector<pair<int, int>> adj[maxn];
vector<ll> checkpolls[maxn];
vector<Query> queries; // Will be populated from input

int n, m, q;
ll res[maxq];
int childSz[maxn];
bool removed[maxn];
vector<int> compNode;

vector<ll> pathCost[maxn], pathpSum[maxn];
int subPar[maxn];

// Helper to get all nodes in the current component
void getCompNode(int u, int p) {
    compNode.push_back(u);
    for (auto& [v, w] : adj[u]) {
        if (v != p && !removed[v]) {
            getCompNode(v, u);
        }
    }
}

// DFS to calculate subtree sizes
void dfsSz(int u, int p) {
    childSz[u] = 1;
    for (auto& [v, w] : adj[u]) {
        if (v != p && !removed[v]) {
            dfsSz(v, u);
            childSz[u] += childSz[v];
        }
    }
}

// Finds the centroid of the current component
int findCentroid(int u, int p, int totalSz) {
    for (auto& [v, w] : adj[u]) {
        if (v != p && !removed[v] && childSz[v] * 2 > totalSz) {
            return findCentroid(v, u, totalSz);
        }
    }
    return u;
}

// DFS to gather path costs and prefix sums from the centroid
void getPathInfo(int u, int p, int rt) {
    subPar[u] = rt;
    for (auto& [v, idx] : adj[u]) {
        if (v != p && !removed[v]) {
            pathCost[v].reserve(pathCost[u].size() + checkpolls[idx].size());
            pathCost[v].resize(pathCost[u].size() + checkpolls[idx].size());
            merge(pathCost[u].begin(), pathCost[u].end(), checkpolls[idx].begin(), checkpolls[idx].end(), pathCost[v].begin());
            
            pathpSum[v].assign(pathCost[v].size() + 1, 0);
            for (size_t i = 0; i < pathCost[v].size(); i++) {
                pathpSum[v][i + 1] = pathpSum[v][i] + pathCost[v][i];
            }
            getPathInfo(v, u, rt);
        }
    }
}

// Calculates the minimum silver cost for a given path and silver ticket count
ll calc(int s, int t, ll kSliver) {
    auto& costS = pathCost[s];
    auto& costT = pathCost[t];
    auto& pss = pathpSum[s];
    auto& pst = pathpSum[t];

    if (kSliver < 0 || kSliver > (ll)(costS.size() + costT.size())) {
        return infmax;
    }
    
    ll min_cost = infmax;
    for (size_t i = 0; i <= costS.size(); i++) {
        ll j = kSliver - i;
        if (j >= 0 && j <= (ll)costT.size()) {
            min_cost = min(min_cost, pss[i] + pst[j]);
        }
    }
    return min_cost;
}


void solve(int entry, const vector<Query>& current_queries, ll midGold, vector<bool>& ress) {
    if (current_queries.empty()) return;

    compNode.clear();
    getCompNode(entry, 0);
    dfsSz(entry, 0);
    int centroid = findCentroid(entry, 0, compNode.size());

    pathCost[centroid].clear();
    pathpSum[centroid].assign(1, 0);
    subPar[centroid] = centroid;

    for (auto& [v, idx] : adj[centroid]) {
        if (!removed[v]) {
            pathCost[v] = checkpolls[idx];
            pathpSum[v].assign(pathCost[v].size() + 1, 0);
            for (size_t i = 0; i < pathCost[v].size(); i++) {
                pathpSum[v][i + 1] = pathpSum[v][i] + pathCost[v][i];
            }
            getPathInfo(v, centroid, v);
        }
    }
    
    // Using static vectors to avoid repeated memory allocations
    static vector<vector<Query>> subQueries;
    static vector<Query> crossQ;
    if(subQueries.empty()) subQueries.resize(n + 1);

    crossQ.clear();
    for(auto const& q : current_queries){
        if(subPar[q.s] != centroid && subPar[q.s] == subPar[q.t]){
            subQueries[subPar[q.s]].push_back(q);
        } else {
            crossQ.push_back(q);
        }
    }

    for (auto& q : crossQ) {
        ll pcheckpolls = pathCost[q.s].size() + pathCost[q.t].size();
        ll kSliver = pcheckpolls - midGold;

        bool flag = false;
        if (kSliver < 0) {
            flag = true;
        } else {
            ll sliver_cost = calc(q.s, q.t, kSliver);
            if (sliver_cost <= q.y) {
                flag = true;
            }
        }
        ress[q.id] = flag;
    }

    removed[centroid] = true;
    for (auto& [rt, idx] : adj[centroid]) {
        if (!removed[rt] && !subQueries[rt].empty()) {
            solve(rt, subQueries[rt], midGold, ress);
            subQueries[rt].clear(); // Clear for next use
        }
    }
}

// Parallel binary search on the answer
void pbs(vector<Query> p_queries, ll mingold, ll maxgold) {
    if (p_queries.empty() || mingold > maxgold) {
        return;
    }
    ll midgold = mingold + (maxgold - mingold) / 2;
    
    fill(removed + 1, removed + n + 1, false);
    vector<bool> results(q);
    solve(1, p_queries, midgold, results);

    vector<Query> possibleqs, impossibleqs;
    for (const auto& q_item : p_queries) {
        if (results[q_item.id]) {
            res[q_item.id] = midgold;
            possibleqs.push_back(q_item);
        } else {
            impossibleqs.push_back(q_item);
        }
    }

    pbs(move(possibleqs), mingold, midgold - 1);
    pbs(move(impossibleqs), midgold + 1, maxgold);
}

int main() {
    scanf("%d %d %d", &n, &m, &q);

    for (int i = 1; i < n; ++i) {
        int u, v;
        scanf("%d %d", &u, &v);
        adj[u].push_back({v, i});
        adj[v].push_back({u, i});
    }

    for (int i = 0; i < m; ++i) {
        int p;
        ll c;
        scanf("%d %lld", &p, &c);
        checkpolls[p].push_back(c);
    }

    for (int i = 1; i < n; ++i) {
        sort(checkpolls[i].begin(), checkpolls[i].end());
    }

    queries.resize(q);
    for (int i = 0; i < q; ++i) {
        queries[i].id = i;
        scanf("%d %d %lld %lld", &queries[i].s, &queries[i].t, &queries[i].x, &queries[i].y);
    }

    fill(res, res + q, -1);
    pbs(queries, 0, n);

    for (int i = 0; i < q; ++i) {
        ll gold = res[i];
        ll initial_x = -1;
        // Find original x value for the query with id `i`
        // This is needed because the `queries` vector is copied and moved in `pbs`
        for(int j=0; j<q; ++j){
            if(::queries[j].id == i){
                initial_x = ::queries[j].x;
                break;
            }
        }

        if (gold != -1 && initial_x >= gold) {
            printf("%lld\n", initial_x - gold);
        } else {
            printf("-1\n");
        }
    }

    return 0;
}

Compilation message (stderr)

currencies.cpp: In function 'int main()':
currencies.cpp:192:10: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
  192 |     scanf("%d %d %d", &n, &m, &q);
      |     ~~~~~^~~~~~~~~~~~~~~~~~~~~~~~
currencies.cpp:196:14: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
  196 |         scanf("%d %d", &u, &v);
      |         ~~~~~^~~~~~~~~~~~~~~~~
currencies.cpp:204:14: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
  204 |         scanf("%d %lld", &p, &c);
      |         ~~~~~^~~~~~~~~~~~~~~~~~~
currencies.cpp:215:14: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
  215 |         scanf("%d %d %lld %lld", &queries[i].s, &queries[i].t, &queries[i].x, &queries[i].y);
      |         ~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...