#include <bits/stdc++.h>
using namespace std;
#ifdef LOCAL
#include "debug.h"
#else
#define debug(...) 42
#endif
#define int long long
#define fi first
#define se second
#define L(i, j, k) for (int i = (j); i <= (k); i++)
#define R(i, j, k) for (int i = (j); i >= (k); i--)
#define all(x) x.begin(), x.end()
const int N = 2e5 + 5;
int n, a[N], h[N], c[N], scc[N], vis[N], sc[N], sz[N], deg[N], k = 0, waste = 0;
vector<int> adj[N], rev[N], g[N], topo;
map<pair<int, int>, int> cyc;
set<int> dis[N];
map<int, int> mp[N];
void dfs(int u, int p) {
vis[u] = 1;
for (int v : adj[u]) if (!vis[v]) dfs(v, u);
topo.push_back(u);
}
void rdfs(int u) {
sc[k] += c[u];
cyc[{k, h[u]}] += c[u];
dis[k].insert(h[u]);
scc[u] = k;
for (int v : rev[u]) if (!scc[v]) rdfs(v);
}
int fen[N];
void upd(int i, int k) {
for (; i < N; i += (i & -i)) fen[i] += k;
}
void comp(int u, int p) {
vis[u] = 1;
sz[u] = 1;
for (int v : g[u]) if (v != p) {
comp(v, u);
sz[u] += sz[v];
if (mp[v].size() > mp[u].size()) mp[u].swap(mp[v]);
}
for (int v : g[u]) if (v != p) {
for (auto i : mp[v]) mp[u][i.fi] += i.se;
}
// f(x, i) <= f(x, i+1) <= ...
// [1..v] = min([1..v], sc[u] - cyc[{u, v}] + prv[..v]);
mp[u][0] += sc[u];
for (int v : dis[u]) {
mp[u][v] += cyc[{u, v}];
int sto = cyc[{u, v}];
while (sto) {
auto it = mp[u].lower_bound(v);
if (it == mp[u].begin()) break;
--it;
int rm = min(sto, it->second);
sto -= rm;
mp[u][it->first] -= rm;
int tp = it->first;
if (!mp[u][tp]) mp[u].erase(tp);
}
}
}
void solve() {
cin >> n;
int ans = 0;
vector<int> disc;
for (int i = 1; i <= n; i++) {
cin >> a[i] >> h[i] >> c[i];
if (a[i] != i) adj[a[i]].push_back(i);
if (i != a[i]) rev[i].push_back(a[i]);
disc.push_back(h[i]);
}
sort(disc.begin(), disc.end());
disc.erase(unique(disc.begin(), disc.end()), disc.end());
for (int i = 1; i <= n; i++) {
h[i] = lower_bound(disc.begin(), disc.end(), h[i]) - disc.begin() + 1;
}
for (int i = 1; i <= n; i++) if (!vis[i]) dfs(i, -1);
reverse(all(topo));
for (int nd : topo) if (!scc[nd]) {
++k;
rdfs(nd);
}
for (int u = 1; u <= n; u++) for (int v : adj[u]) {
if (scc[u] == scc[v]) continue;
g[scc[u]].push_back(scc[v]);
}
for (int i = 0; i < N; i++) {
sort(all(g[i]));
g[i].erase(unique(g[i].begin(), g[i].end()), g[i].end());
for (int j : g[i]) deg[j]++;
}
for (int i = 1; i <= n; i++) vis[i] = 0;
n++;
for (int i = 1; i <= k; i++) if (!deg[i]) {
comp(i, -1);
ans += mp[i][0];
}
cout << ans << '\n';
}
int32_t main() {
ios::sync_with_stdio(0); cin.tie(0);
int test = 1;
// cin >> test;
while (test--) solve();
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |