Submission #956507

#TimeUsernameProblemLanguageResultExecution timeMemory
956507Double_SlashCats or Dogs (JOI18_catdog)C++17
100 / 100
175 ms34772 KiB
#include "catdog.h"
#include <bits/stdc++.h>

using namespace std;

const int INF = 1e9;
int n, par[100001];
int heavy[100001] = {0}, sz[100001] = {0}, depth[100001] = {0}, head[100001];
vector<int> adj[100001];
struct Node {
    int l, r;
    int cats = 0, dogs = 0;
    int mndiff = 0, mxdiff = 0;
    int lazyC = 0, lazyD = 0;
    int type = 0;
    int idx = 0;
    Node *lc, *rc;

    Node(int l, int r) : l(l), r(r) {
        if (l < r) {
            int m = (l + r) >> 1;
            lc = new Node{l, m};
            rc = new Node{m + 1, r};
        }
    }

    void clean() {
        if (lazyC or lazyD) {
            mndiff += lazyC - lazyD;
            mxdiff += lazyC - lazyD;
            if (l < r) {
                lc->lazyC += lazyC;
                rc->lazyC += lazyC;
                lc->lazyD += lazyD;
                rc->lazyD += lazyD;
            } else {
                cats += lazyC, dogs += lazyD;
            }
            lazyC = lazyD = 0;
        }
    }

    void label(int i, int j) {
        if (l == r) {
            idx = j;
            return;
        }
        int m = (l + r) >> 1;
        if (i <= m) lc->label(i, j);
        else rc->label(i, j);
    }

    pair<int, int> query(int i) {
        clean();
        if (l == r) {
            if (type == 1 or not type and cats < dogs) {
                return {cats, cats + 1};
            } else if (type == 2 or not type and cats > dogs) {
                return {dogs + 1, dogs};
            } else {
                return {cats, dogs};
            }
        }
        int m = (l + r) >> 1;
        if (i <= m) return lc->query(i);
        else return rc->query(i);
    }

    void inc(int ul, int ur, int c, int d) {
        clean();
        if (ul > r or ur < l) return;
        if (l >= ul and r <= ur) {
            lazyC += c;
            lazyD += d;
            clean();
            return;
        }
        lc->inc(ul, ur, c, d);
        rc->inc(ul, ur, c, d);
        mndiff = min(lc->mndiff, rc->mndiff);
        mxdiff = max(lc->mxdiff, rc->mxdiff);
    }

    void update(int i, int t) {
        if (i < l or i > r) return;
        if (l == r) {
            type = t;
            return;
        }
        lc->update(i, t);
        rc->update(i, t);
        type = max(lc->type, rc->type);
    }

    pair<int, int> le(int i, int x) {
        clean();
        if (i < l) return {0, 0};
        if (mndiff > x) return {0, 0};
        if (l == r) return {idx, mndiff};
        auto q = rc->le(i, x);
        return q.first ? q : lc->le(i, x);
    }

    pair<int, int> ge(int i, int x) {
        clean();
        if (i < l) return {0, 0};
        if (mxdiff < x) return {0, 0};
        if (l == r) return {idx, mxdiff};
        auto q = rc->ge(i, x);
        return q.first ? q : lc->ge(i, x);
    }

    pair<int, int> occupied(int i) {
        if (i < l or not type) return {0, 0};
        if (l == r) return {idx, type};
        auto q = rc->occupied(i);
        return q.first ? q : lc->occupied(i);
    }
} *st[100001] = {nullptr};

void build(int i) {
    for (int j: adj[i]) {
        if (j != par[i]) {
            par[j] = i;
            depth[j] = depth[i] + 1;
            build(j);
            sz[i] += sz[j];
            if (sz[j] > sz[heavy[i]]) {
                heavy[i] = j;
            }
        }
    }
    if (sz[heavy[i]] <= sz[i]++ / 2) heavy[i] = 0;
}

void hld(int i, int h) {
    if (heavy[i]) {
        hld(heavy[i], h);
        st[i] = st[heavy[i]];
    } else {
        st[i] = new Node{depth[h], depth[i]};
    }
    head[i] = par[h];
    for (int j: adj[i]) {
        if (j != par[i] and j != heavy[i]) {
            hld(j, j);
        }
    }
}

void inc(int i, int c, int d = INF) {
    if (not i) return;
    if (c == d or d == INF) {
        st[i]->inc(0, depth[i], c, c);
        inc(head[i], c, c);
        return;
    }
    int diff1 = c - d;
    if (diff1 == -2) {
        auto [i0, v0] = st[i]->le(depth[i], 0);
        auto [i1, v1] = st[i]->ge(depth[i], 2);
        auto [io, t] = st[i]->occupied(depth[i]);
        if (io and depth[io] >= max(depth[i0], depth[i1])) {
            st[i]->inc(depth[io], depth[i], c, d);
            inc(par[io], t == 1 ? c : d);
        } else if (depth[i0] > depth[i1]) {
            st[i]->inc(depth[i0], depth[i], c, d);
            inc(par[i0], c, v0 == 0 ? c + 1 : c);
        } else if (depth[i1] > depth[i0]) {
            st[i]->inc(depth[i1], depth[i], c, d);
            inc(par[i1], v1 == 2 ? d - 1 : d, d);
        } else {
            st[i]->inc(0, depth[i], c, d);
            inc(head[i], c, d);
        }
    } else if (diff1 == -1) {
        int i0 = st[i]->le(depth[i], -1).first;
        int i1 = st[i]->ge(depth[i], 2).first;
        auto [io, t] = st[i]->occupied(depth[i]);
        if (io and depth[io] >= max(depth[i0], depth[i1])) {
            st[i]->inc(depth[io], depth[i], c, d);
            inc(par[io], t == 1 ? c : d);
        } else if (depth[i0] > depth[i1]) {
            st[i]->inc(depth[i0], depth[i], c, d);
            inc(par[i0], c);
        } else if (depth[i1] > depth[i0]) {
            st[i]->inc(depth[i1], depth[i], c, d);
            inc(par[i1], d);
        } else {
            st[i]->inc(0, depth[i], c, d);
            inc(head[i], c, d);
        }
    } else if (diff1 == 1) {
        int i0 = st[i]->le(depth[i], -2).first;
        int i1 = st[i]->ge(depth[i], 1).first;
        auto [io, t] = st[i]->occupied(depth[i]);
        if (io and depth[io] >= max(depth[i0], depth[i1])) {
            st[i]->inc(depth[io], depth[i], c, d);
            inc(par[io], t == 1 ? c : d);
        } else if (depth[i0] > depth[i1]) {
            st[i]->inc(depth[i0], depth[i], c, d);
            inc(par[i0], c);
        } else if (depth[i1] > depth[i0]) {
            st[i]->inc(depth[i1], depth[i], c, d);
            inc(par[i1], d);
        } else {
            st[i]->inc(0, depth[i], c, d);
            inc(head[i], c, d);
        }
    } else if (diff1 == 2) {
        auto [i0, v0] = st[i]->le(depth[i], -2);
        auto [i1, v1] = st[i]->ge(depth[i], 0);
        auto [io, t] = st[i]->occupied(depth[i]);
        if (io and depth[io] >= max(depth[i0], depth[i1])) {
            st[i]->inc(depth[io], depth[i], c, d);
            inc(par[io], t == 1 ? c : d);
        } else if (depth[i0] > depth[i1]) {
            st[i]->inc(depth[i0], depth[i], c, d);
            inc(par[i0], c, v0 == -2 ? c - 1 : c);
        } else if (depth[i1] > depth[i0]) {
            st[i]->inc(depth[i1], depth[i], c, d);
            inc(par[i1], v1 == 0 ? d + 1 : d, d);
        } else {
            st[i]->inc(0, depth[i], c, d);
            inc(head[i], c, d);
        }
    }
}

int update(int i, int t) {
    auto [c0, d0] = st[i]->query(depth[i]);
    st[i]->update(depth[i], t);
    if (par[i]) {
        auto [c1, d1] = st[i]->query(depth[i]);
        int dc = c1 - c0, dd = d1 - d0;
        inc(par[i], dc, dd);
    }
    auto [c, d] = st[1]->query(depth[1]);
    return min(c, d);
}

void initialize(int N, vector<int> A, vector<int> B) {
    n = N;
    for (int i = 0; i < n - 1; ++i) {
        adj[A[i]].emplace_back(B[i]);
        adj[B[i]].emplace_back(A[i]);
    }
    par[1] = 0;
    build(1);
    hld(1, 1);
    for (int i = 1; i <= n; ++i) {
        st[i]->label(depth[i], i);
    }
    depth[0] = -INF;
}

int cat(int v) {
    return update(v, 1);
}

int dog(int v) {
    return update(v, 2);
}

int neighbor(int v) {
    return update(v, 0);
}

Compilation message (stderr)

catdog.cpp: In member function 'std::pair<int, int> Node::query(int)':
catdog.cpp:56:39: warning: suggest parentheses around '&&' within '||' [-Wparentheses]
   56 |             if (type == 1 or not type and cats < dogs) {
      |                              ~~~~~~~~~^~~~~~~~~~~~~~~
catdog.cpp:58:46: warning: suggest parentheses around '&&' within '||' [-Wparentheses]
   58 |             } else if (type == 2 or not type and cats > dogs) {
      |                                     ~~~~~~~~~^~~~~~~~~~~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...