Submission #844849

#TimeUsernameProblemLanguageResultExecution timeMemory
844849CookieZagrade (COI17_zagrade)C++14
100 / 100
789 ms49748 KiB
#include<bits/stdc++.h>
#define ll long long
#define vt vector
#define pb push_back
#define pii pair<int, int>
#define sz(v) (int)v.size()
using namespace std;
const ll base = 107, mod = 1e9 + 7;
const int mxn = 3e5 + 5, inf = 1e9, sq = 400;
int n;
int val[mxn + 1], sz[mxn + 1];
vt<int>adj[mxn + 1];
map<int, int>cnt;
bool vis[mxn + 1];
int dfs(int s, int pre){
    sz[s] = 1;
    for(auto i: adj[s]){
        if(i != pre && !vis[i]){
            sz[s] += dfs(i, s);
        }
    }
    return(sz[s]);
}
int centroid(int s, int pre, int need){
    for(auto i: adj[s]){
        if(i != pre && !vis[i]){
            if(sz[i] * 2 > need)return(centroid(i, s, need));
        }
    }
    return(s);
}
ll ans = 0;
void dfs2(int s, int pre, int pref, int mxpref, int mnpref, bool add){
    pref += val[s]; mxpref = max(mxpref, pref); mnpref = min(mnpref, pref);
    if(add){
        if(pref - mxpref >= 0){
            cnt[pref]++; 
        }
    }else{
        if(pref - mnpref <= 0){
            assert(pref == mnpref);
            ans += 1LL * cnt[-mnpref];
        }
    }
    for(auto i: adj[s]){
        if(i != pre && !vis[i]){
            dfs2(i, s, pref, mxpref, mnpref, add);
        }
    }
}

void build(int s){
    int c = centroid(s, -1, dfs(s, -1));
    vis[c] = 1;
    cnt.clear(); 
   
    if(val[c] != -1)cnt[val[c]]++;
    for(auto i: adj[c]){
        if(!vis[i]){
            dfs2(i, c, 0, 0, 0, 0);
            dfs2(i, c, val[c], max(0, val[c]), min(0, val[c]), 1);
        }
    }
    ans += 1LL * cnt[0];
    cnt.clear();
    reverse(adj[c].begin(), adj[c].end());
    for(auto i: adj[c]){
        if(!vis[i]){
            dfs2(i, c, 0, 0, 0, 0);
            dfs2(i, c, val[c], max(0, val[c]), min(0, val[c]), 1);
        }
    }
    for(auto i: adj[c]){
        if(!vis[i]){
            build(i);
        }
    }
}
int main(){
    ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0);
    cin >> n;
    for(int i = 1; i <= n; i++){
        char c; cin >> c;
        if(c == '(')val[i] = 1;
        else val[i] = -1;
    }
    for(int i = 1; i < n; i++){
        int u, v; cin >> u >> v;
        adj[u].pb(v); adj[v].pb(u);
    }
    build(1);
    cout << ans;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...