Submission #134270

#TimeUsernameProblemLanguageResultExecution timeMemory
134270evpipisCats or Dogs (JOI18_catdog)C++11
38 / 100
216 ms14824 KiB
#include "catdog.h"
#include <bits/stdc++.h>
using namespace std;

#define pb push_back

const int len = 1e5+5;
int par[len], dep[len], col[len], sz[len], head[len], chain[len], arr[len], pos[len];
int ans, n, cnt, cntchain;
vector<int> adj[len];

struct node{
    int mn, mx, col, lazy;

    node(int a = 0, int b = 0, int c = 0, int d = 0){
        mn = a;
        mx = b;
        col = c;
        lazy = d;
    }
};

node tree[4*len];

void fix(int u){
    sz[u] = 1;
    for (int j = 0; j < adj[u].size(); j++){
        int v = adj[u][j];
        if (v == par[u])
            continue;

        dep[v] = dep[u]+1;
        par[v] = u;
        fix(v);
        sz[u] += sz[v];
    }
}

void hld(int u){
    pos[u] = ++cnt;
    arr[cnt] = u;
    chain[u] = cntchain;
    if (!head[cntchain])
        head[cntchain] = u;

    int big = -1;
    for (int j = 0; j < adj[u].size(); j++){
        int v = adj[u][j];
        if (v == par[u])
            continue;

        if (big == -1 || sz[v] > sz[big])
            big = v;
    }

    if (big != -1)
        hld(big);
    for (int j = 0; j < adj[u].size(); j++){
        int v = adj[u][j];
        if (v == par[u] || v == big)
            continue;

        cntchain++;
        hld(v);
    }
}

node join(node a, node b){
    return node(min(a.mn, b.mn), max(a.mx, b.mx), a.col|b.col);
}

void prop(int p, int l, int r){
    if (tree[p].lazy == 0) return;

    tree[p].mn += tree[p].lazy;
    tree[p].mx += tree[p].lazy;
    if (l != r){
        tree[2*p].lazy += tree[p].lazy;
        tree[2*p+1].lazy += tree[p].lazy;
    }
    tree[p].lazy = 0;
}

void update(int p, int l, int r, int i, int j, int x){
    prop(p, l, r);
    if (r < i || j < l)
        return;
    if (i <= l && r <= j)
        tree[p].lazy += x;
    else{
        int mid = (l+r)/2;
        update(2*p, l, mid, i, j, x);
        update(2*p+1, mid+1, r, i, j, x);

        prop(2*p, l, mid);
        prop(2*p+1, mid+1, r);

        tree[p] = join(tree[2*p], tree[2*p+1]);
    }
}

void colupd(int p, int l, int r, int i){
    prop(p, l, r);
    if (l == r)
        tree[p].col ^= 1;
    else{
        int mid = (l+r)/2;
        if (i <= mid)
            colupd(2*p, l, mid, i);
        else
            colupd(2*p+1, mid+1, r, i);

        tree[p] = join(tree[2*p], tree[2*p+1]);
    }
}

int query(int p, int l, int r, int i, int j, int a, int b){
    prop(p, l, r);

    if (r < i || j < l)
        return -1;
    if (i <= l && r <= j && a <= tree[p].mn && tree[p].mx <= b && !tree[p].col)
        return arr[l];
    if (l == r)
        return 0;

    int mid = (l+r)/2, rig = query(2*p+1, mid+1, r, i, j, a, b), lef;
    if (rig == arr[mid+1] || rig == -1){
        lef = query(2*p, l, mid, i, j, a, b);
        if (lef == -1)
            return rig; // rig can't be -1 as well
        if (lef == 0)
            return max(0, rig); // if (rig is empty return 0 instead
        return lef;
    }
    return rig;
}

int ask(int p, int l, int r, int i){
    prop(p, l, r);
    if (l == r)
        return tree[p].mn;

    int mid = (l+r)/2;
    if (i <= mid)
        return ask(2*p, l, mid, i);
    return ask(2*p+1, mid+1, r, i);
}

void upd(int u, int v, int x){
    while (true){
        if (chain[u] == chain[v]){
            update(1, 1, n, pos[v], pos[u], x);
            break;
        }

        update(1, 1, n, pos[head[chain[u]]], pos[u], x);
        u = par[head[chain[u]]];
    }
}

int fin(int u, int a, int b){
    //printf("fin: u = %d, a = %d, b = %d\n", u, a, b);
    int ans = 0;
    while (u != 0){
        int cur = query(1, 1, n, pos[head[chain[u]]], pos[u], a, b);
        //printf("u = %d, cur = %d\n", u, cur);
        if (cur != head[chain[u]]){
            if (cur == 0)
                return ans;
            return cur;
        }

        ans = cur;
        u = par[head[chain[u]]];
    }

    return ans;
}

/*void upd(int u, int v, int x){
    while (dep[u] >= dep[v])
        dif[u] += x, u = par[u];
}

int fin(int u, int l, int r){
    int ans = 0;
    while (u != 0 && col[u] == 0 && l <= dif[u] && dif[u] <= r)
        ans = u, u = par[u];
    return ans;
}*/

void change(int u, int t){
    /*
    3: red-green
    2: white-green
    1: red-white
    0: nothing
    -1: green-white
    -2: white-red
    -3: green-red
    */

    if (u == 0)
        return;

    //printf("change: u = %d, t = %d\n", u, t);

    if (t == 3){
        int v = fin(u, -1, -1), dif;
        //printf("v = %d\n", v);
        if (v != 0){
            upd(u, v, 2);

            v = par[v];
            if (v == 0) return;
        }
        else
            v = u;

        dif = ask(1, 1, n, pos[v]);
        if (col[v] == 1)
            ans--;
        else if (col[v] == -1)
            ans++;
        else if (dif <= -3)
            ans++;
        else if (dif == -2)
            ans++, change(par[v], 1);
        else if (dif == 0)
            ans--, change(par[v], 2);
        else if (dif >= 1)
            ans--;
        upd(v, v, 2);
    }
    else if (t == 2 || t == 1){
        int v = fin(u, -1, 0), c = t, dif;
        //printf("v = %d\n", v);
        if (v != 0){
            dif = ask(1, 1, n, pos[v]);
            if (t == 2 && dif == -1)
                ans++, c = 1;
            if (t == 1 && dif == 0)
                ans--, c = 2;
            upd(u, v, 1);

            v = par[v];
            if (v == 0) return;
        }
        else
            v = u;

        dif = ask(1, 1, n, pos[v]);
        if (c == 1){
            if (col[v] == 1)
                ans--;
            else if (col[v] == -1)
                ans += 0;
            else if (dif <= -2)
                ans += 0;
            else if (dif >= 1)
                ans--;
        }
        else{
            if (col[v] == 1)
                ans += 0;
            else if (col[v] == -1)
                ans++;
            else if (dif <= -2)
                ans++;
            else if (dif >= 1)
                ans += 0;
        }
        upd(v, v, 1);
    }
    else if (t == -1 || t == -2){
        int v = fin(u, 0, 1), c = t, dif;
        //printf("v = %d\n", v);
        if (v != 0){
            dif = ask(1, 1, n, pos[v]);
            if (t == -1 && dif == 0)
                ans--, c = -2;
            if (t == -2 && dif == 1)
                ans++, c = -1;
            upd(u, v, -1);

            v = par[v];
            if (v == 0) return;
        }
        else
            v = u;

        dif = ask(1, 1, n, pos[v]);
        if (c == -1){
            if (col[v] == 1)
                ans += 0;
            else if (col[v] == -1)
                ans--;
            else if (dif >= 2)
                ans += 0;
            else if (dif <= -1)
                ans--;
        }
        else{
            if (col[v] == 1)
                ans++;
            else if (col[v] == -1)
                ans += 0;
            else if (dif >= 2)
                ans++;
            else if (dif <= -1)
                ans += 0;
        }
        upd(v, v, -1);
    }
    else if (t == -3){
        int v = fin(u, 1, 1), dif;
        //printf("v = %d\n", v);
        if (v != 0){
            upd(u, v, -2);

            v = par[v];
            if (v == 0) return;
        }
        else
            v = u;

        dif = ask(1, 1, n, pos[v]);
        if (col[v] == 1)
            ans++;
        else if (col[v] == -1)
            ans--;
        else if (dif >= 3)
            ans++;
        else if (dif == 2)
            ans++, change(par[v], -1);
        else if (dif == 0)
            ans--, change(par[v], -2);
        else if (dif <= -1)
            ans--;
        upd(v, v, -2);
    }
}

void print(){
    for (int i = 1; i <= n; i++)
        printf("i = %d, dif = %d\n", i, ask(1, 1, n, pos[i]));
    printf("\n");
}

void initialize(int N, vector<int> A, vector<int> B){
    n = N;
    for (int i = 0; i < n-1; i++){
        int a = A[i], b = B[i];
        adj[a].pb(b);
        adj[b].pb(a);
    }

    fix(1), dep[0] = -1;
    hld(1);
}

int cat(int u){
    col[u] = -1;
    colupd(1, 1, n, pos[u]);

    int dif = ask(1, 1, n, pos[u]);
    if (dif > 0)
        ans += dif, change(par[u], -3);
    else if (dif == 0)
        change(par[u], -2);

    //print();
    return ans;
}

int dog(int u){
    col[u] = 1;
    colupd(1, 1, n, pos[u]);

    int dif = ask(1, 1, n, pos[u]);
    if (dif < 0)
        ans -= dif, change(par[u], 3);
    else if (dif == 0)
        change(par[u], 2);

    //print();
    return ans;
}

int neighbor(int u){
    int dif = ask(1, 1, n, pos[u]);
    if (col[u] == 1){
        if (dif < 0)
            ans += dif, change(par[u], -3);
        else if (dif == 0)
            change(par[u], -1);
    }
    else{
        if (dif > 0)
            ans -= dif, change(par[u], 3);
        else if (dif == 0)
            change(par[u], 1);
    }

    col[u] = 0;
    colupd(1, 1, n, pos[u]);

    //print();
    return ans;
}
/*
test cases:
5
1 2
2 3
2 4
1 5
8
1 3
2 4
2 5
3 3
3 4
1 4
1 1
2 2
*/

Compilation message (stderr)

catdog.cpp: In function 'void fix(int)':
catdog.cpp:27:23: warning: comparison between signed and unsigned integer expressions [-Wsign-compare]
     for (int j = 0; j < adj[u].size(); j++){
                     ~~^~~~~~~~~~~~~~~
catdog.cpp: In function 'void hld(int)':
catdog.cpp:47:23: warning: comparison between signed and unsigned integer expressions [-Wsign-compare]
     for (int j = 0; j < adj[u].size(); j++){
                     ~~^~~~~~~~~~~~~~~
catdog.cpp:58:23: warning: comparison between signed and unsigned integer expressions [-Wsign-compare]
     for (int j = 0; j < adj[u].size(); j++){
                     ~~^~~~~~~~~~~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...