Submission #1238265

#TimeUsernameProblemLanguageResultExecution timeMemory
1238265SamAndClosing Time (IOI23_closing)C++20
43 / 100
214 ms77568 KiB
#include "closing.h"
#include <bits/stdc++.h>
using namespace std;
#define m_p make_pair
#define all(x) (x).begin(),(x).end()
#define sz(x) ((int)(x).size())
#define fi first
#define se second
typedef long long ll;
mt19937 rnd(chrono::steady_clock::now().time_since_epoch().count());
mt19937 rnf(2106);
const int N = 400005;

int n;
vector<pair<int, int> > g[N];

void dfs0(int x, int p, vector<ll>& d)
{
    if (x == p)
        d[x] = 0;
    for (int i = 0; i < g[x].size(); ++i)
    {
        int h = g[x][i].fi;
        if (h == p)
            continue;
        d[h] = d[x] + g[x][i].se;
        dfs0(h, x, d);
    }
}

bool dfs1(int x, int p, int y, vector<int>& v)
{
    v.push_back(x);
    if (x == y)
        return true;
    for (int i = 0; i < g[x].size(); ++i)
    {
        int h = g[x][i].fi;
        if (h == p)
            continue;
        if (dfs1(h, x, y, v))
            return true;
    }
    v.pop_back();
    return false;
}

bool c[N];
int u1[N], u2[N];

vector<pair<ll, int> > u;
int q[N * 4];
ll s[N * 4];
ll s1[N * 4], s2[N * 4];
int q2[N * 4];

void bil(int tl, int tr, int pos)
{
    q[pos] = 0;
    s[pos] = 0;
    s1[pos] = s2[pos] = 0;
    q2[pos] = 0;
    if (tl == tr)
        return;
    int m = (tl + tr) / 2;
    bil(tl, m, pos * 2);
    bil(m + 1, tr, pos * 2 + 1);
}

void ubd(int tl, int tr, int x, int y, int pos)
{
    if (tl == tr)
    {
        if (u[tl].se < 0)
        {
            s1[pos] += y * u[tl].fi / 2;
            q[pos] += y;
        }
        else
        {
            s2[pos] += y * u[tl].fi / 2;
            q2[pos] += y / 2;
        }
        s[pos] += y * u[tl].fi;
        return;
    }
    int m = (tl + tr) / 2;
    if (x <= m)
        ubd(tl, m, x, y, pos * 2);
    else
        ubd(m + 1, tr, x, y, pos * 2 + 1);
    q[pos] = q[pos * 2] + q[pos * 2 + 1];
    s[pos] = s[pos * 2] + s[pos * 2 + 1];
    q2[pos] = q2[pos * 2] + q2[pos * 2 + 1];
    s1[pos] = s1[pos * 2] + s1[pos * 2 + 1];
    s2[pos] = s2[pos * 2] + s2[pos * 2 + 1];
}

void qry1(int tl, int tr, ll& kk, int& qq, int pos)
{
    if (kk >= s[pos])
    {
        kk -= s[pos];
        qq += q[pos];
        return;
    }
    if (tl == tr)
    {
        for (int i = 0; i < q[pos]; ++i)
        {
            if (kk >= u[tl].fi)
            {
                kk -= u[tl].fi;
                if (u[tl].se < 0)
                    ++qq;
            }
        }
        return;
    }
    int m = (tl + tr) / 2;
    qry1(tl, m, kk, qq, pos * 2);
    qry1(m + 1, tr, kk, qq, pos * 2 + 1);
}

ll qry2(int tl, int tr, int qq, int pos)
{
    assert(qq <= q[pos]);
    if (qq == q[pos])
        return s1[pos];
    if (tl == tr)
        return 0;
    int m = (tl + tr) / 2;
    if (qq <= q[pos * 2])
        return qry2(tl, m, qq, pos * 2);
    return s1[pos * 2] + qry2(m + 1, tr, qq - q[pos * 2], pos * 2 + 1);
}

int qry3(int tl, int tr, ll kk, int pos)
{
    if (kk >= s2[pos])
        return q2[pos];
    if (tl == tr)
        return 0;
    int m = (tl + tr) / 2;
    if (kk <= s2[pos])
        return qry3(tl, m, kk, pos * 2);
    return q2[pos * 2] + qry3(m + 1, tr, kk - s2[pos * 2], pos * 2 + 1);
}

int max_score(int N, int X, int Y, long long K,
              std::vector<int> U, std::vector<int> V, std::vector<int> W)
{
    n = N;

    for (int i = 0; i <= n + 1; ++i)
    {
        c[i] = false;
        g[i].clear();
    }

    for (int i = 0; i < n - 1; ++i)
    {
        int x = U[i];
        int y = V[i];
        int z = W[i];
        g[x].push_back(m_p(y, z));
        g[y].push_back(m_p(x, z));
    }

    vector<ll> dX(n);
    dfs0(X, X, dX);
    vector<ll> dY(n);
    dfs0(Y, Y, dY);

    vector<int> v;
    assert(dfs1(X, X, Y, v));
    assert(v[0] == X && v.back() == Y);

    int ans = 0;
    {
    vector<ll> w;
    for (int x = 0; x < n; ++x)
        w.push_back(min(dX[x], dY[x]));
    sort(all(w));
    ll k = K;
    for (int i = 0; i < n; ++i)
    {
        if (k >= w[i])
        {
            k -= w[i];
            ++ans;
        }
    }
    }

    u.clear();
    for (int x = 0; x < n; ++x)
    {
        u.push_back(m_p((max(dX[x], dY[x]) - min(dX[x], dY[x])) * 2, -(x + 1)));
        u.push_back(m_p(max(dX[x], dY[x]), x));
    }
    sort(all(u));
    for (int i = 0; i < sz(u); ++i)
    {
        if (u[i].se >= 0)
            u2[u[i].se] = i;
        else
            u1[-u[i].se - 1] = i;
    }

    ll k = K;
    int pans = 0;
    vector<ll> w1;
    bil(0, sz(u) - 1, 1);
    for (int x = 0; x < n; ++x)
        ubd(0, sz(u) - 1, u2[x], 2, 1);
    for (int i = 0; i < sz(v); ++i)
    {
        int x = v[i];
        c[x] = true;
        if (k >= min(dX[x], dY[x]))
        {
            k -= min(dX[x], dY[x]);
            ++pans;
            w1.push_back(max(dX[x], dY[x]) - min(dX[x], dY[x]));
            ubd(0, sz(u) - 1, u2[x], -2, 1);
            ubd(0, sz(u) - 1, u1[x], 1, 1);
        }
    }
    if (pans != sz(v))
        return ans;

    vector<pair<ll, int> > w;
    for (int x = 0; x < n; ++x)
    {
        if (c[x])
            continue;
        w.push_back(m_p(min(dX[x], dY[x]), x));
    }
    sort(all(w));

    for (int i = -1; i < sz(w); ++i)
    {
        if (i >= 0)
        {
            int x = w[i].se;
            c[x] = true;
            k -= min(dX[x], dY[x]);
            ++pans;
            w1.push_back(max(dX[x], dY[x]) - min(dX[x], dY[x]));
            ubd(0, sz(u) - 1, u2[x], -2, 1);
            ubd(0, sz(u) - 1, u1[x], 1, 1);
            if (k < 0)
                break;
        }
        /*vector<ll> w2;
        for (int x = 0; x < n; ++x)
        {
            if (!c[x])
            {
                w2.push_back(max(dX[x], dY[x]));
            }
        }
        sort(all(w2));
        sort(all(w1));

        vector<ll> p1, p2;
        p1.push_back(0);
        for (int i = 0; i < sz(w1); ++i)
        {
            p1.push_back(p1.back() + w1[i]);
        }
        p2.push_back(0);
        for (int i = 0; i < sz(w2); ++i)
        {
            p2.push_back(p2.back() + w2[i]);
        }

        /*vector<pair<ll, bool> > ww;
        for (int i = 0; i < sz(w1); ++i)
        {
            ww.push_back(m_p(w1[i] * 2, true));
        }
        for (int i = 0; i < sz(w2); ++i)
        {
            ww.push_back(m_p(w2[i], false));
            ww.push_back(m_p(w2[i], false));
        }
        sort(all(ww));
        ll kk = k * 2;
        int q = 0;
        for (int i = 0; i < sz(ww); ++i)
        {
            if (kk >= ww[i].fi)
            {
                kk -= ww[i].fi;
                if (ww[i].se)
                    ++q;
            }
        }*/

        ll kk = k * 2;
        int q = 0;
        qry1(0, sz(u) - 1, kk, q, 1);

        /*for (int i = max(0, q - 2); i <= min(sz(p1) - 1, q + 2); ++i)
        {
            for (int j = 0; j < sz(p2); ++j)
            {
                if (k >= p1[i] + p2[j])
                    ans = max(ans, pans + i + 2 * j);
            }
        }*/

        for (int i = max(0, q - 2); i <= min(::q[1], q + 2); ++i)
        {
            ll kk = k;
            kk -= qry2(0, sz(u) - 1, i, 1);
            if (kk < 0)
                continue;
            int yans = qry3(0, sz(u) - 1, kk, 1);
            yans *= 2;
            yans += i + pans;
            ans = max(ans, yans);
        }
    }
    return ans;
}

/*
1
8 3 5 100
0 1 1
1 2 1
2 3 1
3 4 1
4 5 1
5 6 1
6 7 1

2
7 0 2 10
0 1 2
0 3 3
1 2 4
2 4 2
2 5 5
5 6 3
4 0 3 20
0 1 18
1 2 1
2 3 19

6
3
*/
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...