#include <bits/stdc++.h>
using namespace std;
const int nx=3e5+5;
#define ll long long
ll n, u, v, sz[nx], used[nx], cnt[nx], res;
char c[nx];
vector<ll> d[nx];
int dfssz(int u, int p)
{
sz[u]=1;
for (auto v:d[u]) if (v!=p&&!used[v]) sz[u]+=dfssz(v, u);
return sz[u];
}
int findcentroid(int u, int p, int rtsz)
{
for (auto v:d[u]) if (v!=p&&!used[v]&&2*sz[v]>rtsz) return findcentroid(v, u, rtsz);
return u;
}
void dfsquery(int u, int p, int mn, int cur)
{
if (cur==mn&&cur<=0) res+=cnt[-cur];
for (auto v:d[u])
{
if (v==p||used[v]) continue;
int newcur=cur+(c[v]=='('?1:-1);
dfsquery(v, u, min(mn, newcur), newcur);
}
}
void dfsadd(int u, int p, int mn, int cur, int vl)
{
if (mn>=0) cnt[cur]+=vl;
for (auto v:d[u])
{
if (v==p||used[v]) continue;
int newcur=cur+(c[v]=='('?1:-1);
dfsadd(v, u, min(mn+(c[v]=='('?1:-1), 0), newcur, vl);
}
}
void decomposition(int u)
{
u=findcentroid(u, u, dfssz(u, u));
used[u]=1;
if (c[u]=='(') cnt[1]++;
for (int i=0; i<d[u].size(); i++)
{
v=d[u][i];
if (used[v]) continue;
dfsquery(v, u, c[v]=='('?1:-1, c[v]=='('?1:-1);
int mn=c[u]=='('?0:-1, cur=c[u]=='('?1:-1;
int newcur=cur+(c[v]=='('?1:-1);
dfsadd(v, u, min(mn+(c[v]=='('?1:-1), 0), newcur, 1);
}
if (c[u]=='(') cnt[1]--;
for (int i=0; i<d[u].size(); i++)
{
v=d[u][i];
if (used[v]) continue;
int mn=c[u]=='('?0:-1, cur=c[u]=='('?1:-1;
int newcur=cur+(c[v]=='('?1:-1);
dfsadd(v, u, min(mn+(c[v]=='('?1:-1), 0), newcur, -1);
}
for (int i=d[u].size()-1; i>=0; i--)
{
v=d[u][i];
if (used[v]) continue;
dfsquery(v, u, c[v]=='('?1:-1, c[v]=='('?1:-1);
int mn=c[u]=='('?0:-1, cur=c[u]=='('?1:-1;
int newcur=cur+(c[v]=='('?1:-1);
dfsadd(v, u, min(mn+(c[v]=='('?1:-1), 0), newcur, 1);
}
res+=cnt[0];
for (int i=d[u].size()-1; i>=0; i--)
{
v=d[u][i];
if (used[v]) continue;
int mn=c[u]=='('?0:-1, cur=c[u]=='('?1:-1;
int newcur=cur+(c[v]=='('?1:-1);
dfsadd(v, u, min(mn+(c[v]=='('?1:-1), 0), newcur, -1);
}
//for (int i=0; i<=5; i++) cout<<"c "<<i<<' '<<cnt[i]<<'\n';
//cout<<"centroid "<<u<<' '<<res<<'\n';
for (auto v:d[u]) if (!used[v]) decomposition(v);
}
int main()
{
cin.tie(NULL)->sync_with_stdio(false);
cin>>n;
for (int i=1; i<=n; i++) cin>>c[i];
for (int i=1; i<n; i++) cin>>u>>v, d[u].push_back(v), d[v].push_back(u);
decomposition(1);
cout<<res;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |