Submission #1301445

#TimeUsernameProblemLanguageResultExecution timeMemory
1301445proplayerCats or Dogs (JOI18_catdog)C++20
100 / 100
500 ms65668 KiB
#include "catdog.h"
#include<bits/stdc++.h>
using namespace std;
using ll = long long;
using ld = long double;
const int maxN = 1e6 + 5;
const int inf = 1e9;
int n;
vector<int> adj[maxN];
int chk[maxN];
struct Tnode
{
    int cost[2][2];
    Tnode()
    {
        cost[0][0] = 0;
        cost[1][0] = inf;
        cost[0][1] = inf;
        cost[1][1] = 0;
    }
    void init()
    {
        cost[0][0] = 0;
        cost[1][0] = 0;
        cost[0][1] = 0;
        cost[1][1] = 0;
    }
    Tnode operator + (const Tnode& other) const
    {
        if (other.cost[0][0] == -1) return *this;
        if (cost[0][0] == -1) return other;
        Tnode res;
        res.cost[0][0] = inf;
        res.cost[1][1] = inf;
        for (int a = 0; a < 2; ++a)
            for (int b = 0; b < 2; ++b)
                for (int c = 0; c < 2; ++c)
                    for (int d = 0; d < 2; ++d)
                    {
                        res.cost[a][b] = min(cost[a][c] + other.cost[d][b] + (c != d), res.cost[a][b]);
                    }
        return res;
    }
    void add(Tnode x)
    {
        int cat = min(x.cost[0][0], x.cost[0][1]);
        int dog = min(x.cost[1][0], x.cost[1][1]);
        cat = min(cat, dog + 1);
        dog = min(dog, cat + 1);
        cost[0][0] += cat;
        cost[0][1] += dog;
        cost[1][0] += cat;
        cost[1][1] += dog;
    }
    void del(Tnode x)
    {
        int cat = min(x.cost[0][0], x.cost[0][1]);
        int dog = min(x.cost[1][0], x.cost[1][1]);
        cat = min(cat, dog + 1);
        dog = min(dog, cat + 1);
        cost[0][0] -= cat;
        cost[0][1] -= dog;
        cost[1][0] -= cat;
        cost[1][1] -= dog;
    }
    int ans()
    {
        return min({cost[0][0], cost[1][1], cost[0][1], cost[1][0]});
    }
}
st[2 * maxN], f[maxN];
int hepa[maxN], idpa[maxN], id[maxN], sz[maxN], bc[maxN], dep[maxN], p[maxN];
int in[maxN], out[maxN];
int cnt, cntpa;
void predfs(int u, int par)
{
    p[u] = par;
    sz[u] = 1; bc[u] = -1;
    for (int v : adj[u])
    {
        if (v == par) continue;
        dep[v] = dep[u] + 1;
        predfs(v, u);
        sz[u] += sz[v];
        if (bc[u] == -1 || sz[bc[u]] < sz[v]) bc[u] = v;
    }
}
void hld(int u)
{

    idpa[u] = cntpa;
    id[u] = ++cnt;
    if (hepa[cntpa] == 0)
    {
        hepa[cntpa] = u;
        in[cntpa] = cnt;
    }
    out[cntpa] = cnt;
    if (bc[u] > 0) hld(bc[u]);
    for (int v : adj[u])
    {
        if (v == p[u] || v == bc[u]) continue;
        ++cntpa;
        hld(v);
    }
}
void update(int id, int l, int r, int i, Tnode val)
{
    if (i < l || i > r) return;
    if (l == r)
    {
        st[id] = val;
        return;
    }
    int mid = (l + r) / 2;
    update(id << 1, l, mid, i, val);
    update(id << 1 | 1, mid + 1, r, i, val);
    st[id] = st[id << 1] + st[id << 1 | 1];
}
Tnode get(int id, int l, int r, int u, int v)
{
    if (l > v || r < u)
    {
        Tnode tmp = Tnode();
        tmp.cost[0][0] = -1;
        return tmp;
    }
    if (u <= l && r <= v) return st[id];
    int mid = (l + r) / 2;
    return get(id << 1, l, mid, u, v) + get(id << 1 | 1, mid + 1, r, u, v);
}
Tnode modify(int type, Tnode x)
{
    if (type == 0) x.cost[1][1] = x.cost[1][0] = x.cost[0][1] = inf;
    if (type == 1) x.cost[0][0] = x.cost[1][0] = x.cost[0][1] = inf;
    if (type == 2) x.cost[0][1] = x.cost[1][0] = inf;
    return x;
}
void query(int u)
{
    int u0 = u;
    while (idpa[u] != 0)
    {
        f[p[hepa[idpa[u]]]].del(get(1, 1, n, in[idpa[u]], out[idpa[u]]));
        update(1, 1, n, id[u], modify(chk[u], f[u]));
        f[p[hepa[idpa[u]]]].add(get(1, 1, n, in[idpa[u]], out[idpa[u]]));
        u = p[hepa[idpa[u]]];
    }
//    cerr << f[4].cost[0][0] << " " << in[idpa[u]] << " " << get(1, 1, n, id[6], id[6]).cost[0][0] << "sd1\n";
//    cerr << f[4].cost[0][1] << " " << chk[u0] << " " << get(1, 1, n, id[6], id[6]).cost[0][1] << "sd2\n";
//    cerr << f[4].cost[1][0] << " " << idpa[6] << " " << get(1, 1, n, id[6], id[6]).cost[1][0] << "sd3\n";
//    cerr << f[4].cost[1][1] << " " << id[6] << " " << get(1, 1, n, id[6], id[6]).cost[1][1] << "sd4\n";
}

void initialize(int N, vector<int> A, vector<int> B)
{
//    Tnode tmp0, tmp;
//    tmp0.cost[0][0] = 0;
//    tmp0.cost[0][1] = inf;
//    tmp0.cost[1][0] = inf;
//    tmp0.cost[1][1] = inf;
//    tmp.cost[0][0] = inf;
//    tmp.cost[0][1] = inf;
//    tmp.cost[1][0] = inf;
//    tmp.cost[1][1] = 0;
//    cerr << (tmp + tmp0).cost[0][0] << "A\n";
//    cerr << (tmp + tmp0).cost[0][1] << "B\n";
//    cerr << (tmp + tmp0).cost[1][0] << "C\n";
//    cerr << (tmp + tmp0).cost[1][1] << "D\n";
    n = N;
    for (int i = 0; i < n - 1; ++i)
    {
        adj[A[i]].push_back(B[i]);
        adj[B[i]].push_back(A[i]);
    }
    predfs(1, 0);
    cnt = 0; cntpa = 1;
    hld(1);
    fill(chk, chk + n + 1, 2);
    for (int i = 1; i <= n; ++i) f[i].init();
}
int cat(int v)
{
    chk[v] = 0;
    query(v);
    return get(1, 1, n, in[1], out[1]).ans();
}
int dog(int v)
{
    chk[v] = 1;
    query(v);
    return get(1, 1, n, in[1], out[1]).ans();
}
int neighbor(int v)
{
    chk[v] = 2;
    query(v);
    return get(1, 1, n, in[1], out[1]).ans();
}


#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...