This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
#include "bits/stdc++.h"
using namespace std;
#ifdef Nero
#include "Deb.h"
#else
#define deb(...)
#endif
const int N = 3e5 + 5;
int a[N];
int sz[N];
int cent;
long long ans;
bool block[N];
vector<int> g[N];
int cntp[N], cntn[N];
void get_size(int v, int p) {
sz[v] = 1;
for (int u : g[v]) {
if (!block[u] && u != p) {
get_size(u, v);
sz[v] += sz[u];
}
}
}
int get_centroid(int v, int p, int x) {
int ret = v;
for (int u : g[v]) {
if (!block[u] && u != p && sz[u] > x / 2) {
return get_centroid(u, v, x);
}
}
return ret;
}
void dfs(int v, int p, int sum, int min_pref, int min_suf, int upd) {
//deb(v) deb(cent) deb(sum)
sum += a[v];
//deb(sum) cout << '\n';
min_pref = min(min_pref, sum);
min_suf = min(0, min_suf + a[v]);
if (upd) {
if (sum < 0 && min_pref == sum) {
cntn[-sum] += upd;
//deb(v) deb(cent) deb(sum) cout << '\n';
} else if (sum >= 0 && min_suf == 0) {
//deb(v) deb(cent) deb(sum) cout << '\n';
cntp[sum] += upd;
}
} else {
if (sum == 0 && (min_pref == 0 || min_suf == 0)) {
ans++;
}
int tsum = sum - a[cent];
int tpref = min(0, min_pref - a[cent]);
if (tsum < 0 && tpref == tsum) {
//deb(v) deb(cent) de b(tsum) deb(tpref) deb(cntp[-tsum]) cout << '\n';
ans += cntp[-tsum];
} else if (tsum >= 0 && min_suf == 0) {
//deb(v) deb(cent) cout << '\n';
ans += cntn[tsum];
}
}
for (int u : g[v]) {
if (!block[u] && u != p) {
//deb(v) deb(u) deb(sum) cout << '\n';
dfs(u, v, sum, min_pref, min_suf, upd);
}
}
}
void init(int root) {
get_size(root, root);
cent = get_centroid(root, root, sz[root]);
block[cent] = true;
for (int u : g[cent]) {
if (!block[u]) {
dfs(u, cent, a[cent], min(0, a[cent]), min(0, a[cent]), 0);
dfs(u, cent, a[cent], min(0, a[cent]), min(0, a[cent]), 1);
}
}
for (int u : g[cent]) {
if (!block[u]) {
dfs(u, cent, a[cent], min(0, a[cent]), min(0, a[cent]), -1);
}
}
for (int u : g[cent]) {
if (!block[u]) {
init(u);
}
}
}
int main(){
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n;
cin >> n;
for (int i = 1; i <= n; ++i) {
char c;
cin >> c;
a[i] = (c == ')' ? -1 : 1);
}
for(int i = 0; i < n - 1; ++i) {
int u, v;
cin >> u >> v;
g[u].push_back(v);
g[v].push_back(u);
}
init(1);
cout << ans << '\n';
return 0;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |