#include <bits/stdc++.h>
using namespace std;
vector<vector<int>> adjlist;
vector<int> neighbours;
vector<int> val;
vector<vector<pair<int,int>>> upv;
vector<vector<pair<int,int>>> downv;
int v;
int down(int node, int par, int c, int k){
if (k<=0){
return 0;
}
int s = -1;
if (c!=1){
s=(c==0)?0:2;
}
if (s!=-1){
if (s==0){
if (downv[node][k].first!=-1){
return downv[node][k].first;
}
} else {
if (downv[node][k].second!=-1){
return downv[node][k].second;
}
}
}
int mx = 0;
if (c==0){
for (int i : adjlist[node]){
if (i==par){continue;}
mx=max(
mx,
max(
down(i,node,0,k-1)-val[i],
max(
down(i,node,1,k-1)-val[i]+val[node],
down(i,node,2,k-1)+val[node]
)
)+neighbours[node]);
}
} else if (c==1){
for (int i : adjlist[node]){
if (i==par){continue;}
mx=max(mx,down(i,node,0,k)-val[i]);
}
} else {
for (int i : adjlist[node]){
if (i==par){continue;}
mx=max(mx,max(down(i,node,1,k),down(i,node,2,k)));
}
}
if (s!=-1){
if (s==0){
return downv[node][k].first=mx;
} else {
return downv[node][k].second=mx;
}
} else {
return mx;
}
}
int up(int node, int par, int c, int k){
if (k<=0){
return 0;
}
int s = -1;
if (c!=1){
s=(c==0)?0:2;
}
if (s!=-1){
if (s==0){
if (upv[node][k].first!=-1){
return upv[node][k].first;
}
} else {
if (upv[node][k].second!=-1){
return upv[node][k].second;
}
}
}
int mx = 0;
if (c==0){
for (int i : adjlist[node]){
if (i==par){continue;}
mx=max(
mx,
max(
up(i,node,0,k-1)-val[i],
max(
up(i,node,1,k-1)-val[i],
up(i,node,2,k-1)
)
)+neighbours[node]);
}
} else if (c==1){
for (int i : adjlist[node]){
if (i==par){continue;}
mx=max(mx,up(i,node,0,k));
}
} else {
for (int i : adjlist[node]){
if (i==par){continue;}
mx=max(mx,max(up(i,node,1,k),up(i,node,2,k)));
}
}
if (s!=-1){
if (s==0){
return upv[node][k].first=mx;
} else {
return upv[node][k].second=mx;
}
} else {
return mx;
}
}
int dp(int node, int par){
int mx = max(
up(node,par,0,v),
max(up(node,par,1,v),up(node,par,2,v)));
mx=max(mx,
max(
down(node,par,0,v)-val[node],
max(down(node,par,1,v),down(node,par,2,v))
)
);
for (int i : adjlist[node]){
if (i==par){continue;}
mx=max(mx,dp(i,node));
}
// middle not used
for (int k = 1; k<v; k++){
int maxu1=0, maxu1i=0, maxu2=0, maxu2i=0,
maxd1=0, maxd1i=0, maxd2=0, maxd2i=0;
for (int i : adjlist[node]){
if (i==par){continue;}
int uv = max(up(i,node,1,k),up(i,node,2,k));
int dv = max(down(i,node,1,v-k),down(i,node,2,v-k));
if (uv>=maxu1){
maxu1=uv;maxu1i=i;
} else if (uv>=maxu2){
maxu2=uv;maxu2i=i;
}
if (dv>=maxu1){
maxd1=dv;maxd1i=i;
} else if (dv>=maxd2){
maxd2=dv;maxd2i=i;
}
}
if (maxd1i!=maxu1i){
mx=max(mx,maxu1+maxd1);
} else {
mx=max(mx,max(maxu1+maxd2,maxu2+maxd1));
}
}
//middle used
if (v>=3){
for (int k = 1; k<v-1; k++){
//if centre is start
int maxu1=0, maxu1i=0, maxu2=0, maxu2i=0,
maxd1=0, maxd1i=0, maxd2=0, maxd2i=0;
for (int i : adjlist[node]){
if (i==par){continue;}
int uv = max(up(i,node,1,k)-val[i],up(i,node,1,k));
int dv = max(down(i,node,0,v-1-k)-val[node]-val[i],
max(down(i,node,1,v-1-k)-val[i],down(i,node,2,v-1-k)));
if (uv>=maxu1){
maxu1=uv;maxu1i=i;
} else if (uv>=maxu2){
maxu2=uv;maxu2i=i;
}
if (dv>=maxu1){
maxd1=dv;maxd1i=i;
} else if (dv>=maxd2){
maxd2=dv;maxd2i=i;
}
}
if (maxd1i!=maxu1i){
mx=max(mx,maxu1+maxd1+neighbours[node]);
} else {
mx=max(mx,max(maxu1+maxd2,maxu2+maxd1)+neighbours[node]);
}
}
for (int k = 1; k<v-1; k++){
//if centre is middle
int maxu1=0, maxu1i=0, maxu2=0, maxu2i=0,
maxd1=0, maxd1i=0, maxd2=0, maxd2i=0;
for (int i : adjlist[node]){
if (i==par){continue;}
int uv = up(i,node,0,k);
int dv = max(down(i,node,0,v-1-k)-val[node]-val[i],
max(down(i,node,1,v-1-k)-val[i],down(i,node,2,v-1-k)));
if (uv>=maxu1){
maxu1=uv;maxu1i=i;
} else if (uv>=maxu2){
maxu2=uv;maxu2i=i;
}
if (dv>=maxu1){
maxd1=dv;maxd1i=i;
} else if (dv>=maxd2){
maxd2=dv;maxd2i=i;
}
if (maxd1i!=maxu1i){
mx=max(mx,maxu1+maxd1+neighbours[node]+val[node]);
} else {
mx=max(mx,max(maxu1+maxd2,maxu2+maxd1)+neighbours[node]+val[node]);
}
}
}
}
return mx;
}
int main(){
cin.tie(0);
ios_base::sync_with_stdio(0);
int n;
cin>>n>>v;
adjlist.resize(n);
neighbours.resize(n);
val.resize(n);
upv.resize(n,vector<pair<int,int>>(v+1,pair<int,int>({-1,-1})));
downv.resize(n,vector<pair<int,int>>(v+1,pair<int,int>({-1,-1})));
for (int i = 0; i<n; i++){
cin>>val[i];
}
for (int i = 0; i<n-1; i++){
int a,b;
cin>>a>>b;
a--;b--;
adjlist[a].push_back(b);
adjlist[b].push_back(a);
neighbours[a]+=val[b];
neighbours[b]+=val[a];
}
cout<<dp(0,-1)<<'\n';
}
Compilation message
chase.cpp: In function 'int dp(int, int)':
chase.cpp:139:35: warning: variable 'maxu2i' set but not used [-Wunused-but-set-variable]
int maxu1=0, maxu1i=0, maxu2=0, maxu2i=0,
^~~~~~
chase.cpp:140:31: warning: variable 'maxd2i' set but not used [-Wunused-but-set-variable]
maxd1=0, maxd1i=0, maxd2=0, maxd2i=0;
^~~~~~
chase.cpp:166:36: warning: variable 'maxu2i' set but not used [-Wunused-but-set-variable]
int maxu1=0, maxu1i=0, maxu2=0, maxu2i=0,
^~~~~~
chase.cpp:167:32: warning: variable 'maxd2i' set but not used [-Wunused-but-set-variable]
maxd1=0, maxd1i=0, maxd2=0, maxd2i=0;
^~~~~~
chase.cpp:192:36: warning: variable 'maxu2i' set but not used [-Wunused-but-set-variable]
int maxu1=0, maxu1i=0, maxu2=0, maxu2i=0,
^~~~~~
chase.cpp:193:32: warning: variable 'maxd2i' set but not used [-Wunused-but-set-variable]
maxd1=0, maxd1i=0, maxd2=0, maxd2i=0;
^~~~~~
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Incorrect |
5 ms |
384 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Incorrect |
5 ms |
384 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Incorrect |
2803 ms |
181836 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Incorrect |
5 ms |
384 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |