Submission #761581

# Submission time Handle Problem Language Result Execution time Memory
761581 2023-06-20T04:42:40 Z gun_gan Zagrade (COI17_zagrade) C++17
0 / 100
136 ms 89848 KB
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

const int MX = 3e5 + 5, inf = 1e9;

int N;
string s;
vector<int> g[MX];

int par[MX], bal[MX];
set<pair<int, int>> path[MX][2]; // open, close
// (balance, count)

int find(int v) {
      return par[v] == v ? v : par[v] = find(par[v]);
}

ll ans = 0;

void dfs(int v, int p) {
      for(auto u : g[v]) {
            if(u == p) continue;
            // build 
            bal[u] = bal[v] + (s[u] == '(' ? 1 : -1);
            dfs(u, v);
      }

      for(auto u : g[v]) {
            if(u == p) continue; 

            // erase invalid
            if(s[v] == ')') {
                  int pu = find(u);
                  auto it = path[pu][0].lower_bound({bal[v], -inf});
                  if(it != path[pu][0].end() && it->first == bal[v]) {
                        path[pu][0].erase(it);
                  } 

                  // it = path[pu][1].lower_bound({bal[v], -inf});
                  // if(it != path[pu][1].end() && it->first == bal[v]) {
                  //       cout << "erased\n";
                  //       path[pu][1].erase(it);
                  // } 
            }

      }

      // straight path
      if(s[v] == ')') {
            for(auto u : g[v]) {
                  if(u == p) continue;

                  int pu = find(u);
                  auto it = path[pu][0].lower_bound({bal[p], -inf});
                  if(it != path[pu][0].end() && it->first == bal[p]) {
                        ans += it->second;
                  }
            }
      } else {
            for(auto u : g[v]) {
                  if(u == p) continue;

                  int pu = find(u);
                  auto it = path[pu][1].lower_bound({bal[p], -inf});
                  if(it != path[pu][1].end() && it->first == bal[p]) {
                        ans += it->second;
                  }
            }
      }

      // calc answer, merging
      for(auto u : g[v]) {
            if(u == p) continue;
            int pu = find(u), pv = find(v);

            // cout << u << " " << v << " " << pu << " " << pv << '\n';

            if((path[pu][0].size() + path[pu][1].size()) > (path[pv][0].size() + path[pv][1].size())) {
                  swap(pu, pv);
            }

            for(auto [b, c] : path[pu][0]) {
                  int k = b - bal[p];
                  k = -k + bal[v];
                  auto it = path[pv][1].lower_bound({k, -inf});
                  if(it != path[pv][1].end() && it->first == k) {
                        ans += 1LL * c * it->second;
                  }
            }

            for(auto [b, c] : path[pu][1]) {
                  int k = b - bal[p];
                  k = -k + bal[v];
                  auto it = path[pv][0].lower_bound({k, -inf});
                  if(it != path[pv][0].end() && it->first == k) {
                        ans += 1LL * c * it->second;
                  }
            }

            // cout << "merging " << u << " and "  << v << '\n';

            par[pu] = pv;
      }

      // add new path
      if(s[v] == '(') {
            int pv = find(v);
            auto it = path[pv][0].lower_bound({bal[v], -inf});
            if(it != path[pv][0].end() && it->first == bal[v]) {
                  path[pv][0].insert({it->first, it->second + 1});
                  path[pv][0].erase(*it);
            } else {
                  path[pv][0].insert({bal[v], 1});
            }
      } else {
            int pv = find(v);
            auto it = path[pv][1].lower_bound({bal[v], -inf});
            if(it != path[pv][1].end() && it->first == bal[v]) {
                  path[pv][1].insert({it->first, it->second + 1});
                  path[pv][1].erase(*it);
            } else {
                  path[pv][1].insert({bal[v], 1});
            }
      }
}     

int main() {
      ios_base::sync_with_stdio(0); cin.tie(0);

      cin >> N >> s;
      s = '#' + s;

      for(int i = 0; i < N - 1; i++) {
            int u, v;
            cin >> u >> v;
            g[u].push_back(v);
            g[v].push_back(u);
      }

      for(int i = 1; i <= N; i++) par[i] = i;

      bal[1] = (s[1] == '(' ? 1 : -1);
      dfs(1, 0);

      // for(int i = 1; i <= N; i++) cout << bal[i] << " "; cout << '\n';

      cout << ans << '\n';
}
# Verdict Execution time Memory Grader output
1 Incorrect 19 ms 35648 KB Output isn't correct
2 Halted 0 ms 0 KB -
# Verdict Execution time Memory Grader output
1 Incorrect 136 ms 89848 KB Output isn't correct
2 Halted 0 ms 0 KB -
# Verdict Execution time Memory Grader output
1 Incorrect 19 ms 35648 KB Output isn't correct
2 Halted 0 ms 0 KB -