#include<bits/stdc++.h>
using namespace std;
vector<int> ans;
vector<multiset<pair<int, int> > > differentPaths;
vector<bool> power;
vector<vector<int> > adj;
vector<int> inOrder;
int idx = 0;
vector<int> level;
vector<vector<int> > dp;
vector<int> counter;
void dfs2(int u, int p){
inOrder[u] = idx++;
for(int v : adj[u]){
if(v != p){
dp[0][v] = u;
level[v] = level[u] + 1;
dfs2(v, u);
}
}
}
int solve(int x, int k){
int final = 0;
while(k > 0){
if(k % 2 == 1){
x = dp[final][x];
}
final++;
k >>= 1;
}
return x;
}
int lca(int a, int b){
if(level[a] > level[b]){
swap(a, b);
}
int difference = level[b] - level[a];
b = solve(b, difference);
if(a == b){
return a;
}
for(int i = 20; i >= 0; i--){
if(dp[i][a] != dp[i][b]){
a = dp[i][a];
b = dp[i][b];
}
}
return dp[0][a];
}
void dfs(int u, int p){
for(int v : adj[u]){
if(v != p){
dfs(v, u);
ans[u] += ans[v];
counter[u] += counter[v];
if(power[v] && counter[v] <= 1){
differentPaths[u].insert(make_pair(inOrder[v], v));
if(counter[v] == 1){
int thing = (*differentPaths[v].begin()).second;
differentPaths[u].erase(make_pair(inOrder[thing], thing));
continue;
}
counter[u]++;
}
if((int)differentPaths[u].size() < (int)differentPaths[v].size()){
for(pair<int, int> x : differentPaths[u]){
differentPaths[v].insert(x);
}
swap(differentPaths[u], differentPaths[v]);
}
else{
for(pair<int, int> x : differentPaths[v]){
differentPaths[u].insert(x);
}
}
}
}
if(power[u]){
ans[u] = 1 + (int)differentPaths[u].size();
set<int> nodes;
for(pair<int, int> x : differentPaths[u]){
auto it = differentPaths[u].find(x);
auto it2 = it;
it2++;
if(it2 != differentPaths[u].end()){
int hi = lca((*it).second, (*it2).second);
if(power[hi]){
nodes.insert(hi);
}
}
}
ans[u] -= (int)nodes.size();
}
}
int main(){
int n;
cin >> n;
ans.resize(1 + n, 0);
differentPaths.resize(1 + n);
level.resize(1 + n);
dp.resize(21, vector<int>(1 + n));
power.resize(1 + n, false);
adj.resize(1 + n);
inOrder.resize(1 + n);
counter.resize(1 + n);
for(int i = 0; i < n - 1; i++){
int a, b;
cin >> a >> b;
adj[a].push_back(b);
adj[b].push_back(a);
}
string s;
cin >> s;
for(int i = 0; i < n; i++){
if(s[i] == '1'){
power[i + 1] = true;
}
}
dfs2(1, 0);
for(int i = 1; i < 21; i++){
for(int j = 1; j <= n; j++){
dp[i][j] = dp[i - 1][dp[i - 1][j]];
}
}
dfs(1, 0);
cout << ans[1] << "\n";
}
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |