#include "catdog.h"
#include<bits/stdc++.h>
using namespace std;
const int maxn=100000+10;
vector<int>adj[maxn];
int timea=0,now[maxn][2],val[maxn],kaf=(1<<17),n,inf=1e6+5,sz[maxn],part[maxn],parh[maxn],akh[maxn];
pair<int,int>stf[maxn];
set<int>allr[maxn],allb[maxn];
int mainres=0;
struct node{
int res[2][2],ted;
node(){
res[0][1]=res[1][0]=0;
res[0][0]=res[1][1]=0;
ted=0;
}
}fn;
struct seg{
node seg[(1<<18)];
node merge(node a,node b){
if(a.ted==0){
return b;
}
if(b.ted==0){
return a;
}
node ret;
ret.res[0][0]=ret.res[1][0]=ret.res[1][1]=ret.res[0][1]=inf;
ret.ted=a.ted+b.ted;
for(int i=0;i<2;i++){
for(int j=0;j<2;j++){
for(int ii=0;ii<2;ii++){
for(int jj=0;jj<2;jj++){
ret.res[i][j]=min(ret.res[i][j],a.res[i][ii]+b.res[jj][j]+(ii!=jj));
}
}
}
}
return ret;
}
void add(int i,int wr,int wb,int to){
if(i==0){
return ;
}
if(i>=kaf){
// //cout<<i<<" "<<i-kaf<<" "<<wr<<" "<<wb<<" "<<to<<"\n";
seg[i].ted=1;
if(to==0){
seg[i].res[0][0]=seg[i].res[1][1]=seg[i].res[1][0]=seg[i].res[0][1]=inf;
seg[i].res[0][0]=wr;
seg[i].res[1][1]=wb;
return add((i>>1),wr,wb,to);
}
if(to==1){
seg[i].res[0][0]=seg[i].res[1][1]=seg[i].res[1][0]=seg[i].res[0][1]=inf;
seg[i].res[1][1]=wb;
return add((i>>1),wr,wb,to);
}
seg[i].res[0][0]=seg[i].res[1][1]=seg[i].res[1][0]=seg[i].res[0][1]=inf;
seg[i].res[0][0]=wr;
return add((i>>1),wr,wb,to);
}
seg[i]=merge(seg[(i<<1)],seg[(i<<1)^1]);
return add((i>>1),wr,wb,to);
}
node pors(int i,int l,int r,int tl,int tr){
if(l>r||l>tr||r<tl||tl>tr){
return fn;
}
if(l>=tl&&r<=tr){
return seg[i];
}
int m=(l+r)>>1;
return merge(pors((i<<1),l,m,tl,tr),pors((i<<1)^1,m+1,r,tl,tr));
}
}seg;
bool cmp(int a,int b){
return sz[a]>sz[b];
}
void pre(int u,int par=0){
sz[u]=1;
part[u]=par;
for(auto x:adj[u]){
if(x==par){
continue;
}
pre(x,u);
sz[u]+=sz[x];
}
sort(adj[u].begin(),adj[u].end());
if(u!=1){
adj[u].erase(lower_bound(adj[u].begin(),adj[u].end(),par));
}
sort(adj[u].begin(),adj[u].end(),cmp);
}
void makehld(int u=1,int par=1){
// //cout<<u<<" salam "<<par<<endl;
parh[u]=par;
timea++;
stf[u].first=timea;
akh[par]=u;
if((int)adj[u].size()>0){
makehld(adj[u][0],par);
for(int i=1;i<(int)adj[u].size();i++){
makehld(adj[u][i],adj[u][i]);
}
}
stf[u].second=timea;
}
void upd(int u){
//cout<<u<<" haha "<<" "<<allr[u].size()<<" "<<allb[u].size()<<endl;
node av=seg.pors(1,0,kaf-1,stf[parh[u]].first,stf[akh[parh[u]]].first);
int z=min(min(av.res[0][0],av.res[0][1]),min(av.res[1][0],av.res[1][1]));
mainres-=z;
if(parh[u]!=1){
if((av.res[0][0]==z||av.res[0][1]==z)&&(av.res[1][0]==z||av.res[1][1]==z)){
//hehe
}
else if(av.res[0][0]==z||av.res[0][1]==z){
allb[part[parh[u]]].erase(u);
}
else{
allr[part[parh[u]]].erase(u);
}
}
seg.add(stf[u].first+kaf,(int)allr[u].size(),(int)allb[u].size(),val[u]);
av=seg.pors(1,0,kaf-1,stf[parh[u]].first,stf[akh[parh[u]]].first);
z=min(min(av.res[0][0],av.res[0][1]),min(av.res[1][0],av.res[1][1]));
mainres+=z;
if(parh[u]!=1){
if((av.res[0][0]==z||av.res[0][1]==z)&&(av.res[1][0]==z||av.res[1][1]==z)){
//hehe
}
else if(av.res[0][0]==z||av.res[0][1]==z){
allb[part[parh[u]]].insert(u);
}
else{
allr[part[parh[u]]].insert(u);
}
return upd(part[parh[u]]);
}
}
void initialize(int N, std::vector<int> A, std::vector<int> B) {
n=N;
for(int i=0;i<n-1;i++){
adj[A[i]].push_back(B[i]);
adj[B[i]].push_back(A[i]);
}
pre(1);
makehld();
}
int cat(int v) {
val[v]=1;
upd(v);
return mainres;
}
int dog(int v){
val[v]=2;
upd(v);
return mainres;
}
int neighbor(int v) {
val[v]=0;
upd(v);
return mainres;
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
4 ms |
20828 KB |
Output is correct |
2 |
Incorrect |
4 ms |
20824 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
4 ms |
20828 KB |
Output is correct |
2 |
Incorrect |
4 ms |
20824 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
4 ms |
20828 KB |
Output is correct |
2 |
Incorrect |
4 ms |
20824 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |