#include <bits/stdc++.h>
#define sz(x) ((int)(x).size())
const int N = 2e5 + 5;
using namespace std;
typedef long long ll;
map<int, ll> dp[N];
vector<int> adj[N];
int H[N], C[N], n;
void merge(map<int, ll>& a, map<int, ll>& b) {
if (sz(a) < sz(b))
swap(a, b);
for (auto [i, d]: b)
a[i] += d;
}
void dfs(int u) {
for (int v: adj[u]) {
dfs(v);
merge(dp[u], dp[v]);
}
ll val = C[u];
auto it = dp[u].lower_bound(H[u]);
while (it != dp[u].begin()) {
it = prev(it);
val -= it->second;
if (val < 0) {
it->second = -val;
break;
}
it = dp[u].erase(it);
}
dp[u][H[u]] += C[u];
}
int main() {
cin >> n;
ll ans = 0;
vector<int> arr(n), nx(n);
for (int i = 0, p; i < n; i++) {
cin >> nx[i] >> H[i] >> C[i];
nx[i]--;
if (nx[i] != i)
adj[nx[i]].push_back(i);
arr[i] = H[i];
ans += C[i];
}
sort(arr.begin(), arr.end());
arr.erase(unique(arr.begin(), arr.end()), arr.end());
for (int i = 0; i < n; i++)
H[i] = lower_bound(arr.begin(), arr.end(), H[i]) - arr.begin();
vector<int> roots, vis(n, 0);
for (int i = 0; i < n; i++) {
if (vis[i])
continue;
int j = i;
while (!vis[j]) {
vis[j] = i + 1;
j = nx[j];
}
if (vis[j] != i + 1)
continue;
vector<int> cycle;
while (vis[j] == i + 1) {
cycle.push_back(j);
vis[j] = -1;
j = nx[j];
}
sort(cycle.begin(), cycle.end(), [&](int u, int v) {
return H[u] < H[v];
});
if (sz(cycle) > 1) {
auto it = adj[cycle[0]].begin();
while (it != adj[cycle[0]].end()) {
if (vis[*it] == -1) {
adj[cycle[0]].erase(it);
break;
}
it = next(it);
}
}
for (j = 1; j < sz(cycle); j++) {
for (int k: adj[cycle[j]])
if (vis[k] != -1)
adj[cycle[0]].push_back(k);
adj[cycle[j]].clear();
adj[cycle[j]].push_back(cycle[j - 1]);
}
roots.push_back(cycle.back());
}
for (int r: roots) {
dfs(r);
for (auto [i, d]: dp[r]) {
ans -= d;
}
}
cout << ans << "\n";
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |