This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
int n;
vector<int> adj[200200];
ll a[200200];
namespace foo{
ll dp[200200];
int nxt[200200];
void dfs(int s, int pa = 0){
dp[s] = a[s], nxt[s] = 0;
for (auto &v:adj[s]) if (v!=pa){
dfs(v, s);
if (dp[s] < dp[v] + a[s]){
dp[s] = dp[v] + a[s];
nxt[s] = v;
}
}
}
void solve1(){
dfs(1);
printf("%lld\n", dp[1]);
vector<int> V;
for (int i=1;i;i=nxt[i]){
V.push_back(i);
}
printf("%d\n", (int)V.size());
for (auto &x:V) printf("%d ", x);
printf("\n");
}
} // namespace foo
namespace bar{
constexpr ll INF = 4e18;
ll dp[200200][2][3];
int v1[200200], v2[200200][2], v3[200200][3], typ2[200200], typ3[200200], nxt[200200];
vector<int> G[200200];
void dfs(int s, int pa = 0){
vector<pair<ll, int>> C1, C2, C3, C6;
for (int i=1;i<=3;i++){
C1.emplace_back(0, -i);
C2.emplace_back(0, -i);
C3.emplace_back(0, -i);
C6.emplace_back(0, -i);
}
ll sum = 0;
int deg = 0;
for (auto &v:adj[s]) if (v!=pa){
G[s].push_back(v);
dfs(v, s);
sum += a[v];
deg++;
C1.emplace_back(dp[v][1][0] - a[v], v);
C2.emplace_back(dp[v][0][2] - a[v], v);
C3.emplace_back(dp[v][1][2], v);
C6.emplace_back(dp[v][0][1] - a[v], v);
}
sort(C1.begin(), C1.end(), greater<pair<ll, int>>());
sort(C2.begin(), C2.end(), greater<pair<ll, int>>());
sort(C3.begin(), C3.end(), greater<pair<ll, int>>());
sort(C6.begin(), C6.end(), greater<pair<ll, int>>());
// calc 0 1
dp[s][0][1] = a[s] + sum;
dp[s][0][1] += C1[0].first;
v1[s] = C1[0].second;
if (deg==0) dp[s][0][1] = -INF;
// calc 0 2
dp[s][0][2] = -INF;
for (int i=0;i<2;i++){
for (int j=0;j<2;j++) if (C1[i].second != C2[j].second){
ll val = a[s] + sum + C1[i].first + C2[j].first;
if (dp[s][0][2] < val){
dp[s][0][2] = val;
typ2[s] = 2;
v2[s][0] = C1[i].second;
v2[s][1] = C2[j].second;
}
}
}
if (dp[s][0][2] < a[s] + C3[0].first){
dp[s][0][2] = a[s] + C3[0].first;
typ2[s] = 1;
v2[s][0] = C3[0].second;
}
// calc 1 2
dp[s][1][2] = -INF;
vector<pair<ll, int>> C4 = C1;
for (auto &v:adj[s]) if (v!=pa) C4.emplace_back(0, n+v);
sort(C4.begin(), C4.end(), greater<pair<ll, int>>());
for (int i=0;i<min(6, (int)C4.size());i++){
for (int j=0;j<3;j++){
for (int k=0;k<3;k++){
int rc4 = C4[i].second;
if (rc4 > n) rc4 -= n;
if (rc4==C1[j].second || rc4==C2[k].second || C1[j].second==C2[k].second) continue;
if (rc4 < 0) continue;
ll val = a[s] + sum + C4[i].first + C1[j].first + C2[k].first;
if (dp[s][1][2] < val){
dp[s][1][2] = val;
typ3[s] = 3;
v3[s][0] = C4[i].second;
v3[s][1] = C1[j].second;
v3[s][2] = C2[k].second;
}
}
}
}
for (int i=0;i<min(4, (int)C4.size());i++){
for (int j=0;j<2;j++) if (C4[i].second != C3[j].second && C4[i].second > 0){
int rc4 = C4[i].second;
if (rc4 > n) rc4 -= n;
if (rc4==C3[j].second) continue;
if (rc4 < 0) continue;
ll val = a[s] + C4[i].first + C3[j].first;
if (C4[i].second > 0) val += a[rc4];
if (dp[s][1][2] < val){
dp[s][1][2] = val;
typ3[s] = 2;
v3[s][0] = C4[i].second;
v3[s][1] = C3[j].second;
}
}
}
vector<pair<ll, int>> C5;
for (int i=0;i<(int)C3.size();i++) if (C3[i].second > 0){
C5.emplace_back(C3[i].first-a[C3[i].second], C3[i].second);
}
sort(C5.begin(), C5.end(), greater<pair<ll, int>>());
for (int i=0;i<2;i++){
for (int j=0;j<min(2, (int)C5.size());j++) if (C6[i].second != C5[j].second){
ll val = a[s] + sum + C5[j].first + C6[i].first;
if (dp[s][1][2] < val){
dp[s][1][2] = val;
typ3[s] = 1;
v3[s][0] = C6[i].second;
v3[s][1] = C5[j].second;
}
}
}
if (deg==0) dp[s][1][2] = -INF;
// calc 1 0
dp[s][1][0] = dp[s][0][1];
if (dp[s][1][2] < dp[s][1][0]) dp[s][1][2] = dp[s][1][0], typ3[s] = 0;
}
pair<int, int> track(int s, int x, int y){
if (x==0 && y==0) return {s, s};
assert(s > 0);
assert(dp[s][x][y] >= 0);
if (x==1 && y==2 && typ3[s]==0) y = 0;
if (x==0 && y==1){
pair<int, int> ret = {s, s};
if (v1[s] > 0){
auto [l1, r1] = track(v1[s], 1, 0);
nxt[s] = l1;
ret.second = r1;
G[s].erase(find(G[s].begin(), G[s].end(), v1[s]));
}
for (auto &v:G[s]){
nxt[ret.second] = v;
ret.second = v;
}
return ret;
}
if (x==0 && y==2){
pair<int, int> ret = {s, s};
if (typ2[s]==1){
if (v2[s][0] > 0){
auto [l1, r1] = track(v2[s][0], 1, 2);
nxt[s] = l1;
ret.second = r1;
}
return ret;
}
// else (typ2[s]==2)
if (v2[s][0] > 0){
auto [l1, r1] = track(v2[s][0], 1, 0);
nxt[s] = l1;
ret.second = r1;
G[s].erase(find(G[s].begin(), G[s].end(), v2[s][0]));
}
if (v2[s][1] > 0){
G[s].erase(find(G[s].begin(), G[s].end(), v2[s][1]));
}
for (auto &v:G[s]){
nxt[ret.second] = v;
ret.second = v;
}
if (v2[s][1] > 0){
nxt[ret.second] = v2[s][1];
ret.second = v2[s][1];
auto [l2, r2] = track(v2[s][1], 0, 2);
ret.second = r2;
}
return ret;
}
if (x==1 && y==0){
pair<int, int> ret = {s, s};
if (v1[s] > 0){
auto [l1, r1] = track(v1[s], 0, 1);
nxt[r1] = s;
ret.first = l1;
G[s].erase(find(G[s].begin(), G[s].end(), v1[s]));
}
for (auto &v:G[s]){
nxt[v] = ret.first;
ret.first = v;
}
assert(ret.first != s);
return ret;
}
assert(x==1 && y==2);
pair<int, int> ret = {s, s};
if (typ3[s]==2){
if (v3[s][0] > 0){
int rv3 = v3[s][0], ri = 0, rj = 1;
if (v3[s][0] > n) rv3 = v3[s][0]-n, ri = 0, rj = 0;
auto [l1, r1] = track(rv3, ri, rj);
ret.first = l1;
nxt[r1] = s;
G[s].erase(find(G[s].begin(), G[s].end(), rv3));
}
if (v3[s][1] > 0){
auto [l2, r2] = track(v3[s][1], 1, 2);
nxt[s] = l2;
ret.second = r2;
}
assert(ret.first != s);
return ret;
}
if (typ3[s]==1){
assert(v3[s][1] > 0);
G[s].erase(find(G[s].begin(), G[s].end(), v3[s][1]));
if (v3[s][0] > 0){
G[s].erase(find(G[s].begin(), G[s].end(), v3[s][0]));
auto [l2, r2] = track(v3[s][0], 0, 1);
nxt[r2] = ret.first;
ret.first = l2;
}
for (auto &v:G[s]){
nxt[v] = ret.first;
ret.first = v;
}
auto [l1, r1] = track(v3[s][1], 1, 2);
nxt[s] = l1;
ret.second = r1;
return ret;
}
//typ3[s] == 3
if (v3[s][0] > 0){
int rv3 = v3[s][0], ri = 0, rj = 1;
if (v3[s][0] > n) rv3 = v3[s][0]-n, ri = 0, rj = 0;
auto [l1, r1] = track(rv3, ri, rj);
ret.first = l1;
nxt[r1] = s;
G[s].erase(find(G[s].begin(), G[s].end(), rv3));
}
if (v3[s][1] > 0){
auto [l2, r2] = track(v3[s][1], 1, 0);
nxt[s] = l2;
ret.second = r2;
G[s].erase(find(G[s].begin(), G[s].end(), v3[s][1]));
}
if (v3[s][2] > 0){
G[s].erase(find(G[s].begin(), G[s].end(), v3[s][2]));
}
for (auto &v:G[s]){
nxt[ret.second] = v;
ret.second = v;
}
if (v3[s][2] > 0){
auto [l3, r3] = track(v3[s][2], 0, 2);
nxt[ret.second] = l3;
ret.second = r3;
}
assert(ret.first != s);
return ret;
}
ll solve2(){
for (int i=1;i<=n;i++) nxt[i] = 0, G[i].clear();
dfs(1);
ll ans = dp[1][0][2];
printf("%lld\n", ans);
track(1, 0, 2);
vector<int> V;
for (int i=1;i;i=nxt[i]){
V.push_back(i);
}
printf("%d\n", (int)V.size());
for (auto &x:V) printf("%d ", x);
printf("\n");
return ans;
}
} // namespace bar
namespace baz{
void dfs1(int s, int pa);
void dfs2(int s, int pa);
vector<int> V;
void dfs1(int s, int pa){
V.push_back(s);
for (auto &v:adj[s]) if (v!=pa) dfs2(v, s);
}
void dfs2(int s, int pa){
for (auto &v:adj[s]) if (v!=pa) dfs1(v, s);
V.push_back(s);
}
void solve3(){
dfs1(1, 0);
ll sum = 0;
for (int i=1;i<=n;i++) sum += a[i];
printf("%lld\n", sum);
printf("%d\n", n);
for (auto &x:V) printf("%d ", x);
printf("\n");
}
} // namespace baz
int main(){
int k;
scanf("%d %d", &n, &k);
for (int i=1;i<=n-1;i++){
int x, y;
scanf("%d %d", &x, &y);
adj[x].push_back(y);
adj[y].push_back(x);
}
for (int i=1;i<=n;i++) scanf("%lld", a+i);
if (k==1) foo::solve1();
else if (k==2) bar::solve2();
else baz::solve3();
}
Compilation message (stderr)
Main.cpp: In function 'int main()':
Main.cpp:426:7: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
426 | scanf("%d %d", &n, &k);
| ~~~~~^~~~~~~~~~~~~~~~~
Main.cpp:430:8: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
430 | scanf("%d %d", &x, &y);
| ~~~~~^~~~~~~~~~~~~~~~~
Main.cpp:435:30: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
435 | for (int i=1;i<=n;i++) scanf("%lld", a+i);
| ~~~~~^~~~~~~~~~~~~
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |