This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
import sys
from collections import defaultdict
from sys import setrecursionlimit, stdin, stdout
setrecursionlimit(1000000)
class DSU:
def __init__(self, n):
self.parent = list(range(n))
def find(self, a):
acopy = a
while a != self.parent[a]:
a = self.parent[a]
while acopy != a:
self.parent[acopy], acopy = a, self.parent[acopy]
return a
def merge(self, a, b):
self.parent[self.find(b)] = self.find(a)
def setup(v, parent):
for p in graf[v]:
if p == parent:
continue
depth[p] = depth[v] + 1
up[p] = v
setup(p, v)
n, k = map(int, input().split())
graf = [[] for i in range(n+1)]
dsu = DSU(n)
up = [0] * n
depth = [0] * n
for _ in range(n - 1):
a, b = map(int, input().split())
a-=1
b-=1
graf[a].append(b)
graf[b].append(a)
setup(0, -1)
groups = [[] for i in range(n+1)]
for i in range(n):
a = int(input()) - 1
groups[a].append(i)
for group in groups:
for j in range(1, len(group)):
dsu.merge(group[0], group[j])
sol = 0
deg = [0] * n
for i in range(n):
for p in graf[i]:
if dsu.find(i) != dsu.find(p):
if deg[dsu.find(i)] == 0:
sol += 1
if deg[dsu.find(i)] == 1:
sol -= 1
deg[dsu.find(i)] += 1
print((sol + 1) // 2)
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |