This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
#include <bits/stdc++.h>
#include <iostream>
#include <cstdlib>
#include <iomanip>
#include <vector>
#include <cmath>
#include <assert.h>
#include <ctime>
#include <math.h>
#include <queue>
#include <string>
#include <numeric>
#include <fstream>
#include <set>
#include <unordered_map>
#include <unordered_set>
#include <map>
#include <stack>
#include <random>
#include <list>
#include <bitset>
#include <algorithm>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace std;
using namespace __gnu_pbds;
typedef long long ll;
typedef long double ld;
typedef unsigned long long ull;
#define pll pair<ll, ll>
#define pii pair<int, int>
#define pdd pair<ld, ld>
#define ff first
#define ss second
#define all(v) v.begin(),v.end()
typedef tree<
    pii,
    null_type,
    less<pii>,
    rb_tree_tag,
    tree_order_statistics_node_update> ordset;
#pragma GCC optimize("-O3")
#pragma GCC optimize("unroll-loops")
#pragma GCC optimize("-Os")
ll INF = 2e18;
//mt19937 gen(time(0));
ll gcd(ll n, ll m){
    while(n != 0 && m != 0){
        if(n > m) n %= m;
        else m %= n;
    }
    return n + m;
}
ll lcm(ll n, ll m){
    ll nod = gcd(n, m);
    return n / nod * m;
}
ll mod = 1e9 + 7;
ll binpow(ll n, ll m){
    if(m == 0) return 1;
    if(m % 2ll == 1) {
        return (binpow(n, m - 1ll) * 1ll * n) % mod;
    }
    ll b = binpow(n, m / 2);
    return (b * 1ll * b) % mod;
}
vector<vector<int>> a;
vector<int> dp;
int ans = 0;
string s;
void dfs(int v, int p){
    int mx = 0;
    for(int u : a[v]){
        if(u == p) continue;
        dfs(u, v);
        dp[v] += dp[u];
        mx = max(mx, dp[u]);
    }
    if(s[v] == '0'){
        ans = max(ans, dp[v]);
        return;
    }
    ans = max(ans, mx + 1);
    dp[v] = max(1, dp[v] - 1);
    ans = max(ans, dp[v]);
}
void solve(){
    int n;
    cin >> n;
    a.resize(n);
    dp.resize(n, 0);
    for(int i = 0; i < n - 1; i++){
        int u, v;
        cin >> u >> v;
        u--; v--;
        a[u].push_back(v);
        a[v].push_back(u);
    }
    cin >> s;
    int cnt = 0;
    for(int i = 0; i < n; i++) if(s[i] == '1') cnt++;
    ans = min(cnt, 2);
    dfs(0, -1);
    cout << ans << '\n';
}
int main() {
    ios_base::sync_with_stdio(0);
    cin.tie(0);
	int tt;
    //cin >> tt;
	tt = 1;
	while (tt--) {
		solve();
	}
	return 0;
}
| # | Verdict | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict | Execution time | Memory | Grader output | 
|---|
| Fetching results... |