#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
typedef vector<int> vii;
typedef vector<ll> vll;
typedef vector<pii> vpii;
typedef vector<pll> vpll;
#define pb push_back
#define eb emplace_back
#define upb upper_bound
#define lpb lower_bound
#define ppb pop_back
#define X first
#define Y second
#define all(a) a.begin(), a.end()
#define len(a) (int) (a.size())
const ll MOD = 1e9 + 7;
const ll BASE = 32;
const int MAXN = 3e5 + 7;
int n, stsz[MAXN];
vii G[MAXN];
char z[MAXN];
bool used[MAXN];
int paths[MAXN];
ll ans = 0;
int getstsz(int i, int p = -1) {
stsz[i] = 1;
for(auto nx : G[i])
if(nx != p && !used[nx])
stsz[i] += getstsz(nx, i);
return stsz[i];
}
int centroid(int i, int sz, int p = -1) {
int sum = 0;
for(auto nx : G[i])
if(p == -1 && !used[nx]) sum += getstsz(nx, i);
if(p == -1) stsz[i] = sum + 1;
for(auto nx : G[i]) {
if(!used[nx] && nx != p && stsz[nx] * 2 > sz) {
stsz[i] -= stsz[nx];
return centroid(nx, sz, i);
}
}
return i;
}
void process(int i, int p, bool cnt, int delta, int mindelta, int md) {
delta += (z[i] == '(' ? 1 : -1);
mindelta += (z[i] == '(' ? 1 : -1);
mindelta = min(0, mindelta);
md = min(md, delta);
if(cnt) {
if(delta <= 0 && delta == md) ans += paths[-delta];
} else {
if(mindelta >= 0 && delta >= 0) paths[delta]++;
}
for(auto nx : G[i])
if(!used[nx] && nx != p)
process(nx, i, cnt, delta, mindelta, md);
}
void solve(int i) {
int s = getstsz(i);
int cent = centroid(i, getstsz(i));
used[cent] = 1;
int delta = (z[cent] == '(' ? 1 : -1);
paths[0]++;
for(auto nx : G[cent]) {
if(used[nx]) continue;
process(nx, cent, true, delta, delta, delta);
process(nx, cent, false, 0, 0, 0);
}
if(delta < 0) ans += paths[-delta];
for(int i = 0; i <= s; i++)
paths[i] = 0;
reverse(all(G[cent]));
for(auto nx : G[cent]) {
if(used[nx]) continue;
process(nx, cent, true, delta, delta, delta);
process(nx, cent, false, 0, 0, 0);
}
for(int i = 0; i <= s; i++)
paths[i] = 0;
for(auto nx : G[cent])
if(!used[nx])
solve(nx);
}
void task() {
cin >> n;
for(int i = 0; i < n; i++)
cin >> z[i];
for(int i = 0; i < n - 1; i++) {
int a, b; cin >> a >> b;
G[--a].pb(--b);
G[b].pb(a);
}
solve(0);
cout << ans << '\n';
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0); cout.tie(0);
int tt = 1;
while(tt--)
task();
return 0;
}
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |