Submission #924066

#TimeUsernameProblemLanguageResultExecution timeMemory
924066boris_mihovDesignated Cities (JOI19_designated_cities)C++17
16 / 100
349 ms69564 KiB
#include <algorithm>
#include <iostream>
#include <numeric>
#include <cassert>
#include <vector>

typedef long long llong;
#define int long long
const int MAXN = 200000 + 10;
const int INF  = 1e9;

int n, q;
struct Edge
{
    int u, a, b;
};

struct SegmentTree
{
    struct Node
    {
        llong max;
        llong lazy;
        int idx;
    
        Node()
        {
            max = lazy = idx = 0;
        }

        friend Node operator + (const Node &left, const Node &right)
        {
            Node res;
            if (left.max > right.max)
            {
                res.max = left.max;
                res.idx = left.idx;
            } else
            {
                res.max = right.max;
                res.idx = right.idx;
            }
            
            return res;
        }
    };

    Node tree[4*MAXN];
    void push(int node, int l, int r)
    {
        if (tree[node].lazy == 0)
        {
            return;
        }

        tree[node].max -= tree[node].lazy;
        if (l < r)
        {
            tree[2*node].lazy += tree[node].lazy;
            tree[2*node + 1].lazy += tree[node].lazy;
        }

        tree[node].lazy = 0;
    }

    void build(int l, int r, int node, llong dists[], const std::vector <int> &tour)
    {
        if (l == r)
        {
            tree[node].max = dists[tour[l]];
            tree[node].idx = l;
            return;
        }

        int mid = (l + r) / 2;
        build(l, mid, 2*node, dists, tour);
        build(mid + 1, r, 2*node + 1, dists, tour);
        tree[node] = tree[2*node] + tree[2*node + 1];
    }

    void update(int l, int r, int node, int queryL, int queryR, int queryVal)
    {
        push(node, l, r);
        if (queryR < l || r < queryL)
        {
            return;
        }

        if (queryL <= l && r <= queryR)
        {
            tree[node].lazy += queryVal;
            push(node, l, r);
            return;
        }

        int mid = (l + r) / 2;
        update(l, mid, 2*node, queryL, queryR, queryVal);
        update(mid + 1, r, 2*node + 1, queryL, queryR, queryVal);
        tree[node] = tree[2*node] + tree[2*node + 1];
    }

    void build(llong dists[], const std::vector <int> &tour)
    {
        build(0, n - 1, 1, dists, tour);
    }

    void update(int l, int r, int val)
    {
        update(0, n - 1, 1, l, r, val);
    }

    std::pair <llong, int> getMAX()
    {
        return {tree[1].max, tree[1].idx};
    }
};

llong sumUp;
llong sumOfAll;
llong dists[MAXN];
llong answer[MAXN];
llong ifRoot[MAXN];
llong maxPath[MAXN];
llong maxPath2[MAXN];
int fromPath[MAXN];
int fromPath2[MAXN];
std::vector <Edge> g[MAXN];
int sz[MAXN];

int in[MAXN];
int out[MAXN];
SegmentTree tree;
std::vector <int> tour;

void sumUpDFS(int node, int par)
{
    sz[node] = 1;
    for (const auto &[u, a, b] : g[node])
    {
        if (u == par)
        {
            continue;
        }

        sumUp += b;
        sumUpDFS(u, node);
        sz[node] += sz[u];
    }
}

void findOneAnswer(int node, int par, llong sum)
{
    ifRoot[node] = sum;
    answer[1] = std::max(answer[1], ifRoot[node]);
    for (const auto &[u, a, b]  : g[node])
    {
        if (u == par)
        {
            continue;
        }

        findOneAnswer(u, node, sum + a - b);
    }
}

void findMaxPath(int node, int par)
{
    fromPath[node] = node;
    for (const auto &[u, a, b] : g[node])
    {
        if (u == par)
        {
            continue;
        }

        findMaxPath(u, node);
        if (maxPath[node] < maxPath[u] + a)
        {
            fromPath2[node] = fromPath[node];
            fromPath[node] = fromPath[u];
            maxPath2[node] = maxPath[node];
            maxPath[node] = maxPath[u] + a;
        } else if (maxPath2[node] < maxPath[u] + a)
        {
            fromPath2[node] = fromPath[u];
            maxPath2[node] = maxPath[u] + a;
        }
    }
}

int parent[MAXN];
int parentEdge[MAXN];
void buildDFS(int node, int par)
{
    parent[node] = par; 
    in[node] = tour.size();
    tour.push_back(node);

    for (const auto &[u, a, b] : g[node])
    {
        if (u == par)
        {
            continue;
        }

        // std::cout << "Edge: " << node << ' ' << u << ' ' << a << '\n';
        parentEdge[u] = a;
        dists[u] = dists[node] + a;
        buildDFS(u, node);
    }

    out[node] = tour.size() - 1;
}

bool vis[MAXN];
llong take()
{
    auto [max, idx] = tree.getMAX();
    // std::cout << "Take: " << max << ' ' << tour[idx] << '\n';
    if (max == 0)
    {
        return max;
    }

    int node = tour[idx];
    while (!vis[node])
    {
        tree.update(in[node], out[node], parentEdge[node]);
        vis[node] = true;
        node = parent[node];
    }

    return max;
}

void solve()
{
    if (n == 2)
    {
        int sum = g[1][0].a + g[1][0].b;
        int max = std::max(g[1][0].a, g[1][0].b);
        answer[1] = max;
        answer[2] = sum;
        return;
    }

    int root = 1;
    while (root <= n && g[root].size() == 1)
    {
        root++;
    }

    assert(root <= n);
    sumUpDFS(root, 0);
    findOneAnswer(root, 0, sumUp);
    findMaxPath(root, 0);

    root = -1;
    llong last = 0;
    for (int i = 1 ; i <= n ; ++i)
    {
        answer[2] = std::max(answer[2], ifRoot[i] + maxPath[i] + maxPath2[i]);
        if (answer[2] > last)
        {
            last = answer[2];
            root = i;
        }
    }

    assert(root != -1);
    buildDFS(root, 0);
    assert(tour.size() == n);
    tree.build(dists, tour);

    vis[root] = true;
    take();
    take();

    for (int i = 3 ; i <= n ; ++i)
    {
        answer[i] = answer[i - 1] + take();
    }
}

void input()
{
    std::cin >> n;
    for (int i = 1 ; i < n ; ++i)
    {
        int u, v, a, b;
        std::cin >> u >> v >> a >> b;
        g[u].push_back({v, a, b});
        g[v].push_back({u, b, a});
        sumOfAll += a;
        sumOfAll += b;
    }

    std::cin >> q;
}

void print()
{
    for (int i = 1 ; i <= q ; ++i)
    {
        int cnt;
        std::cin >> cnt;
        std::cout << sumOfAll - answer[cnt] << '\n';
    }
}

void fastIOI()
{
    std::ios_base :: sync_with_stdio(0);
    std::cout.tie(nullptr);
    std::cin.tie(nullptr);
}

signed main()
{
    fastIOI();
    input();
    solve();
    print();

    return 0;   
}

Compilation message (stderr)

In file included from /usr/include/c++/10/cassert:44,
                 from designated_cities.cpp:4:
designated_cities.cpp: In function 'void solve()':
designated_cities.cpp:272:24: warning: comparison of integer expressions of different signedness: 'std::vector<long long int>::size_type' {aka 'long unsigned int'} and 'long long int' [-Wsign-compare]
  272 |     assert(tour.size() == n);
      |            ~~~~~~~~~~~~^~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...