Submission #890265

#TimeUsernameProblemLanguageResultExecution timeMemory
890265boris_mihovConstruction of Highway (JOI18_construction)C++17
100 / 100
1020 ms23932 KiB
#include <algorithm>
#include <iostream>
#include <numeric>
#include <cassert>
#include <vector>
#include <stack>

typedef long long llong;
const int MAXN = 100000 + 10;
const llong INF = 1e18;
const int INTINF = 1e9;

int n;
struct SegmentTree
{
    struct Node
    {
        int min;
        int max;
        int lazy;

        Node()
        {
            lazy = 0;
            min = INTINF;
            max = -INTINF;
        }

        Node(int _min, int _max)
        {
            lazy = 0;
            min = _min;
            max = _max;
        }

        friend Node operator + (const Node &left, const Node &right)
        {
            Node result;
            result.min = std::min(left.min, right.min);
            result.max = std::max(left.max, right.max);
            return result;
        }
    };

    Node tree[4*MAXN];
    void push(int node, int l, int r)
    {
        if (tree[node].lazy == 0)
        {
            return;
        }

        tree[node].min = tree[node].max = tree[node].lazy;
        if (l < r)
        {
            tree[2*node].lazy = tree[node].lazy;
            tree[2*node + 1].lazy = tree[node].lazy;
        }

        tree[node].lazy = 0;
    }

    void build(int l, int r, int node, int c[], const std::vector <int> &tour)
    {
        if (l == r)
        {
            tree[node].min = tree[node].max = c[tour[l]];
            return;
        }

        int mid = (l + r) / 2;
        build(l, mid, 2*node, c, tour);
        build(mid + 1, r, 2*node + 1, c, tour);
        tree[node] = tree[2*node] + tree[2*node + 1];
    }

    void update(int l, int r, int node, int queryL, int queryR, int queryVal)
    {
        push(node, l, r);
        if (queryR < l || r < queryL)
        {
            return;
        }

        if (queryL <= l && r <= queryR)
        {
            tree[node].lazy = queryVal;
            push(node, l, r);
            return;
        }

        int mid = (l + r) / 2;
        update(l, mid, 2*node, queryL, queryR, queryVal);
        update(mid + 1, r, 2*node + 1, queryL, queryR, queryVal);
        tree[node] = tree[2*node] + tree[2*node + 1];
    }

    Node query(int l, int r, int node, int queryL, int queryR)
    {
        push(node, l, r); 
        if (queryL <= l && r <= queryR)
        {
            return tree[node];
        }

        Node res;
        int mid = (l + r) / 2;
        if (queryL <= mid) res = res + query(l, mid, 2*node, queryL, queryR);
        if (mid + 1 <= queryR) res = res + query(mid + 1, r, 2*node + 1, queryL, queryR);
        return res;
    }

    int search(int l, int r, int node, int queryL, int queryR, int queryVal)
    {
        push(node, l, r);
        if (queryR < l || r < queryL)
        {
            return -1;
        }

        if (queryL <= l && r <= queryR && tree[node].min == tree[node].max && tree[node].min == queryVal)
        {
            return -1;
        }

        if (l == r)
        {
            return l;
        }

        int mid = (l + r) / 2;
        int res = search(mid + 1, r, 2*node + 1, queryL, queryR, queryVal);
        if (res != -1) return res;
        return search(l, mid, 2*node, queryL, queryR, queryVal);
    }

    void build(int c[], const std::vector <int> &tour)
    {
        build(0, n - 1, 1, c, tour);
    }

    void update(int l, int r, int value)
    {
        update(0, n - 1, 1, l, r, value);
    }

    int getValue(int idx)
    {
        Node res = query(0, n - 1, 1, idx, idx);
        return res.min;
    }

    bool checkIfDifferent(int l, int r)
    {
        Node res = query(0, n - 1, 1, l, r);
        return res.min != res.max;
    }
    
    int search(int l, int r, int value)
    {
        return search(0, n - 1, 1, l, r, value);
    }
};

struct Fenwick
{
    int tree[MAXN];
    std::stack <std::pair <int,int>> updateTracker;
    void update(int idx, int value, bool shouldAdd = true)
    {
        if (shouldAdd) updateTracker.push({idx, value});
        for (int pos = idx ; pos <= n ; pos += pos & (-pos))
        {
            tree[pos] += value;
        }
    }

    int query(int idx)
    {
        int res = 0;
        for (int pos = idx ; pos > 0 ; pos -= pos & (-pos))
        {
            res += tree[pos];
        }

        return res;
    }

    void reset()
    {
        while (updateTracker.size())
        {
            update(updateTracker.top().first, -updateTracker.top().second, false);
            updateTracker.pop();
        }
    }
};

int a[MAXN];
int b[MAXN];
int c[MAXN];
int sz[MAXN];
int in[MAXN];
int head[MAXN];
int heavy[MAXN];
int parent[MAXN];
std::vector <int> tour;
std::vector <int> g[MAXN];
SegmentTree tree;
Fenwick fenwick;

void dfs(int node, int par)
{
    sz[node] = 1;
    parent[node] = par;
    for (const int &u : g[node])
    {
        if (u == par)
        {
            continue;
        }

        dfs(u, node);
        sz[node] += sz[u];
        if (sz[u] > sz[heavy[node]])
        {
            heavy[node] = u;
        }
    }
}

void decompose(int node, int h)
{
    head[node] = h;
    in[node] = tour.size();
    tour.push_back(node);

    if (heavy[node] != 0)
    {
        decompose(heavy[node], h);
    }

    for (const int &u : g[node])
    {
        if (u == parent[node] || u == heavy[node])
        {
            continue;
        }

        decompose(u, u);
    }
}

llong calcCost(int node, int added)
{
    llong ans = 0;
    fenwick.reset();
    while (node != 0)
    {
        int jumpTo = 0;
        int currHead = head[node];
        int myValue = tree.getValue(in[node]);
        if (tree.checkIfDifferent(in[currHead], in[node]))
        {
            jumpTo = tour[tree.search(in[currHead], in[node], myValue) + 1];
        } else
        {
            jumpTo = currHead;
        }
        
        int count = in[node] - in[jumpTo] + 1;
        ans += 1LL * count * fenwick.query(myValue - 1);
        fenwick.update(myValue, count);
        tree.update(in[jumpTo], in[node], c[added]);
        node = parent[jumpTo];
    }   

    return ans;
}

std::pair <int,int> sorted[MAXN];
void solve()
{
    for (int i = 1 ; i <= n ; ++i)
    {
        sorted[i] = {c[i], i};
    }

    std::sort(sorted + 1, sorted + 1 + n);
    int cnt = 0;
    for (int i = 1 ; i <= n ; ++i)
    {
        cnt += (sorted[i].first != sorted[i - 1].first);
        c[sorted[i].second] = cnt;
    }

    dfs(1, 0);
    decompose(1, 1);
    tree.build(c, tour);

    for (int i = 1 ; i < n ; ++i)
    {
        std::cout << calcCost(a[i], b[i]) << '\n';
    }
}

void input()
{
    std::cin >> n;
    for (int i = 1 ; i <= n ; ++i)
    {
        std::cin >> c[i];
    }

    for (int i = 1 ; i < n ; ++i)
    {
        std::cin >> a[i] >> b[i];
        g[a[i]].push_back(b[i]);
    }
}

void fastIOI()
{
    std::ios_base :: sync_with_stdio(0);
    std::cout.tie(nullptr);
    std::cin.tie(nullptr);
}   

int main()
{
    fastIOI();
    input();
    solve();

    return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...