// code written by Arjun Tomar
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define pii pair<int,int>
#define pll pair<ll,ll>
#define mp(x,y) make_pair(x,y)
const int nax=1e5+5;
const int gax=101;
vector<vector<int>> v(nax);
int vis[nax]={0};
map<int,vector<ll>> m[nax];
ll a[nax];
vector<vector<ll>> dp(nax,vector<ll>(gax,0)),dp1(nax,vector<ll>(gax,0));
void dfs(int ci){
vis[ci]=1;
ll g=a[ci];
for(auto x : v[ci]){
if(vis[x]==0) g+=a[x];
}
for(int i=1;i<gax;i++) dp[ci][i]=g;
vector<int> v1;
for(auto x : v[ci]){
if(vis[x]==1) continue;
dfs(x);
for(int i=1;i<gax-1;i++){
dp[ci][i+1]=max(dp[ci][i+1],dp[x][i]+dp[ci][1]-a[x]);
}
v1.push_back(x);
m[ci][x].resize(gax,0);
m[ci][x][1]=dp[ci][1]-a[x];
}
for(int i=1;i<gax-1;i++){
vector<int> pre,suff;
for(auto x : v1){
pre.push_back(dp[x][i-1]-(i>1 ? a[x] : 0) );
}
suff=pre;
for(int j=1;j<(int)pre.size();j++){
pre[j]=max(pre[j-1],pre[j]);
}
for(int j=(int)suff.size()-2;j>=0;j--){
suff[j]=max(suff[j],suff[j+1]);
}
for(int j=0;j<v1.size();j++){
m[ci][v1[j]][i]=max((j>0 ? pre[j-1]:0),(j<(int)suff.size()-1?suff[j+1]:0))+dp[ci][1]-a[v1[j]];
}
}
}
void dfs1(int ci, vector<ll> temp, ll childw){
vis[ci]=1;
dp1[ci][1]=dp[ci][1]+childw;
for(int i=1;i<gax-1;i++){
dp1[ci][i+1]=max(dp[ci][i+1]+childw, dp[ci][1]+temp[i]);
}
for(auto x : v[ci]){
if(vis[x]==1) continue;
vector<ll> t1(gax,0);
vector<ll> t2=m[ci][x];
t1[1]=t2[1]+childw;
for(int i=2;i<gax;i++){
t1[i]=max(t2[i]+childw, t2[1]+temp[i-1]);
}
dfs1(x,t1,a[ci]);
}
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL); cout.tie(NULL);
int n,r;
cin>>n>>r;
for(int i=0;i<n;i++)
cin>>a[i];
for(int i=0;i<n-1;i++){
int x,y;cin>>x>>y;
x--; y--;
v[x].push_back(y);
v[y].push_back(x);
}
dfs(0);
for(int i=0;i<n;i++) vis[i]=0;
vector<ll> ty(gax,0);
dfs1(0,ty,0);
ll ans=0;
for(int i=0;i<n;i++){
for(int j=1;j<=r;j++){
ans=max(dp1[i][j]-a[i],ans);
}
}
cout<<ans<<'\n';
}
Compilation message
chase.cpp: In function 'void dfs(int)':
chase.cpp:47:22: warning: comparison between signed and unsigned integer expressions [-Wsign-compare]
for(int j=0;j<v1.size();j++){
~^~~~~~~~~~
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
102 ms |
171768 KB |
Output is correct |
2 |
Incorrect |
106 ms |
171768 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
102 ms |
171768 KB |
Output is correct |
2 |
Incorrect |
106 ms |
171768 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
1159 ms |
391964 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
102 ms |
171768 KB |
Output is correct |
2 |
Incorrect |
106 ms |
171768 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |