# | 제출 시각 | 아이디 | 문제 | 언어 | 결과 | 실행 시간 | 메모리 |
---|---|---|---|---|---|---|---|
1011085 | imarn | JOI tour (JOI24_joitour) | C++17 | 0 ms | 0 KiB |
이 제출은 이전 버전의 oj.uz에서 채점하였습니다. 현재는 제출 당시와는 다른 서버에서 채점을 하기 때문에, 다시 제출하면 결과가 달라질 수도 있습니다.
#include<bits/stdc++.h>
#pragma GCC optimize("O3")
#pragma GCC target("avx2")
#define ll long long
#define pii pair<int,int>
#define pll pair<ll,ll>
#define plx pair<ll,int>
#define f first
#define s second
#define pb push_back
#define all(x) x.begin(),x.end()
#define vi vector<int>
#define vl vector<ll>
#define vvi vector<vi>
using namespace std;
const int mxn=2e5+5;
vector<int>g[mxn],f;
int sz[mxn]{0},lv[mxn]{0},pr[mxn];
bool vis[mxn]{0};
ll ans=0;
int getsz(int u,int p){
sz[u]=1;
for(auto v:g[u])if(v!=p&&!vis[v])sz[u]+=getsz(v,u);
return sz[u];
}
int getct(int u,int p,int r){
for(auto v:g[u])if(v!=p&&!vis[v]&&sz[v]*2>r)return getct(v,u,r);
return u;
}
vector<ll>dp[6][mxn];
ll td[6][mxn]{0};
struct few{
int fw[mxn]{0};
void add(int i,int amt){
for(;i<mxn;i+=i&-i)fw[i]+=amt;
}
ll qr(int i,ll res=0){
for(;i;i-=i&-i)res+=fw[i];
return res;
}
}fw[3][19],wf[3][19];
int tin[19][mxn],tout[19][mxn],cnt[19]{0},ID[19][mxn];
ll getP(int col,int lvl,int u){
return fw[col][lvl].qr(tin[lvl][u]);
}
ll getC(int col,int lvl,int u){
return wf[col][lvl].qr(tout[lvl][u])-wf[col][lvl].qr(tin[lvl][u]-1);
}
void gen1(int u,int p,int x,int id,int x0,int x1,int x2){
dp[f[u]][x][id]++;tin[lv[x]][u]=++cnt[lv[x]];
fw[f[u]][lv[x]].add(tin[lv[x]][u],1);
if(f[u]==2)dp[3][x][id]+=x1;
for(auto v:g[u])if(v!=p&&!vis[v])gen1(v,u,x,id,x0+(f[v]==0),x1+(f[v]==1),x2+(f[v]==2));
tout[lv[x]][u]=cnt[lv[x]];ID[lv[x]][u]=id;
fw[f[u]][lv[x]].add(tout[lv[x]][u]+1,-1);
}
void gen2(int u,int p,int x,int id){
for(auto v:g[u])if(v!=p&&!vis[v])gen2(v,u,x,id);
wf[f[u]][lv[x]].add(tin[lv[x]][u],1);
if(f[u]==1)dp[4][x][id]+=getC(0,lv[x],u);
}
void play(int u,int p){int id=-1;
u = getct(u,u,getsz(u,u));vis[u]=1;
pr[u]=p;if(p!=-1)lv[u]=lv[p]+1;
tin[lv[u]][u]=++cnt[lv[u]];
for(auto v:g[u]){
if(vis[v])continue;
for(int i=0;i<6;i++)dp[i][u].pb(0);id++;
gen1(v,u,u,id,(f[v]==0),(f[v]==1),(f[v]==2));gen2(v,u,u,id);
for(int i=0;i<5;i++)td[i][u]+=dp[i][u][id];td[5][u]+=dp[0][u][id]*dp[2][u][id];
}tout[lv[u]][u]=cnt[lv[u]];
for(int j=0;j<=id;j++){
ans+=dp[0][u][j]*(td[3][u]-dp[3][u][j]);
ans+=dp[4][u][j]*(td[2][u]-dp[2][u][j]);
if(f[u]==1)ans+=dp[0][u][j]*(td[2][u]-dp[2][u][j]);
}if(f[u]==0)ans+=td[3][u];if(f[u]==2)ans+=td[4][u];
for(auto v:g[u])if(!vis[v])play(v,u);
}
void init(int N, std::vector<int> F, std::vector<int> U, std::vector<int> V,int Q){
f=F;for(int i=0;i<N-1;i++)g[U[i]].pb(V[i]),g[V[i]].pb(U[i]);
play(0,-1);
}
void up(int col,int u,int sig){
int x=u;
if(col==0){
ans+=td[3][u]*sig;x=pr[u];int j=ID[lv[x]][u];
while(x!=-1){
ans-=dp[0][x][j]*(td[3][x]-dp[3][x][j]);
ans-=dp[4][x][j]*(td[2][x]-dp[2][x][j]);
ans-=dp[3][x][j]*(td[0][x]-dp[0][x][j]);
ans-=dp[2][x][j]*(td[4][x]-dp[4][x][j]);
if(f[x]==1)ans-=td[0][x]*td[2][x]-td[5][x];
if(f[x]==2)ans-=td[4][x];
td[5][x]-=dp[0][x][j]*dp[2][x][j];
td[4][x]-=dp[4][x][j];
td[0][x]+=sig;dp[0][x][j]+=sig;
dp[4][x][j]+=sig*getP(1,lv[x],u);
td[4][x]+=dp[4][x][j];
td[5][x]+=dp[0][x][j]*dp[2][x][j];
if(f[x]==1)ans+=td[0][x]*td[2][x]-td[5][x];
if(f[x]==2)ans+=td[4][x];
ans+=dp[0][x][j]*(td[3][x]-dp[3][x][j]);
ans+=dp[4][x][j]*(td[2][x]-dp[2][x][j]);
ans+=dp[3][x][j]*(td[0][x]-dp[0][x][j]);
ans+=dp[2][x][j]*(td[4][x]-dp[4][x][j]);
fw[0][lv[x]].add(tin[lv[x]][u],sig);
fw[0][lv[x]].add(tout[lv[x]][u]+1,-sig);
wf[0][lv[x]].add(tin[lv[x]][u],sig);
if(pr[x]==-1)break;x=pr[x];j=ID[lv[x]][u];
}
}
else if(col==1){
ans+=(td[0][u]*td[2][u]-td[5][u])*sig;x=pr[u];int j=ID[lv[x]][u];
while(x!=-1){
ans-=dp[0][x][j]*(td[3][x]-dp[3][x][j]);
ans-=dp[4][x][j]*(td[2][x]-dp[2][x][j]);
ans-=dp[3][x][j]*(td[0][x]-dp[0][x][j]);
ans-=dp[2][x][j]*(td[4][x]-dp[4][x][j]);
if(f[x]==0)ans-=td[3][x];
if(f[x]==2)ans-=td[4][x];
td[4][x]-=dp[4][x][j];
td[3][x]-=dp[3][x][j];
td[1][x]+=sig;dp[1][x][j]+=sig;
dp[4][x][j]+=sig*getC(0,lv[x],u);
dp[3][x][j]+=sig*getC(2,lv[x],u);
td[4][x]+=dp[4][x][j];
td[3][x]+=dp[3][x][j];
if(f[x]==0)ans+=td[3][x];
if(f[x]==2)ans+=td[4][x];
ans+=dp[0][x][j]*(td[3][x]-dp[3][x][j]);
ans+=dp[4][x][j]*(td[2][x]-dp[2][x][j]);
ans+=dp[3][x][j]*(td[0][x]-dp[0][x][j]);
ans+=dp[2][x][j]*(td[4][x]-dp[4][x][j]);
fw[1][lv[x]].add(tin[lv[x]][u],sig);
fw[1][lv[x]].add(tout[lv[x]][u]+1,-sig);
wf[1][lv[x]].add(tin[lv[x]][u],sig);
if(pr[x]==-1)break;x=pr[x];j=ID[lv[x]][u];
}
}
else {
ans+=td[3][u]*sig;x=pr[u];int j=ID[lv[x]][u];
while(x!=-1){
ans-=dp[0][x][j]*(td[3][x]-dp[3][x][j]);
ans-=dp[4][x][j]*(td[2][x]-dp[2][x][j]);
ans-=dp[3][x][j]*(td[0][x]-dp[0][x][j]);
ans-=dp[2][x][j]*(td[4][x]-dp[4][x][j]);
if(f[x]==1)ans-=td[0][x]*td[2][x]-td[5][x];
if(f[x]==0)ans-=td[3][x];
td[5][x]-=dp[0][x][j]*dp[2][x][j];
td[3][x]-=dp[3][x][j];
td[2][x]+=sig;dp[2][x][j]+=sig;
dp[3][x][j]+=sig*getP(1,lv[x],u);
td[3][x]+=dp[3][x][j];
td[5][x]+=dp[0][x][j]*dp[2][x][j];
if(f[x]==1)ans+=td[0][x]*td[2][x]-td[5][x];
if(f[x]==0)ans+=td[3][x];
ans+=dp[0][x][j]*(td[3][x]-dp[3][x][j]);
ans+=dp[4][x][j]*(td[2][x]-dp[2][x][j]);
ans+=dp[3][x][j]*(td[0][x]-dp[0][x][j]);
ans+=dp[2][x][j]*(td[4][x]-dp[4][x][j]);
fw[2][lv[x]].add(tin[lv[x]][u],sig);
fw[2][lv[x]].add(tout[lv[x]][u]+1,-sig);
wf[2][lv[x]].add(tin[lv[x]][u],sig);
if(pr[x]==-1)break;x=pr[x];j=ID[lv[x]][u];
}
}
}
void change(int X, int Y){
up(f[X],X,-1);up(Y,X,1);
f[X]=Y;
}
long long num_tours() {
return ans;
}
/*int main(){
ios_base::sync_with_stdio(0);cin.tie(0);
int n;cin>>n;
vector<int>ff(n+1),u,v;
for(int i=1;i<=n;i++)ff[i]=i%3;
for(int i=1;i<=n-1;i++){
if(2*i<=n)g[i].pb(2*i),g[2*i].pb(i);
if(2*i+1<=n)g[i].pb(2*i+1),g[2*i+1].pb(i);
}init(n,ff,u,v,0);cout<<num_tours()<<'\n';
int q;cin>>q;
change(2,0);
cout<<num_tours()<<'\n';
/*while(q--){
int x,y;cin>>x>>y;change(x,y);
cout<<num_tours()<<'\n';
}*/
}*/