제출 #1235780

#제출 시각아이디문제언어결과실행 시간메모리
1235780kilikumaJOI tour (JOI24_joitour)C++20
36 / 100
3012 ms53568 KiB
#include "joitour.h"
#include <bits/stdc++.h>

using namespace std; 

int n = 0;  
vector<int> parent; 
vector<vector<int>> adja; 
vector<vector<int>> enfants; 
vector<bool> visited; 
vector<int> shop;

struct Node {

  long long rep = 0; 
  long long nb0 = 0; 
  long long nb1 = 0; 
  long long nb2 = 0; 
  long long nb10 = 0; 
  long long nb12 = 0; 

}; 

vector<Node> dp; 

void combine(int node) {
  dp[node].nb0 = 0; 
  dp[node].nb1 = 0; 
  dp[node].nb2 = 0; 
  for (int enfant : enfants[node]) {
    dp[node].nb0 += dp[enfant].nb0; 
    dp[node].nb1 += dp[enfant].nb1; 
    dp[node].nb2 += dp[enfant].nb2; 
  }
  if (shop[node] == 0)
    dp[node].nb0 ++; 
  if (shop[node] == 1)
    dp[node].nb1 ++;
  if (shop[node] == 2) 
    dp[node].nb2 ++; 
  dp[node].nb10 = 0; 
  dp[node].nb12 = 0; 
  for (int enfant : enfants[node]) {
    dp[node].nb10 += dp[enfant].nb10; 
    dp[node].nb12 += dp[enfant].nb12; 
  }
  if (shop[node] == 1) {
    dp[node].nb10 += dp[node].nb0; 
    dp[node].nb12 += dp[node].nb2; 
  }
  dp[node].rep = 0; 
  for (int enfant : enfants[node]) {
    dp[node].rep += dp[enfant].rep; 
  }
  if (shop[node] == 1) {
    long long lhs = 0, rhs = 0; 
    long long enlever = 0; 
    for (int enfant : enfants[node]) {
      lhs += dp[enfant].nb0; 
      rhs += dp[enfant].nb2;
      enlever += dp[enfant].nb0 * dp[enfant].nb2;  
    }
    dp[node].rep += (lhs * rhs) - enlever; 
  }
  if (shop[node] == 0) {
    for (int enfant : enfants[node]) {
      dp[node].rep += dp[enfant].nb12; 
    }
  }
  if (shop[node] == 2) {
    for (int enfant : enfants[node]) {
      dp[node].rep += dp[enfant].nb10; 
    }
  }
  long long gauche, droite, enlever; 
  gauche = 0; droite = 0; enlever = 0; 
  for (int enfant : enfants[node]) {
    gauche += dp[enfant].nb0; 
    droite += dp[enfant].nb12; 
    enlever += dp[enfant].nb0 * dp[enfant].nb12; 
  } 
  dp[node].rep += (gauche * droite) - enlever; 
  gauche = 0; droite = 0; enlever = 0; 
  for (int enfant : enfants[node]) {
    gauche += dp[enfant].nb2; 
    droite += dp[enfant].nb10; 
    enlever += dp[enfant].nb2 * dp[enfant].nb10; 
  }
  dp[node].rep += (gauche * droite) - enlever; 
  return; 
}

void dfs(int node) {
  for (int voi : adja[node]) {
    if (! visited[voi]) {
      visited[voi] = true; 
      parent[voi] = node; 
      enfants[node].push_back(voi); 
      dfs(voi);
    }
  }
  return; 
}

void regne(int node) {
  for (int enfant : enfants[node]) {
    regne(enfant); 
  }
  combine(node); 
}

void init(int N, vector<int> f, vector<int> u, vector<int> v, int q) {
  n = N; 
  dp.resize(n); 
  adja.resize(n); 
  parent.resize(n); 
  shop.resize(n); 
  enfants.resize(n); 
  visited.assign(n, false); 
  for (int i = 0; i + 1 < n; i ++) {
    adja[u[i]].push_back(v[i]); 
    adja[v[i]].push_back(u[i]); 
  }
  visited[0] = true; 
  dfs(0); 
  for (int i = 0; i < n; i ++) {
    shop[i] = f[i]; 
  }
  regne(0); 
}

void change(int x, int y) {
  shop[x] = y; 
  for (int cur = x; cur != parent[cur]; cur = parent[cur]) {
    combine(cur); 
  }
  combine(0); 
}

long long num_tours() {
  //cout << dp[0].rep << " " << dp[0].nb0 << " " << dp[0].nb1 << " " << dp[0].nb2 << '\n'; 
  return dp[0].rep; 
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...