제출 #794326

#제출 시각아이디문제언어결과실행 시간메모리
79432679brueTravelling Trader (CCO23_day2problem2)C++17
11 / 25
210 ms49608 KiB
#include <bits/stdc++.h>

using namespace std;

typedef long long ll;

int n, k;
ll arr[200002];
vector<int> link[200002];

void solvek1();
void solvek2();
void solvek3();

int main(){
    scanf("%d %d", &n, &k);
    for(int i=1; i<n; i++){
        int x, y;
        scanf("%d %d", &x, &y);
        link[x].push_back(y);
        link[y].push_back(x);
    }
    for(int i=1; i<=n; i++) scanf("%lld", &arr[i]);
    if(k==1) solvek1();
    if(k==2) solvek2();
    if(k==3) solvek3();
}

namespace solve1{
    int par[200002];
    ll depth[200002];

    void dfs(int x, int p=-1){
        depth[x] += arr[x];
        for(auto y: link[x]){
            if(y==p) continue;
            depth[y] = depth[x], par[y] = x;
            dfs(y, x);
        }
    }

    void solve(){
        dfs(1);
        int x = max_element(depth+1, depth+n+1) - depth;
        vector<int> v;
        while(x) v.push_back(x), x = par[x];
        reverse(v.begin(), v.end());
        printf("%lld\n%d\n", depth[v.back()], (int)v.size());
        for(auto x: v) printf("%d ", x);
    }
}

void solvek1(){
    solve1::solve();
}

namespace solve2{
    int par[200002];
    ll depth[200002];

    vector<int> child[200002];

    void dfsChild(int x, int p=-1){
        for(auto y: link[x]){
            if(y==p) continue;
            child[x].push_back(y), par[y] = x, depth[y] = depth[x] + 1;
            dfsChild(y, x);
        }
    }

    ll U[200002][2]; /// 돌아오는 DP, [0]은 x 부모에서 시작, [1]은 x에서 시작
    int trackU[200002][2]; /// 역추적 용도

    void dfsU(int x){
        if(child[x].empty()){ /// 리프 노드인 경우
            U[x][0] = U[x][1] = arr[x];
            return;
        }

        vector<pair<ll, int> > profit_pointToU0, profit_pointToU1;

        ll pointSum = arr[x]; /// 자기 직접 자식들의 모든 이익을 더한 값
        for(int y: child[x]){
            dfsU(y);
            pointSum += arr[y];
            profit_pointToU0.push_back(make_pair(U[y][0] - arr[y], y));
            profit_pointToU1.push_back(make_pair(U[y][1] - arr[y], y));
        }

        sort(profit_pointToU0.rbegin(), profit_pointToU0.rend());
        sort(profit_pointToU1.rbegin(), profit_pointToU1.rend());

        /// U[x][0]: 부모에서 출발해 x가 비어 있는 채로 아래를 순회하고 돌아오는 비용
        /// point 값 중 하나를 U[y][1]로 바꿀 수 있다.
        assert(profit_pointToU0[0].first >= 0);
        U[x][0] = pointSum + profit_pointToU1[0].first;
        trackU[x][0] = profit_pointToU1[0].second;

        /// U[x][1]: x에서 출발해 부모가 비어 있는 채로 아래를 순회하고 돌아오는 비용
        /// point 값 중 하나를 U[y][0]으로 바꿀 수 있다.
        assert(profit_pointToU1[0].first >= 0);
        U[x][1] = pointSum + profit_pointToU0[0].first;
        trackU[x][1] = profit_pointToU0[0].second;
    }

    ll I[200002][2]; /// 돌아오지 않는 DP, [0]은 x 부모에서 시작, [1]은 x에서 시작
    int typeI[200002][2], trackI[200002][2], trackI2[200002][2], trackI3[200002][2]; /// 역추적 용도
    bool trackDetail[200002][2]; /// 0이면 I, 1이면 U
    /// type -1: I[0]->I[1]로 바꿀 때, 0: 잔여 노드 챙기고 내려갈 때, 1: 그냥 하나로 내려갈 때

    void dfsI(int x){
        if(child[x].empty()){
            I[x][0] = I[x][1] = arr[x];
            return;
        }

        vector<pair<ll, int> > profit_pointToU0, profit_pointToU1;
        vector<pair<ll, int> > profit_pointToI0, profit_pointToI1;
        ll pointSum = arr[x];

        for(auto y: child[x]){
            dfsI(y);
            pointSum += arr[y];
            profit_pointToU0.push_back(make_pair(U[y][0] - arr[y], y));
            profit_pointToU1.push_back(make_pair(U[y][1] - arr[y], y));
            profit_pointToI0.push_back(make_pair(I[y][0] - arr[y], y));
            profit_pointToI1.push_back(make_pair(I[y][1] - arr[y], y));
        }

        sort(profit_pointToU0.rbegin(), profit_pointToU0.rend());
        sort(profit_pointToU1.rbegin(), profit_pointToU1.rend());
        sort(profit_pointToI0.rbegin(), profit_pointToI0.rend());
        sort(profit_pointToI1.rbegin(), profit_pointToI1.rend());

        /// I[x][1]: x에서 출발해 아래로 내려가는 최댓값
        /// Case 0. U[x][1]을 택한다.
        I[x][1] = U[x][1], typeI[x][1] = -2;

        /// Case 1. 자식 중 하나의 I[0]을 취할 수 있다.
        for(auto y: child[x]) if(I[x][1] < arr[x] + max(U[y][0], I[y][0])){
            I[x][1] = arr[x] + max(U[y][0], I[y][0]);
            typeI[x][1] = 1, trackI[x][1] = y, trackDetail[x][1] = U[y][0]>I[y][0];
        }

        /// Case 2. 자식 중 하나의 U[0] or U[1]을 취하고, 잔여 노드를 먹고, 남은 자식 중 하나의 I[1]을 취할 수 있다.
        for(auto a: child[x]) for(auto b: child[x]){
            if(a==b) continue;
            ll v = pointSum - arr[a] - arr[b] + max(U[a][0], U[a][1]) + max(U[b][1], I[b][1]);
            if(I[x][1] >= v) continue;
            I[x][1] = v, typeI[x][1] = 0, trackI[x][1] = a, trackI2[x][1] = b, trackDetail[x][1] = U[b][1]>I[b][1];
        }

        /// I[x][0]: x 부모에서 출발해 아래로 내려가는 최댓값
        /// Case 0. I[1]을 택한다.
        I[x][0] = I[x][1], typeI[x][0] = -1;

        /// Case 0-2. U[x][0]을 택한다.
        if(I[x][0] > U[x][0]) I[x][0] = U[x][0], typeI[x][0] = -2;

        /// Case 1. 자식 중 하나의 I[0]을 취할 수 있다.
        for(auto y: child[x]) if(I[x][0] < arr[x] + max(U[y][0], I[y][0])){
            I[x][0] = arr[x] + max(U[y][0], I[y][0]);
            typeI[x][0] = 1, trackI[x][0] = y, trackDetail[x][0] = U[y][0]>I[y][0];
        }

        /// Case 2. 자식 중 하나의 U[1]을 취하고, 잔여 노드를 먹고, 남은 자식 중 하나의 I[1]을 취할 수 있다.
        for(auto a: child[x]) for(auto b: child[x]){
            if(a==b) continue;
            ll v = pointSum - arr[a] - arr[b] + U[a][1] + max(U[b][1], I[b][1]);
            if(I[x][0] >= v) continue;
            I[x][0] = v, typeI[x][0] = 0, trackI[x][0] = a, trackI2[x][0] = b, trackDetail[x][0] = U[b][1]>I[b][1];
        }

        /// Case 3. 자식 중 하나의 U[1]을 취하고, 다른 자식의 U[0]을 취하고, 잔여 노드를 먹고, 남은 자식 중 하나의 I[1]을 취한다.
        for(auto a: child[x]) for(auto b: child[x]) for(auto c: child[x]){
            if(a==b || b==c || a==c) continue;
            ll v = pointSum - arr[a] - arr[b] - arr[c] + U[a][1] + U[b][0] + max(U[c][1], I[c][1]);
            if(I[x][0] >= v) continue;
            I[x][0] = v, typeI[x][0] = 2, trackI[x][0] = a, trackI2[x][0] = b, trackI3[x][0] = c, trackDetail[x][0] = U[c][1]>I[c][1];
        }

        /// Case 4. 자식 중 하나의 U[0]을 취하고, 잔여 노드를 먹고, 남은 자식 중 하나의 I[1]을 취한다.
        for(auto a: child[x]) for(auto b: child[x]){
            if(a==b) continue;
            ll v = pointSum - arr[a] - arr[b] - U[a][0] + max(U[b][1], I[b][1]);
            if(I[x][0] >= v) continue;
            I[x][0] = v, typeI[x][0] = 3, trackI[x][0] = a, trackI2[x][0] = b, trackDetail[x][0] = U[b][1]>I[b][1];
        }

        /// Case 5. 자식 중 하나의 U[1]을 취하고, 다른 자식의 U[0]을 취하고, 잔여 노드를 먹는다.
        for(auto a: child[x]) for(auto b: child[x]){
            if(a==b) continue;
            ll v = pointSum - arr[a] - arr[b] + U[a][1] + U[b][0];
            if(I[x][0] >= v) continue;
            I[x][0] = v, typeI[x][0] = 4, trackI[x][0] = a, trackI2[x][0] = b;
        }
    }

    vector<int> ans;

    void track(int x, char type, int j){
        if(type == 'U'){
            if(child[x].empty()) ans.push_back(x);
            else if(j==0){
                int nxt = trackU[x][j];
                for(auto y: child[x]) if(y != nxt) ans.push_back(y);
                track(nxt, 'U', 1);
                ans.push_back(x);
            }
            else if(j==1){
                int nxt = trackU[x][j];
                ans.push_back(x);
                track(nxt, 'U', 0);
                for(auto y: child[x]) if(y != nxt) ans.push_back(y);
            }
        }
        else if(type == 'I'){
            if(child[x].empty()) ans.push_back(x);
            else if(typeI[x][j] == -2) track(x, 'U', j);
            else if(typeI[x][j] == -1) track(x, 'I', !j);
            else if(j==0){
                if(typeI[x][j] == 0){
                    int a = trackI[x][j], b = trackI2[x][j];
                    track(a, 'U', 1);
                    ans.push_back(x);
                    for(auto y: child[x]) if(y!=a && y!=b) ans.push_back(y);
                    track(b, trackDetail[x][j] ? 'U' : 'I', 1);
                }
                else if(typeI[x][j] == 1){
                    int a = trackI[x][j];
                    ans.push_back(x);
                    track(a, trackDetail[x][j] ? 'U' : 'I', 0);
                }
                else if(typeI[x][j] == 2){
                    int a = trackI[x][j], b = trackI2[x][j], c = trackI3[x][j];
                    track(a, 'U', 1);
                    ans.push_back(x);
                    track(b, 'U', 0);
                    for(auto y: child[x]) if(y!=a && y!=b && y!=c) ans.push_back(y);
                    track(c, trackDetail[x][j] ? 'U' : 'I', 1);
                }
                else if(typeI[x][j] == 3){
                    int a = trackI[x][j], b = trackI2[x][j];
                    ans.push_back(x);
                    track(a, 'U', 0);
                    for(auto y: child[x]) if(y!=a && y!=b) ans.push_back(y);
                    track(b, trackDetail[x][j] ? 'U' : 'I', 1);
                }
                else if(typeI[x][j] == 4){
                    int a = trackI[x][j], b = trackI2[x][j];
                    track(a, 'U', 1);
                    ans.push_back(x);
                    track(b, 'U', 0);
                    for(auto y: child[x]) if(y!=a && y!=b) ans.push_back(y);
                }
                else exit(1);
            }
            else{
                if(typeI[x][j] == 0){
                    int a = trackI[x][j], b = trackI2[x][j];
                    ans.push_back(x);
                    track(a, 'U', 0);
                    for(auto y: child[x]) if(y!=a && y!=b) ans.push_back(y);
                    track(b, trackDetail[x][j] ? 'U' : 'I', 1);
                }
                else if(typeI[x][j] == 1){
                    int a = trackI[x][j];
                    ans.push_back(x);
                    track(a, trackDetail[x][j] ? 'U' : 'I', 0);
                }
                else exit(1);
            }
        }
    }

