#include "joitour.h"
#include <bits/stdc++.h>
using namespace std;
int n = 0;
vector<int> parent;
vector<vector<int>> adja;
vector<vector<int>> enfants;
vector<bool> visited;
vector<int> shop;
struct Node {
long long rep = 0;
long long nb0 = 0;
long long nb1 = 0;
long long nb2 = 0;
long long nb10 = 0;
long long nb12 = 0;
};
vector<Node> dp;
void combine(int node) {
dp[node].nb0 = 0;
dp[node].nb1 = 0;
dp[node].nb2 = 0;
for (int enfant : enfants[node]) {
dp[node].nb0 += dp[enfant].nb0;
dp[node].nb1 += dp[enfant].nb1;
dp[node].nb2 += dp[enfant].nb2;
}
if (shop[node] == 0)
dp[node].nb0 ++;
if (shop[node] == 1)
dp[node].nb1 ++;
if (shop[node] == 2)
dp[node].nb2 ++;
dp[node].nb10 = 0;
dp[node].nb12 = 0;
for (int enfant : enfants[node]) {
dp[node].nb10 += dp[enfant].nb10;
dp[node].nb12 += dp[enfant].nb12;
}
if (shop[node] == 1) {
dp[node].nb10 += dp[node].nb0;
dp[node].nb12 += dp[node].nb2;
}
dp[node].rep = 0;
for (int enfant : enfants[node]) {
dp[node].rep += dp[enfant].rep;
}
if (shop[node] == 1) {
long long lhs = 0, rhs = 0;
long long enlever = 0;
for (int enfant : enfants[node]) {
lhs += dp[enfant].nb0;
rhs += dp[enfant].nb2;
enlever += dp[enfant].nb0 * dp[enfant].nb2;
}
dp[node].rep += (lhs * rhs) - enlever;
}
if (shop[node] == 0) {
for (int enfant : enfants[node]) {
dp[node].rep += dp[enfant].nb12;
}
}
if (shop[node] == 2) {
for (int enfant : enfants[node]) {
dp[node].rep += dp[enfant].nb10;
}
}
long long gauche, droite, enlever;
gauche = 0; droite = 0; enlever = 0;
for (int enfant : enfants[node]) {
gauche += dp[enfant].nb0;
droite += dp[enfant].nb12;
enlever += dp[enfant].nb0 * dp[enfant].nb12;
}
dp[node].rep += (gauche * droite) - enlever;
gauche = 0; droite = 0; enlever = 0;
for (int enfant : enfants[node]) {
gauche += dp[enfant].nb2;
droite += dp[enfant].nb10;
enlever += dp[enfant].nb2 * dp[enfant].nb10;
}
dp[node].rep += (gauche * droite) - enlever;
return;
}
void dfs(int node) {
for (int voi : adja[node]) {
if (! visited[voi]) {
visited[voi] = true;
parent[voi] = node;
enfants[node].push_back(voi);
dfs(voi);
}
}
return;
}
void regne(int node) {
for (int enfant : enfants[node]) {
regne(enfant);
}
combine(node);
}
void init(int N, vector<int> f, vector<int> u, vector<int> v, int q) {
n = N;
dp.resize(n);
adja.resize(n);
parent.resize(n);
shop.resize(n);
enfants.resize(n);
visited.assign(n, false);
for (int i = 0; i + 1 < n; i ++) {
adja[u[i]].push_back(v[i]);
adja[v[i]].push_back(u[i]);
}
visited[0] = true;
dfs(0);
for (int i = 0; i < n; i ++) {
shop[i] = f[i];
}
regne(0);
}
void change(int x, int y) {
shop[x] = y;
for (int cur = x; cur != parent[cur]; cur = parent[cur]) {
combine(cur);
}
combine(0);
}
long long num_tours() {
//cout << dp[0].rep << " " << dp[0].nb0 << " " << dp[0].nb1 << " " << dp[0].nb2 << '\n';
return dp[0].rep;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |