이 제출은 이전 버전의 oj.uz에서 채점하였습니다. 현재는 제출 당시와는 다른 서버에서 채점을 하기 때문에, 다시 제출하면 결과가 달라질 수도 있습니다.
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
int n, k;
ll arr[200002];
vector<int> link[200002];
void solvek1();
void solvek2();
void solvek3();
int main(){
scanf("%d %d", &n, &k);
for(int i=1; i<n; i++){
int x, y;
scanf("%d %d", &x, &y);
link[x].push_back(y);
link[y].push_back(x);
}
for(int i=1; i<=n; i++) scanf("%lld", &arr[i]);
if(k==1) solvek1();
if(k==2) solvek2();
if(k==3) solvek3();
}
namespace solve1{
int par[200002];
ll depth[200002];
void dfs(int x, int p=-1){
depth[x] += arr[x];
for(auto y: link[x]){
if(y==p) continue;
depth[y] = depth[x], par[y] = x;
dfs(y, x);
}
}
void solve(){
dfs(1);
int x = max_element(depth+1, depth+n+1) - depth;
vector<int> v;
while(x) v.push_back(x), x = par[x];
reverse(v.begin(), v.end());
printf("%lld\n%d\n", depth[v.back()], (int)v.size());
for(auto x: v) printf("%d ", x);
}
}
void solvek1(){
solve1::solve();
}
namespace solve2{
int par[200002];
ll depth[200002];
vector<int> child[200002];
void dfsChild(int x, int p=-1){
for(auto y: link[x]){
if(y==p) continue;
child[x].push_back(y), par[y] = x, depth[y] = depth[x] + 1;
dfsChild(y, x);
}
}
ll U[200002][2]; /// 돌아오는 DP, [0]은 x 부모에서 시작, [1]은 x에서 시작
int trackU[200002][2]; /// 역추적 용도
void dfsU(int x){
if(child[x].empty()){ /// 리프 노드인 경우
U[x][0] = U[x][1] = arr[x];
return;
}
vector<pair<ll, int> > profit_pointToU0, profit_pointToU1;
ll pointSum = arr[x]; /// 자기 직접 자식들의 모든 이익을 더한 값
for(int y: child[x]){
dfsU(y);
pointSum += arr[y];
profit_pointToU0.push_back(make_pair(U[y][0] - arr[y], y));
profit_pointToU1.push_back(make_pair(U[y][1] - arr[y], y));
}
sort(profit_pointToU0.rbegin(), profit_pointToU0.rend());
sort(profit_pointToU1.rbegin(), profit_pointToU1.rend());
/// U[x][0]: 부모에서 출발해 x가 비어 있는 채로 아래를 순회하고 돌아오는 비용
/// point 값 중 하나를 U[y][1]로 바꿀 수 있다.
assert(profit_pointToU0[0].first >= 0);
U[x][0] = pointSum + profit_pointToU1[0].first;
trackU[x][0] = profit_pointToU1[0].second;
/// U[x][1]: x에서 출발해 부모가 비어 있는 채로 아래를 순회하고 돌아오는 비용
/// point 값 중 하나를 U[y][0]으로 바꿀 수 있다.
assert(profit_pointToU1[0].first >= 0);
U[x][1] = pointSum + profit_pointToU0[0].first;
trackU[x][1] = profit_pointToU0[0].second;
}
ll I[200002][2]; /// 돌아오지 않는 DP, [0]은 x 부모에서 시작, [1]은 x에서 시작
int typeI[200002][2], trackI[200002][2], trackI2[200002][2], trackI3[200002][2]; /// 역추적 용도
bool trackDetail[200002][2]; /// 0이면 I, 1이면 U
/// type -1: I[0]->I[1]로 바꿀 때, 0: 잔여 노드 챙기고 내려갈 때, 1: 그냥 하나로 내려갈 때
void dfsI(int x){
if(child[x].empty()){
I[x][0] = I[x][1] = arr[x];
return;
}
vector<pair<ll, int> > profit_pointToU0, profit_pointToU1;
vector<pair<ll, int> > profit_pointToI0, profit_pointToI1;
ll pointSum = arr[x];
for(auto y: child[x]){
dfsI(y);
pointSum += arr[y];
profit_pointToU0.push_back(make_pair(U[y][0] - arr[y], y));
profit_pointToU1.push_back(make_pair(U[y][1] - arr[y], y));
profit_pointToI0.push_back(make_pair(I[y][0] - arr[y], y));
profit_pointToI1.push_back(make_pair(I[y][1] - arr[y], y));
}
sort(profit_pointToU0.rbegin(), profit_pointToU0.rend());
sort(profit_pointToU1.rbegin(), profit_pointToU1.rend());
sort(profit_pointToI0.rbegin(), profit_pointToI0.rend());
sort(profit_pointToI1.rbegin(), profit_pointToI1.rend());
/// I[x][1]: x에서 출발해 아래로 내려가는 최댓값
/// Case 0. U[x][1]을 택한다.
I[x][1] = U[x][1], typeI[x][1] = -2;
/// Case 1. 자식 중 하나의 I[0]을 취할 수 있다.
for(auto y: child[x]) if(I[x][1] < arr[x] + max(U[y][0], I[y][0])){
I[x][1] = arr[x] + max(U[y][0], I[y][0]);
typeI[x][1] = 1, trackI[x][1] = y, trackDetail[x][1] = U[y][0]>I[y][0];
}
/// Case 2. 자식 중 하나의 U[0] or U[1]을 취하고, 잔여 노드를 먹고, 남은 자식 중 하나의 I[1]을 취할 수 있다.
for(auto a: child[x]) for(auto b: child[x]){
if(a==b) continue;
ll v = pointSum - arr[a] - arr[b] + max(U[a][0], U[a][1]) + max(U[b][1], I[b][1]);
if(I[x][1] >= v) continue;
I[x][1] = v, typeI[x][1] = 0, trackI[x][1] = a, trackI2[x][1] = b, trackDetail[x][1] = U[b][1]>I[b][1];
}
/// I[x][0]: x 부모에서 출발해 아래로 내려가는 최댓값
/// Case 0. I[1]을 택한다.
I[x][0] = I[x][1], typeI[x][0] = -1;
/// Case 0-2. U[x][0]을 택한다.
if(I[x][0] > U[x][0]) I[x][0] = U[x][0], typeI[x][0] = -2;
/// Case 1. 자식 중 하나의 I[0]을 취할 수 있다.
for(auto y: child[x]) if(I[x][0] < arr[x] + max(U[y][0], I[y][0])){
I[x][0] = arr[x] + max(U[y][0], I[y][0]);
typeI[x][0] = 1, trackI[x][0] = y, trackDetail[x][0] = U[y][0]>I[y][0];
}
/// Case 2. 자식 중 하나의 U[1]을 취하고, 잔여 노드를 먹고, 남은 자식 중 하나의 I[1]을 취할 수 있다.
for(auto a: child[x]) for(auto b: child[x]){
if(a==b) continue;
ll v = pointSum - arr[a] - arr[b] + U[a][1] + max(U[b][1], I[b][1]);
if(I[x][0] >= v) continue;
I[x][0] = v, typeI[x][0] = 0, trackI[x][0] = a, trackI2[x][0] = b, trackDetail[x][0] = U[b][1]>I[b][1];
}
/// Case 3. 자식 중 하나의 U[1]을 취하고, 다른 자식의 U[0]을 취하고, 잔여 노드를 먹고, 남은 자식 중 하나의 I[1]을 취한다.
for(auto a: child[x]) for(auto b: child[x]) for(auto c: child[x]){
if(a==b || b==c || a==c) continue;
ll v = pointSum - arr[a] - arr[b] - arr[c] + U[a][1] + U[b][0] + max(U[c][1], I[c][1]);
if(I[x][0] >= v) continue;
I[x][0] = v, typeI[x][0] = 2, trackI[x][0] = a, trackI2[x][0] = b, trackI3[x][0] = c, trackDetail[x][0] = U[c][1]>I[c][1];
}
/// Case 4. 자식 중 하나의 U[0]을 취하고, 잔여 노드를 먹고, 남은 자식 중 하나의 I[1]을 취한다.
for(auto a: child[x]) for(auto b: child[x]){
if(a==b) continue;
ll v = pointSum - arr[a] - arr[b] - U[a][0] + max(U[b][1], I[b][1]);
if(I[x][0] >= v) continue;
I[x][0] = v, typeI[x][0] = 3, trackI[x][0] = a, trackI2[x][0] = b, trackDetail[x][0] = U[b][1]>I[b][1];
}
/// Case 5. 자식 중 하나의 U[1]을 취하고, 다른 자식의 U[0]을 취하고, 잔여 노드를 먹는다.
for(auto a: child[x]) for(auto b: child[x]){
if(a==b) continue;
ll v = pointSum - arr[a] - arr[b] + U[a][1] + U[b][0];
if(I[x][0] >= v) continue;
I[x][0] = v, typeI[x][0] = 4, trackI[x][0] = a, trackI2[x][0] = b;
}
}
vector<int> ans;
void track(int x, char type, int j){
if(type == 'U'){
if(child[x].empty()) ans.push_back(x);
else if(j==0){
int nxt = trackU[x][j];
for(auto y: child[x]) if(y != nxt) ans.push_back(y);
track(nxt, 'U', 1);
ans.push_back(x);
}
else if(j==1){
int nxt = trackU[x][j];
ans.push_back(x);
track(nxt, 'U', 0);
for(auto y: child[x]) if(y != nxt) ans.push_back(y);
}
}
else if(type == 'I'){
if(child[x].empty()) ans.push_back(x);
else if(typeI[x][j] == -2) track(x, 'U', j);
else if(typeI[x][j] == -1) track(x, 'I', !j);
else if(j==0){
if(typeI[x][j] == 0){
int a = trackI[x][j], b = trackI2[x][j];
track(a, 'U', 1);
ans.push_back(x);
for(auto y: child[x]) if(y!=a && y!=b) ans.push_back(y);
track(b, trackDetail[x][j] ? 'U' : 'I', 1);
}
else if(typeI[x][j] == 1){
int a = trackI[x][j];
ans.push_back(x);
track(a, trackDetail[x][j] ? 'U' : 'I', 0);
}
else if(typeI[x][j] == 2){
int a = trackI[x][j], b = trackI2[x][j], c = trackI3[x][j];
track(a, 'U', 1);
ans.push_back(x);
track(b, 'U', 0);
for(auto y: child[x]) if(y!=a && y!=b && y!=c) ans.push_back(y);
track(c, trackDetail[x][j] ? 'U' : 'I', 1);
}
else if(typeI[x][j] == 3){
int a = trackI[x][j], b = trackI2[x][j];
ans.push_back(x);
track(a, 'U', 0);
for(auto y: child[x]) if(y!=a && y!=b) ans.push_back(y);
track(b, trackDetail[x][j] ? 'U' : 'I', 1);
}
else if(typeI[x][j] == 4){
int a = trackI[x][j], b = trackI2[x][j];
track(a, 'U', 1);
ans.push_back(x);
track(b, 'U', 0);
for(auto y: child[x]) if(y!=a && y!=b) ans.push_back(y);
}
else exit(1);
}
else{
if(typeI[x][j] == 0){
int a = trackI[x][j], b = trackI2[x][j];
ans.push_back(x);
track(a, 'U', 0);
for(auto y: child[x]) if(y!=a && y!=b) ans.push_back(y);
track(b, trackDetail[x][j] ? 'U' : 'I', 1);
}
else if(typeI[x][j] == 1){
int a = trackI[x][j];
ans.push_back(x);
track(a, trackDetail[x][j] ? 'U' : 'I', 0);
}
else exit(1);
}
}
}
void solve(){
dfsChild(1);
dfsU(1);
dfsI(1);
// printf("Dfs Results\n");
// for(int i=1; i<=n; i++){
// printf("%2d: U0 %2lld U1 %2lld I0 %2lld I1 %2lld\n", i, U[i][0], U[i][1], I[i][0], I[i][1]);
// }
printf("%lld\n", I[1][1]);
track(1, 'I', 1);
printf("%d\n", (int)ans.size());
for(auto x: ans) printf("%d ", x);
ll v = 0;
for(auto x: ans) v += arr[x];
assert(v == I[1][1]);
for(int i=0; i<(int)ans.size()-1; i++){
int x = ans[i], y = ans[i+1];
assert(x==par[y] || y==par[x] || par[x]==par[y] || par[par[x]]==y || par[par[y]]==x);
}
for(int i=1; i<=n; i++) assert(U[i][0] >= U[i][1]);
// for(int i=1; i<=n; i++) assert(typeI[i][0] < 3);
set<int> st;
for(auto x: ans) st.insert(x);
assert(st.size() == ans.size());
}
}
void solvek2(){
solve2::solve();
}
namespace solve3{
int par[200002];
ll depth[200002];
vector<int> child[200002];
void dfs(int x, int p=-1){
for(auto y: link[x]){
if(y==p) continue;
par[y] = x, depth[y] = depth[x] + 1;
child[x].push_back(y);
dfs(y, x);
}
}
vector<int> ans;
bool chk[200002];
void solve(){
dfs(1);
int coolTime = 1;
vector<int> vec (1, 1);
while(!vec.empty()){
int x = vec.back(); coolTime--;
if(coolTime == 0){
assert(!chk[x]);
chk[x] = 1, ans.push_back(x), coolTime = 3;
}
if(child[x].empty()){
if(!chk[x]) chk[x] = 1, ans.push_back(x), coolTime = 3;
vec.pop_back();
}
else{
int y = child[x].back(); child[x].pop_back();
vec.push_back(y);
}
}
printf("%lld\n%d\n", accumulate(arr+1, arr+n+1, 0LL), n);
for(int x: ans) printf("%d ", x);
}
}
void solvek3(){
solve3::solve();
}
컴파일 시 표준 에러 (stderr) 메시지
Main.cpp: In function 'int main()':
Main.cpp:16:10: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
16 | scanf("%d %d", &n, &k);
| ~~~~~^~~~~~~~~~~~~~~~~
Main.cpp:19:14: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
19 | scanf("%d %d", &x, &y);
| ~~~~~^~~~~~~~~~~~~~~~~
Main.cpp:23:34: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
23 | for(int i=1; i<=n; i++) scanf("%lld", &arr[i]);
| ~~~~~^~~~~~~~~~~~~~~~~
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |