Submission #410677

#TimeUsernameProblemLanguageResultExecution timeMemory
410677dooweyCats or Dogs (JOI18_catdog)C++14
100 / 100
369 ms46260 KiB
#include "catdog.h"
#include <bits/stdc++.h>

using namespace std;

typedef long long ll;
typedef pair<int, int> pii;

int n;

const int N = (int)1e5 + 100;
int sub[N];
int par[N];
int link[N];
int sz[N];
vector<int> T[N];

void dfs(int u){
    sub[u]=1;
    for(auto x : T[u]){
        if(x != par[u]){
            par[x]=u;
            dfs(x);
            sub[u] += sub[x];
        }
    }
}

void make_hld(int u){
    for(auto x : T[u]){
        if(x != par[u]){
            if(sub[x] > sub[u] / 2){
                link[x] = link[u];
            }
            else{
                link[x] = x;
            }
            make_hld(x);
            if(sub[x] > sub[u] / 2){
                sz[u] = sz[x] + 1;
            }
        }
    }
}

const int inf = (int)1e7 + 1;

struct segment_tree{
    struct Node{
        int val[2][2];
    };
    int sz;
    vector<Node> TR;
    Node unite(Node A, Node B){
        Node res;
        for(int i = 0 ; i < 2; i ++ ){
            for(int j = 0 ; j < 2; j ++ ){
                res.val[i][j] = inf;
            }
        }
        for(int p = 0; p < 2; p ++ ){
            for(int q = 0; q < 2; q ++ ){
                for(int ai = 0; ai < 2; ai ++ ){
                    for(int bi = 0; bi < 2; bi ++ ){
                        res.val[p][bi] = min(res.val[p][bi], A.val[p][q] + B.val[ai][bi] + (q != ai));
                    }
                }
            }
        }
        return res;
    }
    void build(int node, int cl, int cr){
        if(node == 1){
            sz = cr + 1;
            TR.resize(sz * 4 + 16);
        }
        if(cl == cr){
            TR[node].val[0][0] = 0;
            TR[node].val[1][1] = 0;
            TR[node].val[0][1] = inf;
            TR[node].val[1][0] = inf;
            return;
        }
        int mid = (cl + cr) / 2;
        build(node * 2, cl, mid);
        build(node * 2 + 1, mid + 1, cr);
        TR[node] = unite(TR[node * 2], TR[node * 2 + 1]);
    }
    void sub(int node, int cl, int cr, int id, int mode, int v){
        if(cl == cr){
            TR[node].val[mode][mode] = v;
            return;
        }
        int mid = (cl + cr) / 2;
        if(mid >= id)
            sub(node * 2, cl, mid, id, mode, v);
        else
            sub(node * 2 + 1, mid + 1, cr, id, mode, v);
        TR[node] = unite(TR[node * 2], TR[node * 2 + 1]);
    }
    Node gg;
    bool go;
    void get(int node, int cl, int cr, int tl, int tr){
        if(node==1) go = false;
        if(cr < tl || cl > tr) return;
        if(cl >= tl && cr <= tr){
            if(!go){
                gg = TR[node];
            }
            else{
                gg = unite(gg, TR[node]);
            }
            return;
        }
        int mid = (cl + cr) / 2;
        get(node * 2, cl, mid, tl, tr);
        get(node * 2 + 1, mid + 1, cr, tl, tr);
    }
    void visit(int node, int cl, int cr){
        if(cl == cr){
            cout << "X: " << TR[node].val[0][0] << " Y: " << TR[node].val[1][1] << "\n";
            return;
        }
        int mid = (cl + cr) / 2;
        visit(node * 2, cl, mid);
        visit(node * 2 + 1, mid + 1, cr);

    }
};

segment_tree chain[N];
int type[N];
ll cx[N];
ll cy[N];

void initialize(int _N, vector<int> A, vector<int> B) {
    n = _N;
    par[1] = -1;
    for(int i = 0 ; i < n - 1; i ++ ){
        T[A[i]].push_back(B[i]);
        T[B[i]].push_back(A[i]);
    }
    dfs(1);
    link[1] = 1;

    make_hld(1);
    int pi;
    int cur;
    for(int i = 1; i <= n; i ++ ){
        if(sz[i] == 0){
            pi = i;
            cur = 0;
            while(1){
                cur ++ ;
                if(pi == link[pi]) break;
                pi = par[pi];
            }
            chain[link[i]].sz = cur;
            chain[link[i]].build(1, 0, cur - 1);
        }
    }
}

void walk(int node){
    ll las0, las1;
    int nex;
    while(1){
        nex = par[link[node]];
        las0 = min(chain[link[node]].TR[1].val[0][0],chain[link[node]].TR[1].val[1][0]);
        las1 = min(chain[link[node]].TR[1].val[0][1],chain[link[node]].TR[1].val[1][1]);

        if(nex != -1){
            cx[nex] -= min(las0, las1 + 1);
            cy[nex] -= min(las1, las0 + 1);
        }
        if(type[node] == 0){
            chain[link[node]].sub(1, 0, chain[link[node]].sz - 1, sz[node], 0, cx[node]);
            chain[link[node]].sub(1, 0, chain[link[node]].sz - 1, sz[node], 1, cy[node]);
        }
        else if(type[node] == 1){
            chain[link[node]].sub(1, 0, chain[link[node]].sz - 1, sz[node], 0, cx[node]);
            chain[link[node]].sub(1, 0, chain[link[node]].sz - 1, sz[node], 1, inf);
        }
        else{
            chain[link[node]].sub(1, 0, chain[link[node]].sz - 1, sz[node], 1, cy[node]);
            chain[link[node]].sub(1, 0, chain[link[node]].sz - 1, sz[node], 0, inf);
        }
        las0 = min(chain[link[node]].TR[1].val[0][0],chain[link[node]].TR[1].val[1][0]);
        las1 = min(chain[link[node]].TR[1].val[0][1],chain[link[node]].TR[1].val[1][1]);

        if(nex != -1){
            cx[nex] += min(las0, las1 + 1);
            cy[nex] += min(las1, las0 + 1);
        }

        if(nex == -1) break;
        node = nex;
    }
}

int get(){
    int outp = inf;
    for(int p = 0 ; p < 2; p ++ ){
        for(int q = 0; q < 2; q ++ ){
            outp = min(outp, chain[1].TR[1].val[p][q]);
        }
    }
    return outp;
}

int cat(int node) { // set to 0
    type[node] = 1;
    walk(node);
    return get();
}

int dog(int node) { // set to 1
    type[node] = 2;
    walk(node);
    return get();
}

int neighbor(int node) { // any
    type[node] = 0;
    walk(node);
    return get();
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...