제출 #864530

#제출 시각아이디문제언어결과실행 시간메모리
864530serifefedartarZagrade (COI17_zagrade)C++17
10 / 100
3008 ms40380 KiB
#include <bits/stdc++.h>
using namespace std;
 
#define fast ios::sync_with_stdio(0);cin.tie(0);
#define s second
#define f first
typedef long long ll;
const ll MOD = 1e9 + 7;
const ll LOGN = 20; 
const ll MAXN = 3e5 + 100;

vector<vector<int>> graph;
string s;
int marked[MAXN], sz[MAXN];
map<int,int> cnt;
ll ans = 0;
int get_sz(int node, int parent) {
	sz[node] = 1;
	for (auto u : graph[node]) {
		if (u == parent)
			continue;
		sz[node] += get_sz(u, node);
	}
	return sz[node];
}

int find_centro(int node, int parent, int n) {
	for (auto u : graph[node]) {
		if (u != parent && !marked[u] && sz[u] * 2 >= n)
			return find_centro(u, node, n);
	}
	return node;
}

void dfs(int node, int parent, int sum, int mx, int mn) {
	sum += (s[node] == '(' ? 1 : -1);
	mx = max(mx, sum);
	mn = min(mn, sum);

	/*
		( ) ) ) (
		1 0 -1 -2 -1
		sum != mn olduğunda sağlamıyor.
	*/

	if (sum == mn)
		ans += cnt[-mn];

	for (auto u : graph[node]) {
		if (u == parent || marked[u])
			continue;
		dfs(u, node, sum, mx, mn);
	}
}

void add(int node, int parent, int sum, int mx, int mn) {
	sum += (s[node] == '(' ? 1 : -1);
	mx = max(mx, sum);
	mn = min(mn, sum);

	if (sum == mx)
		cnt[sum]++; 

	/*
		( ) ) ) ( (
		1 0 -1 -2 -1 0
		sum = 0
		mx = 1
	*/

	for (auto u : graph[node]) {
		if (u == parent || marked[u])
			continue;
		add(u, node, sum, mx, mn);
	}
}

void solve(int node) {
	int n = get_sz(node, node);
	int centro = find_centro(node, node, n);

	cnt.clear();
	marked[centro] = true;
	if (s[centro] == '(')
		cnt[1]++;

	for (auto u : graph[centro]) {
		if (marked[u])
			continue;
		dfs(u, centro, 0, 0, 0);
		int val = (s[centro] == '(' ? 1 : -1);
		add(u, centro, val, max(0, val), min(0, val));
	}

	ans += cnt[0];

	cnt.clear();
	reverse(graph[centro].begin(), graph[centro].end());
	for (auto u : graph[centro]) {
		if (marked[u])
			continue;
		dfs(u, centro, 0, 0, 0);
		int val = (s[centro] == '(' ? 1 : -1);
		add(u, centro, val, max(0, val), min(0, val));
	}

	for (auto u : graph[centro]) {
		if (!marked[u])
			solve(u);
	}
}

int main() {
	fast
	int n, a, b;
	cin >> n >> s;
	s = "#" + s;

	graph = vector<vector<int>>(n+1, vector<int>());
	for (int i = 1; i < n; i++) {
		cin >> a >> b;
		graph[a].push_back(b);
		graph[b].push_back(a);
	}
	solve(1);
	cout << ans << "\n";
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...