Submission #637847

#TimeUsernameProblemLanguageResultExecution timeMemory
637847fatemetmhrCats or Dogs (JOI18_catdog)C++17
100 / 100
793 ms28740 KiB
//  ~Be name khoda~  //
 
#include "catdog.h"
#include<bits/stdc++.h>
 
using namespace std;
 
typedef long long ll;
 
#define pb       push_back
#define mp       make_pair
#define all(x)   x.begin(), x.end()
#define fi       first
#define se       second
 
const int maxn  =  1e6   + 10;
const int maxn5 =  1e5   + 10;
const int maxnt =  4e5   + 10;
const int maxn3 =  1e3   + 10;
const int mod   =  1e9   +  7;
const ll  inf   =  1e18;
 
int n, id[maxn5], ti = -1, top[maxn5], par[maxn5];
int retnode[maxn5], a[maxn5], av[maxn5][3], big[maxn5];
int sz[maxn5], ans = 0;
vector <int> adj[maxn5];

struct catdog{
    int add[3], fnres[3];
    bool empty;
    catdog(){
        empty = false;
        add[0] = add[1] = add[2] = fnres[1] = fnres[2] = fnres[0] = 0;
    }
} seg[maxnt], empty_node, res;

catdog operator +(catdog a, catdog b){
    if(a.empty)
        return b;
    if(b.empty)
        return a;
    for(int i = 0; i < 3; i++){
        res.fnres[i] = b.fnres[a.fnres[i]];
        res.add[i] = b.add[a.fnres[i]] + a.add[i];
    }
    return res;
}

void update(int l, int r, int id, int v){
    if(r - l == 1){
        if(a[l] == 1){
            seg[v].fnres[0] = seg[v].fnres[1] = seg[v].fnres[2] = 1;
            seg[v].add[0] = seg[v].add[1] = seg[v].add[2] = av[l][2];
            seg[v].add[2]++;
        }
        if(a[l] == 2){
            seg[v].fnres[0] = seg[v].fnres[1] = seg[v].fnres[2] = 2;
            seg[v].add[0] = seg[v].add[1] = seg[v].add[2] = av[l][1];
            seg[v].add[1]++;
        }
        if(a[l] == 0){
            for(int i = 0; i < 3; i++){
                av[l][i]++;
                if(av[l][1] < av[l][2]){
                    seg[v].fnres[i] = 2;
                    seg[v].add[i] = av[l][1];
                }
                if(av[l][1] > av[l][2]){
                    seg[v].fnres[i] = 1;
                    seg[v].add[i] = av[l][2];
                }
                if(av[l][1] == av[l][2]){
                    seg[v].fnres[i] = 0;
                    seg[v].add[i] = av[l][1];
                }
                av[l][i]--;
            }
     
        }
        return;
    }

    int mid = (l + r) >> 1;
    if(id < mid)
        update(l, mid, id, v * 2);
    else
        update(mid, r, id, v * 2 + 1);
    seg[v] = seg[v * 2 + 1] + seg[v * 2];
    return;
}

catdog get(int l, int r, int lq, int rq, int v){
    if(rq <= l || r <= lq)
        return empty_node;
    if(lq <= l && r <= rq)
        return seg[v];
    int mid = (l + r) >> 1;
    return get(mid, r, lq, rq, v * 2 + 1) + get(l, mid, lq, rq, v * 2);
}

void build(int l, int r, int v){
    if(r - l == 1){
        seg[v].fnres[1] = 1;
        seg[v].fnres[2] = 2;
        return;
    }
    int mid = (l + r) >> 1;
    build(l, mid, v * 2);
    build(mid, r, v * 2 + 1);
    seg[v] = seg[v * 2 + 1] + seg[v * 2];
    return;
}
 
void upd(int v){
    auto last = get(0, n, id[top[v]], id[retnode[top[v]]] + 1, 1);
    ans -= last.add[0];
    update(0, n, id[v], 1);
    auto cur = get(0, n, id[top[v]], id[retnode[top[v]]] + 1, 1);
    ans += cur.add[0];
    if(par[top[v]] == -1)
        return;
    av[id[par[top[v]]]][last.fnres[0]]--;
    av[id[par[top[v]]]][cur.fnres[0]]++;
    upd(par[top[v]]);
    return;
}

void dfs_det(int v){
    sz[v] = 1;
    big[v] = -1;
    for(auto u : adj[v]) if(u != par[v]){
        par[u] = v;
        dfs_det(u);
        sz[v] += sz[u];
        if(big[v] == -1 || sz[u] > sz[big[v]])
            big[v] = u;
    }
    return;
}

void dfs_hld(int v){
    id[v] = ++ti;
    retnode[top[v]] = v;
    if(big[v] == -1)
        return;
    top[big[v]] = top[v];
    dfs_hld(big[v]);
    for(auto u : adj[v]) if(u != big[v] && u != par[v]){
        top[u] = u;
        dfs_hld(u);
    }
    return;
}

int cat(int v) {
    v--;
    a[id[v]] = 1;
    upd(v);
    return ans;
}
 
int dog(int v) {
    v--;
    a[id[v]] = 2;
    upd(v);
    return ans;
}
 
int neighbor(int v) {
    v--;
    a[id[v]] = 0;
    upd(v);
    return ans;
}
 
void initialize(int N, std::vector<int> A, std::vector<int> B) {
    n = N;
    empty_node.empty = true;
    build(0, n, 1);
    for(int i = 0; i < n - 1; i++){
        A[i]--; B[i]--;
        adj[A[i]].pb(B[i]);
        adj[B[i]].pb(A[i]);
    }
    par[0] = -1;
    dfs_det(0);
    dfs_hld(0);
    return;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...