This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
#include <iostream>
#include <cassert>
#include <cmath>
#include <set>
#include <map>
#include <stack>
#include <queue>
#include <ctime>
#include <set>
#include <algorithm>
#include <functional>
#include <iomanip>
#define ll long long
using namespace std;
vector<vector<int>> adj;
vector<bool> marked;
vector<int> sz;
map<pair<int,int>,int> myMap;
map<pair<int,int>,int> myMap1;
set<int> mySet;
map<int,int> tot[(int)3e5];
int n;
string s;
int cntr = 0;
int get (char c) {
return ((c == '(') ? 1 : -1);
}
int find_centroid (int curNode, int prevNode, int net_sz) {
for (int i: adj[curNode]) {
if (i != prevNode and !marked[i] and sz[i] * 2 >= net_sz) {
return find_centroid(i, curNode, net_sz);
}
}
return curNode;
}
void solve (int root) { //this node is our centroid
std::function<int(int, int)> get_sizes;
get_sizes = [&get_sizes](int curNode, int prevNode) -> int {
assert(!marked[curNode]);
sz[curNode] = 1;
for (int i: adj[curNode]) {
if (i != prevNode and !marked[i]) {
get_sizes(i, curNode);
sz[curNode] += sz[i];
}
}
return sz[curNode];
};
get_sizes(root, root);
if (sz[root] == 1) {
return;
}
int centroid = find_centroid(root, root, sz[root]);
marked[centroid] = true;
int tc = 2;
while (tc--) {
for (int x: mySet) {
tot[x].clear();
}
mySet.clear();
for (int node: adj[centroid]) {
if (!marked[node]) {
myMap.clear();
myMap1.clear();
std::function<void(int,int,int,int, int, int, int)> dfs;
dfs = [&dfs](int curNode, int prevNode, int sm, int prefix1, int prefix2, int centroid, int tc) -> void {
if (sm <= 0) mySet.insert(0 - sm);
myMap[make_pair(sm, prefix1)]++;
if (prefix2 >= 0 and sm + get(s[centroid]) >= 0) myMap1[make_pair(sm, prefix1)]++;
if (prefix2 >= 0 and sm + get(s[centroid]) >= 0 and tot[sm + get(s[centroid])].count(0 - get(s[centroid]) - sm)) {
cntr += tot[sm + get(s[centroid])][0 - get(s[centroid]) - sm] - myMap[make_pair(0 - sm - get(s[centroid]), 0 - get(s[centroid]) - sm)];
}
if (sm <= 0) tot[0 - sm][prefix1]++;
cntr += (sm == 1 and prefix2 == 0 and s[centroid] == ')') * (tc == 0);
cntr += (sm + get(s[centroid]) == 0 and prefix1 == -1 and s[centroid] == '(') * (tc == 0);
for (int i: adj[curNode]) {
if (i != prevNode and !marked[i]) {
dfs(i, curNode, sm + get(s[i]), min(min(prefix1, sm + get(s[i])), 0), min(min(prefix2 + get(s[i]), get(s[i])), 0), centroid, tc);
}
}
};
dfs(node, node, get(s[node]), min(0, get(s[node])), min(0, get(s[node])), centroid, tc);
}
}
reverse(adj[centroid].begin(), adj[centroid].end());
}
for (int i: adj[centroid]) {
if (!marked[i]) {
solve(i);
}
}
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
//freopen("zagrade.in.3j.out", "r", stdin);
//exit(0);
cin >> n >> s;
adj.resize(n);
marked.assign(n, false), sz.resize(n);
for (int i = 0; i < n - 1; i++) {
int u, v;
cin >> u >> v;
--u, --v;
auto add_edge = [&](int u, int v) {
adj[u].push_back(v), adj[v].push_back(u);
};
add_edge(u, v);
}
solve(0);
cout << cntr << '\n';
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |