Submission #651708

#TimeUsernameProblemLanguageResultExecution timeMemory
651708beaconmcPower Plant (JOI20_power)Pypy 3
0 / 100
37 ms18256 KiB
import sys
sys.setrecursionlimit(1000000000)


n = int(input())
if n==1:
    print(1)
    quit()
edges = [[] for i in range(n+1)]
par = [-1 for i in range(n+1)]
visited = [False for i in range(n+1)]
for i in range(n-1):
    a,b = map(int, input().split())
    edges[a].append(b)
    edges[b].append(a)
sus = input()

for i in range(1,n+1):
    if len(edges[i]) == 1:
        root = i
flag = False

def dfs(a):
    global flag
    
    for i in edges[a]:

        if sus[a-1]==sus[i-1] and sus[i-1]=="1":
            flag = True
        if not visited[i]:
            par[i] = a
            visited[i] = True
            dfs(i)
visited[root] = True
dfs(root)
dps = [-1 for i in range(n+1)]
used = [0 for i in range(n+1)]

def dp(a):

    if dps[a] != -1:
        return dps[a]
    if sus[a-1]=="0":
        
        sums = 0
        for i in edges[a]:
            if i == par[a]: continue
            sums += dp(i)
        dps[a] = sums
        
        return sums

    dps[a] = 1
    maxi = 0
    for i in edges[a]:
        if i == par[a]: continue
        maxi += dp(i)

    if dps[a] <= maxi-1:
        used[a] = 1
    dps[a] = max(dps[a], maxi-1)
    return dps[a]

ans = dp(root)
if used[root]:
    ans += 2
elif sus[root-1]=="1":
    ans += 1
if flag:
    print(max(2,ans))
else:
    print(ans)



#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...