#include <bits/stdc++.h>
using namespace std;
const int maxn = 2e5+10;
const int maxm = 5e2+10;
const int inf = 1e9+10;
typedef pair<int, int> pii;
struct L
{
int l, r, v, ind;
} f[maxn], aux[maxn];
int n, m, root;
int num[maxn], in[maxn], out[maxn], tt;
int nxt[2][maxn], dp[maxn][maxm];
vector<int> grafo[maxn];
void dfs(int u, int p)
{
in[u] = ++tt;
for (auto v: grafo[u])
if (v != p)
dfs(v, u);
out[u] = tt;
f[u] = {in[u], out[u], num[u], 0};
}
bool comp(L a, L b)
{
if (a.r == b.r) return a.l < b.l;
return a.r < b.r;
}
bool comp2(L a, L b)
{
if (a.l == b.l) return a.r < b.r;
return a.l < b.l;
}
int solve(int pos, int x)
{
if (pos == n+1 && x == m+1) return 0;
if (pos == n+1) return -inf;
if (dp[pos][x] != -1) return dp[pos][x];
int caso1 = solve(pos+1, x);
int caso2 = aux[pos].v+solve(nxt[1][pos], x+1);
return dp[pos][x] = max(caso1, caso2);
}
int main(void)
{
ios::sync_with_stdio(false); cin.tie(0);
cin >> n >> m;
for (int i = 1; i <= n; i++)
{
int p;
cin >> p >> num[i];
if (p)
{
grafo[i].push_back(p);
grafo[p].push_back(i);
}
if (!p) root = i;
}
if (n == 1)
{
cout << num[1] << "\n";
return 0;
}
dfs(root, 0);
sort(f+1, f+n+1, comp);
for (int i = 1; i <= n; i++)
{
aux[i] = f[i];
aux[i].ind = i;
}
sort(aux+1, aux+n+1, comp2);
for (int i = 1; i <= n; i++)
f[aux[i].ind].ind = i;
int ptr = n;
for (int i = n; i >= 1; i--)
{
if (f[i].r >= aux[ptr].l)
{
nxt[0][i] = n+1;
continue;
}
while (f[i].r < aux[ptr-1].l) ptr--;
nxt[0][i] = aux[ptr].ind;
}
for (int i = 1; i <= n; i++)
{
nxt[1][f[i].ind] = f[nxt[0][i]].ind;
if (nxt[0][i] == n+1) nxt[1][f[i].ind] = n+1;
}
memset(dp, -1, sizeof dp);
cout << solve(1, 1) << "\n";
}
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Correct |
278 ms |
404348 KB |
Output is correct |
2 |
Correct |
275 ms |
404344 KB |
Output is correct |
3 |
Correct |
276 ms |
404344 KB |
Output is correct |
4 |
Incorrect |
374 ms |
405500 KB |
Output isn't correct |
5 |
Halted |
0 ms |
0 KB |
- |