#include <bits/stdc++.h>
using namespace std;
#define mp make_pair
#define pb push_back
#define len(a) (int)a.size()
#define fi first
#define sc second
#define d1(w) cerr<<#w<<":"<<w<<endl;
#define d2(w,c) cerr<<#w<<":"<<w<<" "<<#c<<":"<<c<<endl;
#define d3(w,c,z) cerr<<#w<<":"<<w<<" "<<#c<<":"<<c<<" "<<#z<<":"<<z<<endl;
#define left ind+ind
#define right ind+ind+1
#define mid (l+r)/2
#define FAST_IO ios_base::sync_with_stdio(false);
#define endl '\n'
typedef long long int ll;
const int maxn = 620;
const long long LINF = 1e18;
const int LOG = 31;
const int INF = 1e9;
const int N = 3e5 + 5;
const int M = 25;
const int SQ = 350;
const int MOD = 998244353;
typedef pair <int,int> pii;
vector <int> ed[N],vec[N],vec2[N],v;
long long int ans = 0;
int n,sub[N],a[N],vis[N],mark[N],mark2[N];
int calc(int cur,int back = -1) {
sub[cur] = 0;
for (auto i : ed[cur])
if (i != back && !vis[i])
sub[cur] += calc(i,cur);
return ++sub[cur];
}
int find(int cur,int back,int size) {
for (auto i : ed[cur])
if (i != back && !vis[i] && sub[i] > size / 2)
return find(i,cur,size);
return cur;
}
void dfs(int cur,int back) {
vec[cur].clear();
vec2[cur].clear();
for (auto i : ed[cur]) {
if (i == back || vis[i] == 1) continue;
dfs(i,cur);
for (auto j : vec[i]) if(j + a[cur] > 0) vec[cur].pb(j + a[cur]);
for (auto j : vec2[i]) if(j - a[cur] > 0) vec2[cur].pb(j - a[cur]);
}
if (a[cur] > 0) vec[cur].pb(a[cur]);
else vec2[cur].pb(-a[cur]);
}
void dfs2(int cur,int back,int sm,int mn,int sm2,int mn2) {
sm += a[cur];
sm2 -= a[cur];
mn = min(mn,sm);
mn2 = min(mn2,sm2);
if (sm == mn && sm <= 0) ans += mark[-sm];
if (sm2 == mn2 && sm2 <= 0) ans += mark2[-sm2];
for (auto i : ed[cur]) {
if (i == back || vis[i] == 1) continue;
dfs2(i,cur,sm,mn,sm2,mn2);
}
}
void solve(int cur = 1) {
int cen = find(cur,-1,calc(cur));
vis[cen] = 1;
for (auto i : ed[cen]) {
if (vis[i]) continue;
dfs2(i,-1,0,INF,0,INF);
dfs(i,-1);
for (auto j : vec[i]) if (j + a[cen] >= 0) {mark[j + a[cen]]++; if (j + a[cen] == 0) ans++; v.pb(j + a[cen]);}
for (auto j : vec2[i]) if (j - a[cen] >= 0) { mark2[j - a[cen]]++; if (j - a[cen] == 0) ans++; v.pb(j - a[cen]);}
}
for (auto i : v) mark[i] = mark2[i] = 0;
v.clear();
for (auto i : ed[cen])
if (!vis[i])
solve(i);
}
int main() {
scanf("%d",&n);
for (int i = 1 ; i <= n; i++) {
char c;
scanf(" %c",&c);
if (c == '(') a[i] = 1;
else a[i] = -1;
}
for (int i = 1 ; i < n ; i++) {
int u,v;
scanf("%d %d",&u,&v);
ed[u].pb(v);
ed[v].pb(u);
}
solve();
printf("%lld\n",ans);
return 0 ;
}
Compilation message
zagrade.cpp: In function 'int main()':
zagrade.cpp:103:7: warning: ignoring return value of 'int scanf(const char*, ...)', declared with attribute warn_unused_result [-Wunused-result]
scanf("%d",&n);
~~~~~^~~~~~~~~
zagrade.cpp:107:8: warning: ignoring return value of 'int scanf(const char*, ...)', declared with attribute warn_unused_result [-Wunused-result]
scanf(" %c",&c);
~~~~~^~~~~~~~~~
zagrade.cpp:114:8: warning: ignoring return value of 'int scanf(const char*, ...)', declared with attribute warn_unused_result [-Wunused-result]
scanf("%d %d",&u,&v);
~~~~~^~~~~~~~~~~~~~~
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
22 ms |
21624 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Execution timed out |
3085 ms |
837236 KB |
Time limit exceeded |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
22 ms |
21624 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |