Submission #232815

#TimeUsernameProblemLanguageResultExecution timeMemory
232815shayan_pZagrade (COI17_zagrade)C++14
100 / 100
1754 ms47756 KiB
// Never let them see you bleed...

#include<bits/stdc++.h>

#define F first
#define S second
#define PB push_back
#define sz(s) int((s).size())
#define bit(n,k) (((n)>>(k))&1)

using namespace std;

typedef long long ll;
typedef pair<int,int> pii;

const int maxn = 3e5 + 10, mod = 1e9 + 7, inf = 1e9 + 10;

vector<int> v[maxn];

int a[maxn], SZ[maxn];
bool mark[maxn];

void dfsSZ(int u, int par = -1){
    SZ[u] = 1;
    for(int y : v[u]){
	if(!mark[y] && par != y)
	    dfsSZ(y, u), SZ[u]+= SZ[y];
    }
}
int dfsCenter(int u, int N, int par = -1){
    for(int y : v[u]){
	if(!mark[y] && par != y && SZ[y] > N)
	    return dfsCenter(y, N, u);	    
    }
    return u;
}

int cnt[2 * maxn];

int val[maxn], pr[maxn], up[maxn];
ll ANS;

void calc(int u, int bad, int task, int sum, int par = 0){
    pr[u] = par;
    val[u] = val[par] + a[u];
    up[u] = (a[u] == bad ? u : up[pr[up[par]]]);
    if(task == 0)	    
	cnt[maxn + val[u]] = 0;
    if(up[u] == 0){
	if(task == 1)
	    cnt[maxn + val[u]]++;
	if(task == -1)
	    ANS+= cnt[maxn + sum - val[u]];
    }
    for(int y : v[u])
	if(!mark[y] && par != y)
	    calc(y, bad, task, sum, u);	
}

void solve(int r){
    dfsSZ(r);
    r = dfsCenter(r, SZ[r]/2);
    mark[r] = 1;
    if(a[r] == 1)
	cnt[maxn]++;    
    for(int u : v[r]){
	if(mark[u])
	    continue;
	calc(u, 1, -1, -a[r]);
	calc(u, -1, 1, 0);	
    }
    if(a[r] == 1)
	cnt[maxn] = 0;
    for(int u : v[r]){
	if(mark[u])
	    continue;
	calc(u, -1, 0, 0);
    }

    if(a[r] == -1)
	cnt[maxn]++;
    for(int u : v[r]){
	if(mark[u])
	    continue;
	calc(u, -1, -1, -a[r]);
	calc(u, 1, 1, 0);
    }
    if(a[r] == -1)
	cnt[maxn] = 0;    
    for(int u : v[r]){
	if(mark[u])
	    continue;
	calc(u, 1, 0, 0);
    }

    for(int u : v[r]){
	if(!mark[u])
	    solve(u);
    }
}

int main(){
    ios_base::sync_with_stdio(false); cin.tie(0); cout.tie();

    int n;
    cin >> n;
    string s;
    cin >> s;	
    for(int i = 0; i < n; i++){
	a[i+1] = (s[i] == '(' ? 1 : -1);
    }
    for(int i = 0; i < n-1; i++){
	int a, b;
	cin >> a >> b;
	v[a].PB(b);
	v[b].PB(a);
    }
    solve(1);
    return cout << ANS << endl, 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...