#include <bits/stdc++.h>
#define DIM 100002
#define INF 2000000000000000000LL
using namespace std;
vector <int> L[DIM];
int v[DIM],level[DIM],fth[DIM];
long long sum[DIM];
int n,x,y,i,j,nr,k;
struct idk{
long long maxi,maxi2;
int fiu,fiu2;
} dp[DIM][101][2];
long long dp_up[DIM][101][2];
/// dp[nod][i][0/1] = care e suma maxima daca incep undeva in subarborele lui nod,
/// ajung in nod, am consumat i boabe si am una in nod sau nu
void dfs (int nod, int tata){
int ok = 0;
fth[nod] = tata;
for (auto vecin : L[nod])
if (vecin != tata){
ok = 1;
dfs (vecin,nod);
}
if (!ok){ /// frunza
dp[nod][1][1].maxi = v[fth[nod]] - v[nod];
} else {
/// pun boaba in nod
dp[nod][1][1].maxi = sum[nod] - v[nod];
for (int i=2;i<=nr;i++){
long long maxi = -INF, maxi2 = -INF; int fiu = 0, fiu2 = 0;
for (auto vecin : L[nod]){
if (vecin == tata)
continue;
if (dp[vecin][i-1][0].maxi != -INF){
long long val = sum[nod] - v[vecin] - v[nod] + dp[vecin][i-1][0].maxi;
if (val > maxi){
maxi2 = maxi, fiu2 = fiu;
maxi = val, fiu = vecin;
} else {
if (val > maxi2)
maxi2 = val, fiu2 = vecin;
}}
if (dp[vecin][i-1][1].maxi != -INF){
long long val = sum[nod] - v[vecin] - v[nod] + dp[vecin][i-1][1].maxi;
if (val > maxi){
maxi2 = maxi, fiu2 = fiu;
maxi = val, fiu = vecin;
} else {
if (val > maxi2)
maxi2 = val, fiu2 = vecin;
}}}
dp[nod][i][1].maxi = maxi, dp[nod][i][1].fiu = fiu;
dp[nod][i][1].maxi2 = maxi2, dp[nod][i][1].fiu2 = fiu2;
}
/// nu pun boaba in nod
for (int i=1;i<=nr;i++){
long long maxi = -INF, maxi2 = -INF; int fiu = 0, fiu2 = 0;
for (auto vecin : L[nod]){
if (vecin == tata)
continue;
if (dp[vecin][i][0].maxi != -INF){
long long val = dp[vecin][i][0].maxi - v[nod];
if (val > maxi){
maxi2 = maxi, fiu2 = fiu;
maxi = val, fiu = vecin;
} else {
if (val > maxi2)
maxi2 = val, fiu2 = vecin;
}}
if (dp[vecin][i][1].maxi != -INF){
long long val = dp[vecin][i][1].maxi;
if (val > maxi){
maxi2 = maxi, fiu2 = fiu;
maxi = val, fiu = vecin;
} else {
if (val > maxi2)
maxi2 = val, fiu2 = vecin;
}}}
dp[nod][i][0].maxi = maxi, dp[nod][i][0].fiu = fiu;
dp[nod][i][0].maxi2 = maxi2, dp[nod][i][0].fiu2 = fiu2;
}}}
/// dp_up[nod][i][0/1]
void dfs2 (int nod, int tata){
int nr_fii = 0;
for (auto vecin : L[nod]){
if (vecin != tata){
nr_fii++;
if (nr_fii > 1)
break;
}}
/// calculez strict din dp_up[tata]
if (nod != 1){
/// pun boaba in nod
for (int i=2;i<=nr;i++){
if (dp_up[tata][i-1][0] != -INF){
long long val = dp_up[tata][i-1][0] + sum[nod] - v[nod] - v[tata];
dp_up[nod][i][1] = max (dp_up[nod][i][1],val);
}
if (dp_up[tata][i-1][1] != -INF){
long long val = dp_up[tata][i-1][1] + sum[nod] - v[tata] - v[nod];
dp_up[nod][i][1] = max (dp_up[nod][i][1],val);
}
}
/// nu pun boaba in nod
for (int i=1;i<=nr;i++){
if (dp_up[tata][i][0] != -INF){
long long val = dp_up[tata][i][0] - v[nod];
dp_up[nod][i][0] = max (dp_up[nod][i][0],val);
}
if (dp_up[tata][i][1] != -INF)
dp_up[nod][i][0] = max (dp_up[nod][i][0],dp_up[tata][i][1]);
}
/// acum facem in functie de dp
if (nr_fii > 1 && nod != 1){
/// pun boaba
for (int i=2;i<=nr;i++){
long long val;
if (dp[tata][i-1][0].fiu == nod)
val = dp[tata][i-1][0].maxi2 + sum[nod] - v[nod] - v[tata];
else val = dp[tata][i-1][0].maxi + sum[nod] - v[nod] - v[tata];
dp_up[nod][i][1] = max (dp_up[nod][i][1],val);
////////////////////////
if (dp[tata][i-1][1].fiu == nod)
val = dp[tata][i-1][1].maxi2 + sum[nod] - v[nod] - v[tata];
else val = dp[tata][i-1][1].maxi + sum[nod] - v[nod] - v[tata];
dp_up[nod][i][1] = max (dp_up[nod][i][1],val);
}
/// nu pun boaba
for (int i=1;i<=nr;i++){
long long val;
if (dp[tata][i][0].fiu == nod)
val = dp[tata][i][0].maxi2 - v[nod];
else val = dp[tata][i][0].maxi - v[nod];
dp_up[nod][i][0] = max (dp_up[nod][i][0],val);
if (dp[tata][i][1].fiu == nod)
val = dp[tata][i][1].maxi2 - v[nod];
else val = dp[tata][i][1].maxi - v[nod];
dp_up[nod][i][0] = max (dp_up[nod][i][0],val);
}
}
}
for (auto vecin : L[nod])
if (vecin != tata)
dfs2 (vecin,nod);
}
int main (){
//ifstream cin ("date.in");
//ofstream cout ("date.out");
cin>>n>>nr;
for (i=1;i<=n;i++)
cin>>v[i];
for (i=1;i<n;i++){
cin>>x>>y;
L[x].push_back(y);
L[y].push_back(x);
}
for (i=1;i<=n;i++){
sum[i] = v[i];
for (auto it : L[i])
sum[i] += v[it];
}
if (nr == 1){
long long maxi = 0;
for (i=1;i<=n;i++)
maxi = max (maxi,sum[i] - v[i]);
cout<<maxi;
return 0;
}
for (i=1;i<=n;i++)
for (j=1;j<=nr;j++)
dp[i][j][0].maxi = dp[i][j][1].maxi = -INF;
dfs (1,0);
long long maxi = 0;
for (i=1;i<=n;i++)
for (j=1;j<=nr;j++){
maxi = max (maxi,max(dp[i][j][0].maxi,dp[i][j][1].maxi));
maxi = max (maxi,max(dp_up[i][j][0],dp_up[i][j][1]));
}
cout<<maxi;
return 0;
}
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Correct |
6 ms |
2816 KB |
Output is correct |
2 |
Incorrect |
6 ms |
2816 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Correct |
6 ms |
2816 KB |
Output is correct |
2 |
Incorrect |
6 ms |
2816 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Incorrect |
560 ms |
489208 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Correct |
6 ms |
2816 KB |
Output is correct |
2 |
Incorrect |
6 ms |
2816 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |