Submission #1173517

#TimeUsernameProblemLanguageResultExecution timeMemory
1173517_callmelucianCats or Dogs (JOI18_catdog)C++17
100 / 100
1765 ms79364 KiB
#include <bits/stdc++.h>
using namespace std;

using ll = long long;
using ld = long double;
using pl = pair<ll,ll>;
using pii = pair<int,int>;
using tpl = tuple<int,int,int>;

#define all(a) a.begin(), a.end()
#define filter(a) a.erase(unique(all(a)), a.end())

const int mn = 1e5 + 5;
const int INF = 1e9;

struct node : vector<vector<int>> {
    node() : vector<vector<int>>(2, vector<int>(2, INF)) {}
    node (int a, int b, int c, int d) : vector<vector<int>>({{a, b}, {c, d}}) {}

    const node operator + (const node &o) {
        if ((*this) == node()) return o;
        if (o == node()) return (*this);

        node ans;
        for (int i = 0; i < 2; i++) {
            for (int j = 0; j < 2; j++) {
                ans[i][j] = INF;
                for (int s1 = 0; s1 < 2; s1++)
                    for (int s2 = 0; s2 < 2; s2++)
                        if (max(o[s2][j], (*this)[i][s1]) != INF)
                            ans[i][j] = min(ans[i][j], (*this)[i][s1] + o[s2][j] + (s1 ^ s2));
            }
        }
        return ans;
    }

    int minElement() {
        int ans = INT_MAX;
        for (vector<int> &v : (*this))
            ans = min(ans, *min_element(all(v)));
        return ans;
    }
};

const vector<node> base = { node(0, INF, INF, 0),
                            node(0, INF, INF, INF),
                            node(INF, INF, INF , 0) };

struct IT {
    vector<node> tr;
    IT (int sz = 0) : tr(4 * sz) {}

    void update (int pos, node cur, int k, int l, int r) {
        for (; l < r;) {
            int mid = (l + r) >> 1;
            if (pos <= mid) k <<= 1, r = mid;
            else k <<= 1, k |= 1, l = mid + 1;
        }
        tr[k] = cur;
        for (k >>= 1; k > 0; k >>= 1)
            tr[k] = tr[k << 1] + tr[k << 1 | 1];
    }

    node getNode() { return tr[1]; }
} tree[mn];

int depth[mn], sz[mn], par[mn], sumChild[2][mn], n;
int chain[mn], chainSz[mn], type[mn];
vector<int> adj[mn];

int szDfs (int u, int p) {
    sz[u] = 1;
    for (int v : adj[u])
        if (v != p) sz[u] += szDfs(v, u);
    return sz[u];
}

void dfs (int u, int p, int d, bool toP) {
    if (u == 1) szDfs(u, p);
    chain[u] = (toP ? chain[p] : u), chainSz[chain[u]]++;
    par[u] = p, depth[u] = d;

    sort(all(adj[u]), [&] (int a, int b) { return sz[a] > sz[b]; });
    bool heavy = 1;
    for (int v : adj[u])
        if (v != p) dfs(v, u, d + 1, heavy), heavy = 0;
}

void modify (int u) {
    while (u) {
        // pre-calculate some information for current node
        int head = chain[u], sz = chainSz[head], id = depth[u] - depth[head] + 1;
        node cur = tree[head].getNode();

        // remove contribution to sumChild[...][par[head]]
        sumChild[0][par[head]] -= min(min(cur[0][0], cur[0][1]), min(cur[1][0], cur[1][1]) + 1);
        sumChild[1][par[head]] -= min(min(cur[1][0], cur[1][1]), min(cur[0][0], cur[0][1]) + 1);

        // update current node
        cur = base[type[u]];
        if (cur[0][0] != INF) cur[0][0] += sumChild[0][u];
        if (cur[1][1] != INF) cur[1][1] += sumChild[1][u];
        tree[head].update(id, cur, 1, 1, sz);

        // re-add contribution to sumChild[...][par[head]]
        cur = tree[head].getNode();
        sumChild[0][par[head]] += min(min(cur[0][0], cur[0][1]), min(cur[1][0], cur[1][1]) + 1);
        sumChild[1][par[head]] += min(min(cur[1][0], cur[1][1]), min(cur[0][0], cur[0][1]) + 1);

        u = par[head];
    }
}

int solve() { return tree[1].getNode().minElement(); }

int neighbor (int u) {
    type[u] = 0, modify(u);
    return solve();
}

int cat (int u) {
    type[u] = 1, modify(u);
    return solve();
}

int dog (int u) {
    type[u] = 2, modify(u);
    return solve();
}

void initialize (int n, vector<int> a, vector<int> b) {
    ::n = n;
    for (int i = 0; i < n - 1; i++) {
        adj[a[i]].push_back(b[i]);
        adj[b[i]].push_back(a[i]);
    }
    dfs(1, 0, 1, 0);
    for (int i = 1; i <= n; i++) {
        if (chain[i] != i) continue;
        tree[i] = IT(chainSz[i]);
        for (int j = 1; j <= chainSz[i]; j++)
            tree[i].update(j, base[0], 1, 1, chainSz[i]);
    }
}

#ifdef LOCAL
int main()
{
    ios::sync_with_stdio(0);
    cin.tie(0);

    initialize(5, {1, 2, 2, 4}, {2, 3, 4, 5});

    cout << cat(3) << " " << dog(5) << "\n";
    cout << cat(2) << " " << dog(1) << "\n";
    cout << neighbor(2) << "\n";

    return 0;
}
#endif // LOCAL
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...