제출 #1263976

#제출 시각아이디문제언어결과실행 시간메모리
1263976tvgkConstruction of Highway (JOI18_construction)C++20
100 / 100
335 ms20856 KiB
#include<bits/stdc++.h>
using namespace std;
#define task "a"
#define se second
#define fi first
#define ll long long
#define ii pair<ll, ll>
const long mxN = 2e5 + 7;

int n, a[mxN], lf[mxN], root[mxN], in[mxN], par[mxN], h[mxN], stt, num, tree[mxN * 4], v[mxN], bit[mxN];
vector<int> val, w[mxN], mem;

bool cmp(int u, int v)
{
    return lf[u] > lf[v];
}

void DFS(int j)
{
    lf[j] = h[j];
    for (int i : w[j])
    {
        h[i] = h[j] + 1;
        par[i] = j;
        DFS(i);
        lf[j] = max(lf[j], lf[i]);
    }
    sort(w[j].begin(), w[j].end(), cmp);
}

void HLD(int j)
{
    if (!stt)
        stt = j;
    root[j] = stt;
    in[j] = ++num;

    for (int i : w[j])
        HLD(i);

    stt = 0;
}

void Down(int j)
{
    if (tree[j] == -1)
        return;

    tree[j * 2] = tree[j];
    tree[j * 2 + 1] = tree[j];
}

void Up(int j)
{
    if (tree[j * 2] == tree[j * 2 + 1])
        tree[j] = tree[j * 2];
    else
        tree[j] = -1;
}

void Upd(int u, int v, int nw, int j = 1, int l = 1, int r = n)
{
    if (u > r || l > v)
        return;

    if (u <= l && r <= v)
    {
        tree[j] = nw;
        return;
    }
    Down(j);

    int mid = (l + r) / 2;
    Upd(u, v, nw, j * 2, l, mid);
    Upd(u, v, nw, j * 2 + 1, mid + 1, r);
    Up(j);
}

ll Get_BIT(int j)
{
    int res = 0;
    while (j)
    {
        res += bit[j];
        j -= j & (-j);
    }
    return res;
}

void Upd_BIT(int j, int inc)
{
    while (j < val.size())
    {
        bit[j] += inc;
        j += j & (-j);
    }
}

ll Get(int u, int v, int j = 1, int l = 1, int r = n)
{
    if (u > r || l > v)
        return 0;

    if (u <= l && r <= v && tree[j] != -1)
    {
        //cout << tree[j] << " " << r - l + 1 << "  ";

        ll res = Get_BIT(tree[j]) * (r - l + 1);
        Upd_BIT(tree[j] + 1, r - l + 1);
        mem.push_back(tree[j]);
        return res;
    }
    Down(j);

    int mid = (l + r) / 2;
    return Get(u, v, j * 2 + 1, mid + 1, r) + Get(u, v, j * 2, l, mid);
}

void Erase(int j)
{
    while (j < val.size())
    {
        bit[j] = 0;
        j += j & (-j);
    }
}

void Add(int j)
{
    int tmp = a[j];
    Upd(in[j], in[j], tmp);
    j = par[j];

    ll ans = 0;
    while (j)
    {
        ans += Get(in[root[j]], in[j]);
        Upd(in[root[j]], in[j], tmp);
        j = par[root[j]];
    }

    for (int i : mem)
        Erase(i + 1);
    mem.clear();

    cout << ans << '\n';
}

int main()
{
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    cout.tie(NULL);
    //freopen(task".INP", "r", stdin);
    //freopen(task".OUT", "w", stdout);

    cin >> n;
    for (int i = 1; i <= n; i++)
    {
        cin >> a[i];
        val.push_back(a[i]);
    }
    val.push_back(0);
    sort(val.begin(), val.end());
    val.erase(unique(val.begin(), val.end()), val.end());

    for (int i = 1; i <= n; i++)
        a[i] = lower_bound(val.begin(), val.end(), a[i]) - val.begin();

    for (int i = 1; i < n; i++)
    {
        int u;
        cin >> u >> v[i];
        w[u].push_back(v[i]);
    }
    DFS(1);
    HLD(1);

    Upd(1, 1, a[1]);
    for (int i = 1; i < n; i++)
        Add(v[i]);
}

#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...