Submission #1131449

#TimeUsernameProblemLanguageResultExecution timeMemory
1131449huutuanJOI tour (JOI24_joitour)C++20
100 / 100
1658 ms128808 KiB
#include "joitour.h"

#include <bits/stdc++.h>

using namespace std;

#define int long long

struct Node{
   int cntp[2];
   int cntc[2];
   int lazyp[2];
   int lazyc[2];
   int mul;
   int size;
   Node (){
      memset(cntp, 0, sizeof cntp);
      memset(cntc, 0, sizeof cntc);
      memset(lazyp, 0, sizeof lazyp);
      memset(lazyc, 0, sizeof lazyc);
      mul=0;
      size=0;
   }
   void merge(const Node &tl, const Node &tr){
      for (int i=0; i<2; ++i){
         cntp[i]=(tl.cntp[i]*(tl.size==1?tl.mul:1))+(tr.cntp[i]*(tr.size==1?tr.mul:1));
         cntc[i]=(tl.cntc[i]*(tl.size==1?tl.mul:1))+(tr.cntc[i]*(tr.size==1?tr.mul:1));
         lazyp[i]=lazyc[i]=0;
      }
      mul=tl.mul+tr.mul;
   }
};

struct SegmentTree{
   int n;
   vector<Node> t;
   void init(int _n){
      n=_n;
      t.assign(4*n+1, Node());
   }
   void build(int k, int l, int r){
      if (l==r){
         t[k].size=1;
         return;
      }
      int mid=(l+r)>>1;
      build(k<<1, l, mid);
      build(k<<1|1, mid+1, r);
      t[k].size=t[k<<1].size+t[k<<1|1].size;
   }
   void apply(int k, int lazyp[2], int lazyc[2]){
      for (int i=0; i<2; ++i){
         t[k].lazyp[i]+=lazyp[i];
         t[k].cntp[i]+=lazyp[i]*(t[k].size==1?1:t[k].mul);
         t[k].lazyc[i]+=lazyc[i];
         t[k].cntc[i]+=lazyc[i]*(t[k].size==1?1:t[k].mul);
      }
   }
   void push(int k){
      apply(k<<1, t[k].lazyp, t[k].lazyc);
      apply(k<<1|1, t[k].lazyp, t[k].lazyc);
      for (int i=0; i<2; ++i) t[k].lazyp[i]=t[k].lazyc[i]=0;
   }
   void update_val(int k, int l, int r, int pos, const Node &val){
      if (l==r){
         t[k]=val;
         return;
      }
      push(k);
      int mid=(l+r)>>1;
      if (pos<=mid) update_val(k<<1, l, mid, pos, val);
      else update_val(k<<1|1, mid+1, r, pos, val);
      t[k].merge(t[k<<1], t[k<<1|1]);
   }
   void update_mul(int k, int l, int r, int pos, int mul){
      if (l==r){
         t[k].mul=mul;
         return;
      }
      push(k);
      int mid=(l+r)>>1;
      if (pos<=mid) update_mul(k<<1, l, mid, pos, mul);
      else update_mul(k<<1|1, mid+1, r, pos, mul);
      t[k].merge(t[k<<1], t[k<<1|1]);
   }
   void update(int k, int l, int r, int L, int R, int lazyp[2], int lazyc[2]){
      if (r<L || R<l) return;
      if (L<=l && r<=R){
         apply(k, lazyp, lazyc);
         return;
      }
      push(k);
      int mid=(l+r)>>1;
      update(k<<1, l, mid, L, R, lazyp, lazyc);
      update(k<<1|1, mid+1, r, L, R, lazyp, lazyc);
      t[k].merge(t[k<<1], t[k<<1|1]);
   }
   Node ans1, ans2;
   void get(int k, int l, int r, int L, int R){
      if (r<L || R<l) return;
      if (L<=l && r<=R){
         ans2.merge(ans1, t[k]);
         ans1=ans2;
         return;
      }
      push(k);
      int mid=(l+r)>>1;
      get(k<<1, l, mid, L, R);
      get(k<<1|1, mid+1, r, L, R);
   }
   void reset_ans(){
      ans1=Node(); ans2=Node();
   }
} st;

const int N=2e5+10;
int n, a[N], cnt[N][3], cntp[N][3];
int cnt_all[3];
vector<int> g[N];
vector<int> cntc[N][2];
int sumc[N];
int sum, par[N], sz[N], head[N], tin[N], tout[N], tdfs;

void dfs_sz(int u, int p){
   sz[u]=1;
   par[u]=p;
   if (p) g[u].erase(find(g[u].begin(), g[u].end(), p));
   for (int &v:g[u]){
      dfs_sz(v, u);
      sz[u]+=sz[v];
      if (sz[v]>sz[g[u][0]]) swap(v, g[u][0]);
   }
   if (g[u].size()) sort(g[u].begin()+1, g[u].end());
}

void dfs(int u, int h){
   tin[u]=++tdfs;
   head[u]=h;
   ++cnt[u][a[u]];
   for (int v:g[u]){
      dfs(v, v==g[u][0]?h:v);
      for (int i=0; i<3; ++i) cnt[u][i]+=cnt[v][i];
      if (v!=g[u][0]){
         cntc[u][0].push_back(cnt[v][0]);
         cntc[u][1].push_back(cnt[v][2]);
         sumc[u]+=cnt[v][0]*cnt[v][2];
      }
   }
   tout[u]=tdfs;
}

int lazyp[2], lazyc[2];

void update_manual(int u, int v, int type, int val){
   int i=lower_bound(g[u].begin()+1, g[u].end(), v)-g[u].begin()-1;
   if (a[u]==1) sum-=sumc[u];
   sumc[u]-=cntc[u][0][i]*cntc[u][1][i];
   cntc[u][type][i]+=val;
   sumc[u]+=cntc[u][0][i]*cntc[u][1][i];
   if (a[u]==1) sum+=sumc[u];
}

void update(int u, int type, int val){
   lazyp[type]=val;
   sum+=st.t[1].cntp[type^1]*val;
   st.update(1, 1, n, 1, n, lazyp, lazyc);
   lazyp[type]=0;
   while (u){
      st.reset_ans();
      st.get(1, 1, n, tin[head[u]], tin[u]);
      sum-=st.ans1.cntp[type^1]*val;
      lazyp[type]=-val;
      st.update(1, 1, n, tin[head[u]], tin[u], lazyp, lazyc);
      lazyp[type]=0;

      lazyc[type]=val;
      st.reset_ans();
      st.get(1, 1, n, tin[head[u]], tin[par[u]]);
      sum+=st.ans1.cntc[type^1]*val;
      st.update(1, 1, n, tin[head[u]], tin[par[u]], lazyp, lazyc);
      lazyc[type]=0;

      if (par[head[u]]) update_manual(par[head[u]], head[u], type, val);
      u=par[head[u]];
   }
}

void init(int32_t _N, vector<int32_t> F, vector<int32_t> U, vector<int32_t> V,
          int32_t Q) {
   n=_N;
   for (int i=1; i<=n; ++i) a[i]=F[i-1];
   for (int i=1; i<n; ++i){
      g[U[i-1]+1].push_back(V[i-1]+1);
      g[V[i-1]+1].push_back(U[i-1]+1);
   }
   memset(cnt, 0, sizeof cnt);
   dfs_sz(1, 0);
   dfs(1, 1);
   for (int u=1; u<=n; ++u){
      ++cnt_all[a[u]];
      for (int i=0; i<3; ++i) cntp[u][i]=cnt[1][i]-cnt[u][i];
      if (a[u]==1){
         sum+=sumc[u];
         sum+=cntp[u][0]*cntp[u][2];
         if (g[u].size()) sum+=cnt[g[u][0]][0]*cnt[g[u][0]][2];
      }
   }
   st.init(n);
   for (int u=1; u<=n; ++u){
      Node val;
      if (g[u].size()){
         val.cntc[0]=cnt[g[u][0]][0];
         val.cntc[1]=cnt[g[u][0]][2];
      }
      val.cntp[0]=cntp[u][0];
      val.cntp[1]=cntp[u][2];
      val.mul=a[u]==1;
      val.size=1;
      st.update_val(1, 1, n, tin[u], val);
   }
   st.build(1, 1, n);
}

int get_val(int u){
   int ans=sumc[u];
   st.reset_ans();
   st.get(1, 1, n, tin[u], tin[u]);
   ans+=st.ans1.cntc[0]*st.ans1.cntc[1];
   ans+=st.ans1.cntp[0]*st.ans1.cntp[1];
   return ans;
}

void change(int32_t X, int32_t Y) {
   int u=X+1, type=Y;
   --cnt_all[a[u]];
   if (a[u]==1){
      sum-=get_val(u);
      st.update_mul(1, 1, n, tin[u], 0);
   }else{
      update(u, a[u]==2, -1);
   }
   a[u]=type;
   ++cnt_all[a[u]];
   if (a[u]==1){
      st.update_mul(1, 1, n, tin[u], 1);
      sum+=get_val(u);
   }else{
      update(u, type==2, 1);
   }
}

long long num_tours() {
   return cnt_all[1]*cnt_all[0]*cnt_all[2]-sum;
}

#ifdef sus
int32_t main() {
  int32_t N;
  assert(scanf("%d", &N) == 1);

  std::vector<int32_t> F(N);
  for (int32_t i = 0; i < N; i++) {
    assert(scanf("%d", &F[i]) == 1);
  }

  std::vector<int32_t> U(N - 1), V(N - 1);
  for (int32_t j = 0; j < N - 1; j++) {
    assert(scanf("%d %d", &U[j], &V[j]) == 2);
  }

  int32_t Q;
  assert(scanf("%d", &Q) == 1);

  init(N, F, U, V, Q);
  printf("%lld\n", num_tours());
  fflush(stdout);

  for (int32_t k = 0; k < Q; k++) {
    int32_t X, Y;
    assert(scanf("%d %d", &X, &Y) == 2);

    change(X, Y);
    printf("%lld\n", num_tours());
    fflush(stdout);
  }
}

#endif
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...