Submission #1160511

#TimeUsernameProblemLanguageResultExecution timeMemory
1160511mychecksedadJOI tour (JOI24_joitour)C++20
100 / 100
1363 ms390088 KiB
#include "joitour.h"
#include<bits/stdc++.h>
using namespace std;
#define ll long long int
#define MOD (1000000000+7)
#define MOD1 (998244353)
#define pb push_back
#define all(x) x.begin(), x.end()
#define en cout << '\n'
#define ff first
#define ss second
#define pii pair<int,int>
#define vi vector<int>
const int N = 2e5+100;


struct Fenwick{
  int n;
  vector<int> t;
  void init(int _n){
    n = _n;
    t.clear();
    t.resize(n+1, 0);
  }
  void add(int v, int x){
    while(v <= n){
      t[v] += x;
      v += (v&-v);
    }
  }
  int get(int v){
    int res = 0;
    while(v > 0){
      res += t[v];
      v -= (v&-v);
    }
    return res;
  }
};

int tp[N], s[N], TIN[N], TOUT[N], timer = 1;
vector<array<ll, 5>> dp[N]; // 0, 1, 2, 01, 21
ll ans, bad[N];
vi g[N];
bitset<N> vis;
Fenwick T[N][3];
vector<array<int, 5>> C[N];


ll calc(int v, int c, int tpp, bool flag){
  ll sum = 0;
  if(flag){
    if(tpp == 1){
      sum = dp[c][v][0] * dp[c][v][2];
    }else if(tpp == 0){
      sum = dp[c][v][4];
    }else{
      sum = dp[c][v][3];
    }
  }
  return sum + dp[c][v][2]*dp[c][v][3] + dp[c][v][0]*dp[c][v][4]; // 2*10 + 0*12
}

void pre(int v, int p){
  s[v] = 1;
  for(int u: g[v]){
    if(!vis[u] && u != p){
      pre(u, v);
      s[v] += s[u];
    }
  }
}
int num;
int centro(int v, int p){
  for(int u: g[v]){
    if(!vis[u] && u != p && s[u] >= (num+1)/2) return centro(u, v);
  }
  return v;
}

void dfs(int v, int p, int c, int d){
  TIN[v] = timer++;
  dp[c].pb(array<ll,5>{0ll});
  for(int u: g[v]){
    if(u != p && !vis[u]){
      dfs(u, v, c, (v == c ? u : d));
      for(int j = 0; j < 5; ++j){
        dp[c][TIN[v]][j] += dp[c][TIN[u]][j];
      }
      if(tp[v] == 1 && v != c) dp[c][TIN[v]][3] += dp[c][TIN[u]][0], dp[c][TIN[v]][4] += dp[c][TIN[u]][2];
    }
  }
  TOUT[v] = timer - 1;

  C[v].pb({c, d, TIN[v], TOUT[v], TIN[d]});

  if(v != c){
    T[c][tp[v]].add(TIN[v], 1);
    if(tp[v] == 1) T[c][tp[v]].add(TOUT[v] + 1, -1);
  
    dp[c][TIN[v]][tp[v]]++;
  }
}


void f(int v){
  pre(v, v);
  num = s[v];
  v = centro(v, v);
  timer = 1;
  T[v][0].init(num);  
  T[v][1].init(num);  
  T[v][2].init(num);  

  dp[v].clear();
  dp[v].pb(array<ll,5>{0ll});
  dfs(v, v, v, v);

  ans += calc(1, v, tp[v], true);
  bad[v] = 0;
  for(int u: g[v]){
    if(!vis[u]){
      ans -= calc(TIN[u], v, tp[u], false);
      bad[v] += dp[v][TIN[u]][0] * dp[v][TIN[u]][2];
    }
  }
  if(tp[v] == 1) ans -= bad[v];


  vis[v] = 1;
  for(int u: g[v]){
    if(!vis[u]) f(u);
  }
}

void init(int n, std::vector<int> F, std::vector<int> U, std::vector<int> V, int Q) {
  for(int i = 0; i + 1 < n; ++i){
    g[U[i]].pb(V[i]);
    g[V[i]].pb(U[i]);
  }
  for(int i = 1; i <= n; ++i){
    tp[i - 1] = F[i - 1];
  } 
  ans = 0;
  f(0);
}

void change(int v, int Y) {
  int A = tp[v];
  int B = Y;
  for(auto [centro, low, tin, tout, low_tin]: C[v]){
    // cout << v << ' ' << centro << ' ' << low << ' ' << tin << ' ' << tout << ' ' << low_tin << '\n';
    // cout << T[centro][2].get(tout) << "f\n";
    if(v == centro){
      ans -= calc(1, v, A, true);
      if(A == 1) ans += bad[v];

      ans += calc(1, v, B, true);
      if(B == 1) ans -= bad[v];
      continue;
    }
    // removal
    ans -= calc(1, centro, tp[centro], true);
    if(tp[centro] == 1) ans += bad[centro];

    ans += calc(low_tin, centro, tp[low], false);
    bad[centro] -= dp[centro][low_tin][0] * dp[centro][low_tin][2];


    // changes - removal
    dp[centro][1][A]--;
    if(A == 0){
      dp[centro][1][3] -= T[centro][1].get(tin);
    }else if(A == 1){
      dp[centro][1][3] -= T[centro][0].get(tout) - T[centro][0].get(tin - 1);
      dp[centro][1][4] -= T[centro][2].get(tout) - T[centro][2].get(tin - 1);
    }else{
      dp[centro][1][4] -= T[centro][1].get(tin);
    }

    dp[centro][low_tin][A]--;
    if(A == 0){
      dp[centro][low_tin][3] -= T[centro][1].get(tin);
    }else if(A == 1){
      dp[centro][low_tin][3] -= T[centro][0].get(tout) - T[centro][0].get(tin - 1);
      dp[centro][low_tin][4] -= T[centro][2].get(tout) - T[centro][2].get(tin - 1);
    }else{
      dp[centro][low_tin][4] -= T[centro][1].get(tin);
    }

    T[centro][A].add(tin, -1);
    if(A == 1) T[centro][A].add(tout + 1, 1);

    // changes - addition


    T[centro][B].add(tin, 1);
    if(B == 1) T[centro][B].add(tout + 1, -1);

    dp[centro][1][B]++;
    if(B == 0){
      dp[centro][1][3] += T[centro][1].get(tin);
    }else if(B == 1){
      dp[centro][1][3] += T[centro][0].get(tout) - T[centro][0].get(tin - 1);
      // cout << tin << ' ' << tout << ' ' <<  << "f\n";
      dp[centro][1][4] += T[centro][2].get(tout) - T[centro][2].get(tin - 1);
    }else{
      dp[centro][1][4] += T[centro][1].get(tin);
    }

    dp[centro][low_tin][B]++;
    if(B == 0){
      dp[centro][low_tin][3] += T[centro][1].get(tin);
    }else if(B == 1){
      dp[centro][low_tin][3] += T[centro][0].get(tout) - T[centro][0].get(tin - 1);
      dp[centro][low_tin][4] += T[centro][2].get(tout) - T[centro][2].get(tin - 1);
    }else{
      dp[centro][low_tin][4] += T[centro][1].get(tin);
    }

    

    // recalculation
    ans -= calc(low_tin, centro, tp[low], false);
    bad[centro] += dp[centro][low_tin][0] * dp[centro][low_tin][2];

    ans += calc(1, centro, tp[centro], true);
    if(tp[centro] == 1) ans -= bad[centro];
  }
  tp[v] = Y;
}

long long num_tours() {
  return ans;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...