    void solve(){
        dfsChild(1);
        dfsU(1);
        dfsI(1);

//        printf("Dfs Results\n");
//        for(int i=1; i<=n; i++){
//            printf("%2d: U0 %2lld U1 %2lld I0 %2lld I1 %2lld\n", i, U[i][0], U[i][1], I[i][0], I[i][1]);
//        }

        printf("%lld\n", I[1][1]);
        track(1, 'I', 1);
        printf("%d\n", (int)ans.size());
        for(auto x: ans) printf("%d ", x);

        ll v = 0;
        for(auto x: ans) v += arr[x];
        assert(v == I[1][1]);
        for(int i=0; i<(int)ans.size()-1; i++){
            int x = ans[i], y = ans[i+1];
            assert(x==par[y] || y==par[x] || par[x]==par[y] || par[par[x]]==y || par[par[y]]==x);
        }

        for(int i=1; i<=n; i++) assert(U[i][0] >= U[i][1]);
//        for(int i=1; i<=n; i++) assert(typeI[i][0] < 3);

        set<int> st;
        for(auto x: ans) st.insert(x);
        assert(st.size() == ans.size());
    }
}

void solvek2(){
    solve2::solve();
}

namespace solve3{
    int par[200002];
    ll depth[200002];
    vector<int> child[200002];

    void dfs(int x, int p=-1){
        for(auto y: link[x]){
            if(y==p) continue;
            par[y] = x, depth[y] = depth[x] + 1;
            child[x].push_back(y);
            dfs(y, x);
        }
    }

    vector<int> ans;
    bool chk[200002];

    void solve(){
        dfs(1);

        int coolTime = 1;
        vector<int> vec (1, 1);
        while(!vec.empty()){
            int x = vec.back(); coolTime--;
            if(coolTime == 0){
                assert(!chk[x]);
                chk[x] = 1, ans.push_back(x), coolTime = 3;
            }
            if(child[x].empty()){
                if(!chk[x]) chk[x] = 1, ans.push_back(x), coolTime = 3;
                vec.pop_back();
            }
            else{
                int y = child[x].back(); child[x].pop_back();
                vec.push_back(y);
            }
        }

        printf("%lld\n%d\n", accumulate(arr+1, arr+n+1, 0LL), n);
        for(int x: ans) printf("%d ", x);
    }
}

void solvek3(){
    solve3::solve();
}

컴파일 시 표준 에러 (stderr) 메시지

Main.cpp: In function 'int main()':
Main.cpp:16:10: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
   16 |     scanf("%d %d", &n, &k);
      |     ~~~~~^~~~~~~~~~~~~~~~~
Main.cpp:19:14: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
   19 |         scanf("%d %d", &x, &y);
      |         ~~~~~^~~~~~~~~~~~~~~~~
Main.cpp:23:34: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
   23 |     for(int i=1; i<=n; i++) scanf("%lld", &arr[i]);
      |                             ~~~~~^~~~~~~~~~~~~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...