Submission #1131449

#TimeUsernameProblemLanguageResultExecution timeMemory
1131449huutuanJOI tour (JOI24_joitour)C++20
100 / 100
1658 ms128808 KiB
#include "joitour.h" #include <bits/stdc++.h> using namespace std; #define int long long struct Node{ int cntp[2]; int cntc[2]; int lazyp[2]; int lazyc[2]; int mul; int size; Node (){ memset(cntp, 0, sizeof cntp); memset(cntc, 0, sizeof cntc); memset(lazyp, 0, sizeof lazyp); memset(lazyc, 0, sizeof lazyc); mul=0; size=0; } void merge(const Node &tl, const Node &tr){ for (int i=0; i<2; ++i){ cntp[i]=(tl.cntp[i]*(tl.size==1?tl.mul:1))+(tr.cntp[i]*(tr.size==1?tr.mul:1)); cntc[i]=(tl.cntc[i]*(tl.size==1?tl.mul:1))+(tr.cntc[i]*(tr.size==1?tr.mul:1)); lazyp[i]=lazyc[i]=0; } mul=tl.mul+tr.mul; } }; struct SegmentTree{ int n; vector<Node> t; void init(int _n){ n=_n; t.assign(4*n+1, Node()); } void build(int k, int l, int r){ if (l==r){ t[k].size=1; return; } int mid=(l+r)>>1; build(k<<1, l, mid); build(k<<1|1, mid+1, r); t[k].size=t[k<<1].size+t[k<<1|1].size; } void apply(int k, int lazyp[2], int lazyc[2]){ for (int i=0; i<2; ++i){ t[k].lazyp[i]+=lazyp[i]; t[k].cntp[i]+=lazyp[i]*(t[k].size==1?1:t[k].mul); t[k].lazyc[i]+=lazyc[i]; t[k].cntc[i]+=lazyc[i]*(t[k].size==1?1:t[k].mul); } } void push(int k){ apply(k<<1, t[k].lazyp, t[k].lazyc); apply(k<<1|1, t[k].lazyp, t[k].lazyc); for (int i=0; i<2; ++i) t[k].lazyp[i]=t[k].lazyc[i]=0; } void update_val(int k, int l, int r, int pos, const Node &val){ if (l==r){ t[k]=val; return; } push(k); int mid=(l+r)>>1; if (pos<=mid) update_val(k<<1, l, mid, pos, val); else update_val(k<<1|1, mid+1, r, pos, val); t[k].merge(t[k<<1], t[k<<1|1]); } void update_mul(int k, int l, int r, int pos, int mul){ if (l==r){ t[k].mul=mul; return; } push(k); int mid=(l+r)>>1; if (pos<=mid) update_mul(k<<1, l, mid, pos, mul); else update_mul(k<<1|1, mid+1, r, pos, mul); t[k].merge(t[k<<1], t[k<<1|1]); } void update(int k, int l, int r, int L, int R, int lazyp[2], int lazyc[2]){ if (r<L || R<l) return; if (L<=l && r<=R){ apply(k, lazyp, lazyc); return; } push(k); int mid=(l+r)>>1; update(k<<1, l, mid, L, R, lazyp, lazyc); update(k<<1|1, mid+1, r, L, R, lazyp, lazyc); t[k].merge(t[k<<1], t[k<<1|1]); } Node ans1, ans2; void get(int k, int l, int r, int L, int R){ if (r<L || R<l) return; if (L<=l && r<=R){ ans2.merge(ans1, t[k]); ans1=ans2; return; } push(k); int mid=(l+r)>>1; get(k<<1, l, mid, L, R); get(k<<1|1, mid+1, r, L, R); } void reset_ans(){ ans1=Node(); ans2=Node(); } } st; const int N=2e5+10; int n, a[N], cnt[N][3], cntp[N][3]; int cnt_all[3]; vector<int> g[N]; vector<int> cntc[N][2]; int sumc[N]; int sum, par[N], sz[N], head[N], tin[N], tout[N], tdfs; void dfs_sz(int u, int p){ sz[u]=1; par[u]=p; if (p) g[u].erase(find(g[u].begin(), g[u].end(), p)); for (int &v:g[u]){ dfs_sz(v, u); sz[u]+=sz[v]; if (sz[v]>sz[g[u][0]]) swap(v, g[u][0]); } if (g[u].size()) sort(g[u].begin()+1, g[u].end()); } void dfs(int u, int h){ tin[u]=++tdfs; head[u]=h; ++cnt[u][a[u]]; for (int v:g[u]){ dfs(v, v==g[u][0]?h:v); for (int i=0; i<3; ++i) cnt[u][i]+=cnt[v][i]; if (v!=g[u][0]){ cntc[u][0].push_back(cnt[v][0]); cntc[u][1].push_back(cnt[v][2]); sumc[u]+=cnt[v][0]*cnt[v][2]; } } tout[u]=tdfs; } int lazyp[2], lazyc[2]; void update_manual(int u, int v, int type, int val){ int i=lower_bound(g[u].begin()+1, g[u].end(), v)-g[u].begin()-1; if (a[u]==1) sum-=sumc[u]; sumc[u]-=cntc[u][0][i]*cntc[u][1][i]; cntc[u][type][i]+=val; sumc[u]+=cntc[u][0][i]*cntc[u][1][i]; if (a[u]==1) sum+=sumc[u]; } void update(int u, int type, int val){ lazyp[type]=val; sum+=st.t[1].cntp[type^1]*val; st.update(1, 1, n, 1, n, lazyp, lazyc); lazyp[type]=0; while (u){ st.reset_ans(); st.get(1, 1, n, tin[head[u]], tin[u]); sum-=st.ans1.cntp[type^1]*val; lazyp[type]=-val; st.update(1, 1, n, tin[head[u]], tin[u], lazyp, lazyc); lazyp[type]=0; lazyc[type]=val; st.reset_ans(); st.get(1, 1, n, tin[head[u]], tin[par[u]]); sum+=st.ans1.cntc[type^1]*val; st.update(1, 1, n, tin[head[u]], tin[par[u]], lazyp, lazyc); lazyc[type]=0; if (par[head[u]]) update_manual(par[head[u]], head[u], type, val); u=par[head[u]]; } } void init(int32_t _N, vector<int32_t> F, vector<int32_t> U, vector<int32_t> V, int32_t Q) { n=_N; for (int i=1; i<=n; ++i) a[i]=F[i-1]; for (int i=1; i<n; ++i){ g[U[i-1]+1].push_back(V[i-1]+1); g[V[i-1]+1].push_back(U[i-1]+1); } memset(cnt, 0, sizeof cnt); dfs_sz(1, 0); dfs(1, 1); for (int u=1; u<=n; ++u){ ++cnt_all[a[u]]; for (int i=0; i<3; ++i) cntp[u][i]=cnt[1][i]-cnt[u][i]; if (a[u]==1){ sum+=sumc[u]; sum+=cntp[u][0]*cntp[u][2]; if (g[u].size()) sum+=cnt[g[u][0]][0]*cnt[g[u][0]][2]; } } st.init(n); for (int u=1; u<=n; ++u){ Node val; if (g[u].size()){ val.cntc[0]=cnt[g[u][0]][0]; val.cntc[1]=cnt[g[u][0]][2]; } val.cntp[0]=cntp[u][0]; val.cntp[1]=cntp[u][2]; val.mul=a[u]==1; val.size=1; st.update_val(1, 1, n, tin[u], val); } st.build(1, 1, n); } int get_val(int u){ int ans=sumc[u]; st.reset_ans(); st.get(1, 1, n, tin[u], tin[u]); ans+=st.ans1.cntc[0]*st.ans1.cntc[1]; ans+=st.ans1.cntp[0]*st.ans1.cntp[1]; return ans; } void change(int32_t X, int32_t Y) { int u=X+1, type=Y; --cnt_all[a[u]]; if (a[u]==1){ sum-=get_val(u); st.update_mul(1, 1, n, tin[u], 0); }else{ update(u, a[u]==2, -1); } a[u]=type; ++cnt_all[a[u]]; if (a[u]==1){ st.update_mul(1, 1, n, tin[u], 1); sum+=get_val(u); }else{ update(u, type==2, 1); } } long long num_tours() { return cnt_all[1]*cnt_all[0]*cnt_all[2]-sum; } #ifdef sus int32_t main() { int32_t N; assert(scanf("%d", &N) == 1); std::vector<int32_t> F(N); for (int32_t i = 0; i < N; i++) { assert(scanf("%d", &F[i]) == 1); } std::vector<int32_t> U(N - 1), V(N - 1); for (int32_t j = 0; j < N - 1; j++) { assert(scanf("%d %d", &U[j], &V[j]) == 2); } int32_t Q; assert(scanf("%d", &Q) == 1); init(N, F, U, V, Q); printf("%lld\n", num_tours()); fflush(stdout); for (int32_t k = 0; k < Q; k++) { int32_t X, Y; assert(scanf("%d %d", &X, &Y) == 2); change(X, Y); printf("%lld\n", num_tours()); fflush(stdout); } } #endif
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...