#include <bits/stdc++.h>
using namespace std;
#define int long long
#define ld long double
#define show(x,y) cout << y << " " << #x << endl;
#define show2(x,y,i,j) cout << y << " " << #x << " " << j << " " << #i << endl;
#define show3(x,y,i,j,p,q) cout << y << " " << #x << " " << j << " " << #i << " " << q << " " << #p << endl;
#define show4(x,y) for(auto it:y) cout << it << " "; cout << #x << endl;
typedef pair<int,int>pii;
typedef pair<pii,int>pi2;
mt19937_64 rng(chrono::system_clock::now().time_since_epoch().count());
vector<int>adj[300005];
int arr[300005];
int lvl[300005];
int sz[300005];
void dfs(int index, int par){
sz[index]=1;
for(auto it:adj[index]){
if(it==par) continue;
if(lvl[it]!=-1) continue;
dfs(it,index);
sz[index]+=sz[it];
}
}
int dfs2(int index, int par, int n){
for(auto it:adj[index]){
if(it==par) continue;
if(lvl[it]!=-1) continue;
if(sz[it]>n/2){
return dfs2(it,index,n);
}
}
return index;
}
int counter=0;
map<int,int>mp;
void down(int index, int par, int cur, int mini){
cur+=arr[index];
mini=min(mini,cur);
if(mini==cur)counter+=mp[-mini];
for(auto it:adj[index]){
if(it==par) continue;
if(lvl[it]!=-1) continue;
down(it,index,cur,mini);
}
}
vector<int>undo;
void up(int index, int par, int cur, int mini){
cur+=arr[index];
mini+=arr[index];
mini=min(mini,0LL);
if(mini>=0){
mp[cur]++;
undo.push_back(cur);
}
for(auto it:adj[index]){
if(it==par) continue;
if(lvl[it]!=-1) continue;
up(it,index,cur,mini);
}
}
void build(int index, int par, int l){
dfs(index,par);
int cent=dfs2(index,par,sz[index]);
//counting
if(arr[cent]==1){
mp[1]=1;
undo.push_back(1);
}
for(auto it:adj[cent]){
if(lvl[it]!=-1) continue;
down(it,cent,0,0);
up(it,cent,arr[cent],min(0LL,arr[cent]));
}
counter+=mp[0];
for(auto it:undo) mp[it]--;
undo.clear();
for(int y=adj[cent].size()-1;y>=0;y--){
int it=adj[cent][y];
if(lvl[it]!=-1) continue;
down(it,cent,arr[cent],min(0LL,arr[cent]));
up(it,cent,0,0);
}
for(auto it:undo) mp[it]--;
undo.clear();
lvl[cent]=l;
//show2(cent,cent,counter,counter);
for(auto it:adj[cent]){
if(lvl[it]!=-1) continue;
build(it,cent,l+1);
}
}
void solve(){
int n;
string s;
cin >> n >> s;
for(int x=0;x<n;x++){
if(s[x]=='(') arr[x+1]=1;
else arr[x+1]=-1;
}
int temp,temp2;
for(int x=0;x<n-1;x++){
cin >> temp >> temp2;
adj[temp].push_back(temp2);
adj[temp2].push_back(temp);
}
memset(lvl,-1,sizeof(lvl));
build(1,-1,0);
cout << counter;
}
int32_t main(){
ios::sync_with_stdio(0);
cin.tie(0);
//freopen("in.txt","r",stdin);
//freopen("in.txt","w",stdout);
int t=1;
//cin >> t;
while(t--){
solve();
}
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |