제출 #1350955

#제출 시각아이디문제언어결과실행 시간메모리
1350955msab3fZagrade (COI17_zagrade)C++20
100 / 100
373 ms45548 KiB
#include <bits/stdc++.h>
using namespace std;

const int MAX_N = 300'000 + 10;

int n, val[MAX_N], cnt[MAX_N];
vector<int> adj[MAX_N];
long long ans;

int sz[MAX_N];
bool mark[MAX_N];

void size_dfs(int u, int p) {
    sz[u] = 1;

    for (int v : adj[u]) {
        if (!mark[v] && v != p) {
            size_dfs(v, u);
            sz[u] += sz[v];
        }
    }
}

int find_cent(int u, int p, int root) {
    for (int v : adj[u]) {
        if (!mark[v] && v != p && sz[v] > sz[root] / 2) {
            return find_cent(v, u, root);
        }
    }
    return u;
}

void add_dfs(int u, int p, int curr, int mx, vector<int>& rm) {
    curr += val[u];
    mx = max(mx, curr);
    if (mx == curr) {
        ++cnt[curr];
        rm.push_back(curr);
    }
    for (int v : adj[u]) {
        if (!mark[v] && v != p) {
            add_dfs(v, u, curr, mx, rm);
        }
    }
}

void count_dfs(int u, int p, int curr, int mn) {
    curr += val[u];
    mn = min(mn, curr);
    if (mn == curr) {
        ans += cnt[-curr];
    }
    for (int v : adj[u]) {
        if (!mark[v] && v != p) {
            count_dfs(v, u, curr, mn);
        }
    }
}

void decompose(int v) {
    size_dfs(v, -1);
    v = find_cent(v, -1, v);

    vector<int> rm;

    if (val[v] == 1) ++cnt[1];

    for (int u : adj[v]) {
        if (!mark[u]) {
            count_dfs(u, v, 0, 0);
            add_dfs(u, v, val[v], max(0, val[v]), rm);
        }
    }

    ans += cnt[0];

    for (int x : rm) {
        --cnt[x];
    }
    rm.clear();

    if (val[v] == 1) --cnt[1];

    reverse(adj[v].begin(), adj[v].end());

    for (int u : adj[v]) {
        if (!mark[u]) {
            count_dfs(u, v, 0, 0);
            add_dfs(u, v, val[v], max(0, val[v]), rm);
        }
    }

    for (int x : rm) {
        --cnt[x];
    }

    mark[v] = true;

    for (int u : adj[v]) {
        if (!mark[u]) {
            decompose(u);
        }
    }
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr), cout.tie(nullptr);

    cin >> n;

    for (int u = 1; u <= n; ++u) {
        char c;
        cin >> c;
        if (c == '(') val[u] = +1;
        else val[u] = -1;
    }

    for (int e = 1, u, v; e <= n - 1; ++e) {
        cin >> u >> v;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }

    decompose(1);

    cout << ans << '\n';
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...