제출 #1345236

#제출 시각아이디문제언어결과실행 시간메모리
1345236kokokaiCats or Dogs (JOI18_catdog)C++20
100 / 100
367 ms26544 KiB
#include <bits/stdc++.h>
#include "catdog.h"
using namespace std;
using ll = long long;
#define fi first
#define se second
#define task "text"
const int N = 1e5+5;
int chainid[N],chainhead[N],curchain,curpos,pos[N],end_head[N],sz[N],par[N],rev[N];
ll L0[N],L1[N];
int type[N];
vector<int> adj[N];
int n;
void dfs(int u,int p){
    sz[u]=1;
    par[u]=p;
    for(int v:adj[u]){
        if(v==p) continue;
        dfs(v,u);
        sz[u] += sz[v];
    }
}
void hld(int u,int p){
    if(!chainhead[curchain]){
        chainhead[curchain]=u;
    }
    curpos++;
    pos[u]=curpos;
    rev[curpos]=u;
    chainid[u]=curchain;
    end_head[curchain]=u;
    int nxt=0;
    for(int v:adj[u]){
        if(v==p) continue;
        if(sz[v]>sz[nxt]) nxt=v;
    }
    if(nxt) hld(nxt,u);
    for(int v:adj[u]){
        if(v==p || v==nxt) continue;
        curchain++;
        hld(v,u);
    }
}
struct matrix{
        ll a[2][2];
        int n,m;
        void init(){
            a[0][0]=a[0][1]=a[1][0]=a[1][1]=1e9;
        }
}st[4*N],emp;
matrix mul(matrix a,matrix b){
        if(a.a[0][0] == 1e9+1) return b;
        if(b.a[0][0] == 1e9+1) return a;
        matrix res;
        res.init();
        res.n=a.n;
        res.m=b.m;
        for(int i=0;i<2;i++){
            for(int k=0;k<2;k++){
                if(a.a[i][k] >= 1e9) continue;
                for(int j=0;j<2;j++){
                    res.a[i][j]=min(res.a[i][j],a.a[i][k]+b.a[k][j]);
                }
            }
        }
        return res;
}
matrix getmat(int u){
    matrix res;
    res.init();
    if(type[u] != 2){
        res.a[0][0]=L0[u];
        res.a[0][1]=L0[u]+1;
    }
    if(type[u] != 1){
        res.a[1][0]=L1[u]+1;
        res.a[1][1]=L1[u];
    }
    return res;
}
struct segtree{

    void build(int id,int l,int r){
        if(l == r){
            st[id]=getmat(rev[l]);
            return;
        }
        int mid=(l+r)>>1;
        build(id<<1,l,mid);
        build(id<<1|1,mid+1,r);
        st[id]=mul(st[id<<1],st[id<<1|1]);
    }
    void update(int id,int l,int r,int p){
        if(l==r){
            st[id]=getmat(rev[l]);
            return;
        }
        int mid=(l+r)>>1;
        if(mid<p) update(id<<1|1,mid+1,r,p);
        else update(id<<1,l,mid,p);
        st[id]=mul(st[id<<1],st[id<<1|1]);
    }
    matrix get(int id,int l,int r,int u,int v){
        if(r<u || v<l) return emp;
        if(u<=l && r<=v) return st[id];
        int mid=(l+r)>>1;
        return mul(get(id<<1,l,mid,u,v),get(id<<1|1,mid+1,r,u,v));
    }
}seg;

void initialize(int _n,vector<int> A,vector<int> B){
    n=_n;
    emp.a[0][0]=1e9+1;
    for(int i=0;i<A.size();i++){
        int u=A[i],v=B[i];
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
    dfs(1,0);
    hld(1,0);
    seg.build(1,1,n);
}
//cat:1,dog:2
pair<ll,ll> getdp(int head){
    matrix res=seg.get(1,1,n,pos[head],pos[end_head[chainid[head]]]);
    return {min(res.a[0][0],res.a[0][1]),min(res.a[1][0],res.a[1][1])};
}
ll update(int u){
    while(u){
        int head=chainhead[chainid[u]];
        int p=par[head];

        if(p){
            pair<ll,ll> res=getdp(head);
            L0[p] -= min(res.fi,res.se+1);
            L1[p] -= min(res.se,res.fi+1);
        }
        seg.update(1,1,n,pos[u]);
        if(p){
            pair<ll,ll> res=getdp(head);
            L0[p] += min(res.fi,res.se+1);
            L1[p] += min(res.se,res.fi+1);
        }
        u=p;
    }
    pair<ll,ll> ans=getdp(1);
    return min(ans.fi,ans.se);
}
int cat(int u){
    type[u]=1;
    return update(u);
}
int dog(int u){
    type[u]=2;
    return update(u);
}
int neighbor(int u){
    type[u]=0;
    return update(u);
}



#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...