Submission #292887

#TimeUsernameProblemLanguageResultExecution timeMemory
292887limabeansPower Plant (JOI20_power)C++17
100 / 100
427 ms51896 KiB
#include <bits/stdc++.h>
using namespace std;

template<typename T>
void out(T x) { cout << x << endl; exit(0); }
#define watch(x) cout << (#x) << " is " << (x) << endl





using ll = long long;

const ll mod = 1e9+7;
const int maxn = 200000 + 100;

// Assume answer is rooted at node and join the children, test all nodes via re-rooting trick.
// Key observation: a node's contribution upwards is either itself, or disjoint children in it's subtrees.

int n;
vector<int> g[maxn];
string s;

struct vec {
    multiset<int> ms;
    int tot=0;
    void insert(int x) {
	ms.insert(x);
	tot += x;
    }
    void erase(int x) {
	ms.erase(ms.find(x));
	tot -= x;
    }
    int sum() {
	return tot;
    }
    bool empty() {
	return ms.empty();
    }
    int max() {
	return *ms.rbegin();
    }
};

vec dp[maxn];
int cache[maxn];

int dfs(int at, int p) {
    for (int to: g[at]) {
	if (to==p) continue;
	int rec=dfs(to,at);
	if (rec>0) {
	    dp[at].insert(rec);
	}
    }
    int res=int(s[at]-'0'); // self
    if (!dp[at].empty()) {
	res=max(res, dp[at].sum()-int(s[at]-'0'));
    }
    return cache[at]=res;
}

int best=0;
void solve(int at, int p) {
    // compute subtree[at]
    int cur = int(s[at]-'0');
    if (!dp[at].empty()) {
	if (s[at]=='1') {
	    best = max(best, 1+dp[at].max());
	}
	cur = max(cur, dp[at].sum()-int(s[at]-'0'));
    }

    best = max(best, cur);
    

    ////////////////////////////////

    for (int to: g[at]) {
	if (to != p) {
	    if (cache[to]) {
		dp[at].erase(cache[to]);
	    }
	    int fromP = int(s[at]-'0');
	    
	    if (!dp[at].empty()) {
		fromP = max(fromP, dp[at].sum()-int(s[at]-'0'));
	    }
	    //cout<<at+1<<"->"<<to+1<<": "<<fromP<<endl;
	    
	    dp[to].insert(fromP);
	    solve(to, at);
	    dp[to].erase(fromP);

	    
	    if (cache[to]) {
		dp[at].insert(cache[to]);
	    }
	}
    }

}

int main() {
    ios_base::sync_with_stdio(false); cin.tie(0);  cout.tie(0);

    cin>>n;
    for (int i=0; i<n-1; i++) {
	int u,v; cin>>u>>v;
	--u; --v;
	g[u].push_back(v);
	g[v].push_back(u);
    }
    
    cin>>s;
    dfs(0,-1);
    solve(0,-1);
    cout<<best<<endl;    
    return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...