This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
#include <iostream>
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 3e5+5;
int sz[MAXN];
int ans;
int n;
vector<int> v1[MAXN];
bool blocked[MAXN];
int currsz;
map<int,int> m1;
map<int,int> m2;
map<int,int> ansm1;
map<int,int> ansm2;
string s;
void dfs(int curr,int par){
currsz++;
sz[curr] = 1;
for(int x:v1[curr]){
if(x!=par && !blocked[x]){
dfs(x,curr);
sz[curr]+=sz[x];
}
}
}
int findcentroid(int curr,int par){
for(int x:v1[curr]){
if(blocked[x]||x==par){
continue;
}
if(sz[x]>currsz/2){
return findcentroid(x,curr);
}
}
return curr;
}
void dfsans(int curr,int par,int min1,int curr1,int min2,int curr2){
// cout<<curr1<<endl;
if(blocked[curr]){
return;
}
if(s[curr] == '('){
if(min1 == 0){
m1[curr1]++;
}
}else{
if(min2 == curr2){
m2[-curr2]++;
}
}
for(int x:v1[curr]){
if(x==par||blocked[curr]){
continue;
}
if(s[x] == '('){
dfsans(x,curr,min(0,min1+1),curr1+1,min2,curr2+1);
}else{
dfsans(x,curr,min(-1,min1-1),curr1-1,min(min2,curr2-1),curr2-1);
}
}
}
void solve(int curr){
ansm1.clear();
ansm2.clear();
ansm1[0]++;
ansm2[0]++;
for(int x:v1[curr]){
if(blocked[x]){
continue;
}
m1.clear();
m2.clear();
if(s[x] == '('){
dfsans(x,curr,0,1,0,1);
}else{
dfsans(x,curr,-1,-1,-1,-1);
}
//cout<<s[x]<<endl;
if(s[curr]=='('){
for(auto y:m1){
ans+=1LL*m1[y.first]*ansm2[y.first+1];
}
for(auto y:m2){
ans+=1LL*m2[y.first]*ansm1[y.first-1];
}
}else{
for(auto y:m1){
ans+=1LL*m1[y.first]*ansm2[y.first-1];
}
for(auto y:m2){
// cout<<y.first<<endl;
ans+=1LL*m2[y.first]*ansm1[y.first+1];
}
}
for(auto y:m1){
ansm1[y.first]+=y.second;
}
for(auto y:m2){
ansm2[y.first]+=y.second;
}
}
}
void decompose(int curr,int par){
currsz = 0;
dfs(curr,curr);
int cent = findcentroid(curr,curr);
solve(cent);
//cout<<cent<<" "<<ans<<endl;
blocked[cent] = true;
for(int x:v1[cent]){
//cout<<x<<endl;
if(!blocked[x]){
//cout<<x<<endl;
decompose(x,cent);
}
}
blocked[cent] = false;
}
int main(){
cin>>n;
cin>>s;
s='#'+s;
for(int i=1;i<n;i++){
int u,v;
cin>>u>>v;
v1[u].push_back(v);
v1[v].push_back(u);
//cout<<u<<" "<<v<<endl;
}
decompose(1,1);
cout<<ans<<endl;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |