답안 #231962

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
231962 2020-05-15T14:17:51 Z nicolaalexandra Chase (CEOI17_chase) C++14
0 / 100
560 ms 489208 KB
#include <bits/stdc++.h>
#define DIM 100002
#define INF 2000000000000000000LL
using namespace std;
vector <int> L[DIM];
int v[DIM],level[DIM],fth[DIM];
long long sum[DIM];
int n,x,y,i,j,nr,k;
struct idk{
    long long maxi,maxi2;
    int fiu,fiu2;
} dp[DIM][101][2];

long long dp_up[DIM][101][2];
/// dp[nod][i][0/1] = care e suma maxima daca incep undeva in subarborele lui nod,
/// ajung in nod, am consumat i boabe si am una in nod sau nu
void dfs (int nod, int tata){

    int ok = 0;
    fth[nod] = tata;
    for (auto vecin : L[nod])
        if (vecin != tata){
            ok = 1;
            dfs (vecin,nod);
        }
    if (!ok){ /// frunza
        dp[nod][1][1].maxi = v[fth[nod]] - v[nod];
    } else {
        /// pun boaba in nod
        dp[nod][1][1].maxi = sum[nod] - v[nod];
        for (int i=2;i<=nr;i++){
            long long maxi = -INF, maxi2 = -INF; int fiu = 0, fiu2 = 0;

            for (auto vecin : L[nod]){
                if (vecin == tata)
                    continue;

                if (dp[vecin][i-1][0].maxi != -INF){
                    long long val = sum[nod] - v[vecin] - v[nod] + dp[vecin][i-1][0].maxi;
                    if (val > maxi){
                        maxi2 = maxi, fiu2 = fiu;
                        maxi = val, fiu = vecin;
                    } else {
                        if (val > maxi2)
                            maxi2 = val, fiu2 = vecin;
                    }}
                if (dp[vecin][i-1][1].maxi != -INF){
                    long long val = sum[nod] - v[vecin] - v[nod] + dp[vecin][i-1][1].maxi;
                    if (val > maxi){
                        maxi2 = maxi, fiu2 = fiu;
                        maxi = val, fiu = vecin;
                    } else {
                        if (val > maxi2)
                            maxi2 = val, fiu2 = vecin;
                    }}}

            dp[nod][i][1].maxi = maxi, dp[nod][i][1].fiu = fiu;
            dp[nod][i][1].maxi2 = maxi2, dp[nod][i][1].fiu2 = fiu2;

        }

        /// nu pun boaba in nod

        for (int i=1;i<=nr;i++){
            long long maxi = -INF, maxi2 = -INF; int fiu = 0, fiu2 = 0;
            for (auto vecin : L[nod]){
                if (vecin == tata)
                    continue;

                if (dp[vecin][i][0].maxi != -INF){
                    long long val = dp[vecin][i][0].maxi - v[nod];
                    if (val > maxi){
                        maxi2 = maxi, fiu2 = fiu;
                        maxi = val, fiu = vecin;
                    } else {
                        if (val > maxi2)
                            maxi2 = val, fiu2 = vecin;
                    }}

                if (dp[vecin][i][1].maxi != -INF){
                    long long val = dp[vecin][i][1].maxi;
                    if (val > maxi){
                        maxi2 = maxi, fiu2 = fiu;
                        maxi = val, fiu = vecin;
                    } else {
                        if (val > maxi2)
                            maxi2 = val, fiu2 = vecin;
                    }}}

            dp[nod][i][0].maxi = maxi, dp[nod][i][0].fiu = fiu;
            dp[nod][i][0].maxi2 = maxi2, dp[nod][i][0].fiu2 = fiu2;
        }}}
/// dp_up[nod][i][0/1]
void dfs2 (int nod, int tata){

    int nr_fii = 0;
    for (auto vecin : L[nod]){
        if (vecin != tata){
            nr_fii++;
            if (nr_fii > 1)
                break;
        }}

    /// calculez strict din dp_up[tata]
    if (nod != 1){
        /// pun boaba in nod
        for (int i=2;i<=nr;i++){
            if (dp_up[tata][i-1][0] != -INF){
                long long val = dp_up[tata][i-1][0] + sum[nod] - v[nod] - v[tata];
                dp_up[nod][i][1] = max (dp_up[nod][i][1],val);
            }
            if (dp_up[tata][i-1][1] != -INF){
                long long val = dp_up[tata][i-1][1] + sum[nod] - v[tata] - v[nod];
                dp_up[nod][i][1] = max (dp_up[nod][i][1],val);
            }
        }
        /// nu pun boaba in nod
        for (int i=1;i<=nr;i++){
            if (dp_up[tata][i][0] != -INF){
                long long val = dp_up[tata][i][0] - v[nod];
                dp_up[nod][i][0] = max (dp_up[nod][i][0],val);
            }
            if (dp_up[tata][i][1] != -INF)
                dp_up[nod][i][0] = max (dp_up[nod][i][0],dp_up[tata][i][1]);
        }

        /// acum facem in functie de dp
        if (nr_fii > 1 && nod != 1){
            /// pun boaba
            for (int i=2;i<=nr;i++){
                long long val;
                if (dp[tata][i-1][0].fiu == nod)
                    val = dp[tata][i-1][0].maxi2 + sum[nod] - v[nod] - v[tata];
                else val = dp[tata][i-1][0].maxi + sum[nod] - v[nod] - v[tata];

                dp_up[nod][i][1] = max (dp_up[nod][i][1],val);
                ////////////////////////

                if (dp[tata][i-1][1].fiu == nod)
                    val = dp[tata][i-1][1].maxi2 + sum[nod] - v[nod] - v[tata];
                else val = dp[tata][i-1][1].maxi + sum[nod] - v[nod] - v[tata];

                dp_up[nod][i][1] = max (dp_up[nod][i][1],val);
            }

            /// nu pun boaba
            for (int i=1;i<=nr;i++){
                long long val;
                if (dp[tata][i][0].fiu == nod)
                    val = dp[tata][i][0].maxi2 - v[nod];
                else val = dp[tata][i][0].maxi - v[nod];

                dp_up[nod][i][0] = max (dp_up[nod][i][0],val);

                if (dp[tata][i][1].fiu == nod)
                    val = dp[tata][i][1].maxi2 - v[nod];
                else val = dp[tata][i][1].maxi - v[nod];

                dp_up[nod][i][0] = max (dp_up[nod][i][0],val);
            }
        }
    }
    for (auto vecin : L[nod])
        if (vecin != tata)
            dfs2 (vecin,nod);
}
int main (){

    //ifstream cin ("date.in");
    //ofstream cout ("date.out");

    cin>>n>>nr;
    for (i=1;i<=n;i++)
        cin>>v[i];
    for (i=1;i<n;i++){
        cin>>x>>y;
        L[x].push_back(y);
        L[y].push_back(x);
    }

    for (i=1;i<=n;i++){
        sum[i] = v[i];
        for (auto it : L[i])
            sum[i] += v[it];
    }

    if (nr == 1){
        long long maxi = 0;
        for (i=1;i<=n;i++)
            maxi = max (maxi,sum[i] - v[i]);
        cout<<maxi;
        return 0;
    }

    for (i=1;i<=n;i++)
        for (j=1;j<=nr;j++)
            dp[i][j][0].maxi = dp[i][j][1].maxi = -INF;

    dfs (1,0);

    long long maxi = 0;
    for (i=1;i<=n;i++)
        for (j=1;j<=nr;j++){
            maxi = max (maxi,max(dp[i][j][0].maxi,dp[i][j][1].maxi));
            maxi = max (maxi,max(dp_up[i][j][0],dp_up[i][j][1]));
        }

    cout<<maxi;


    return 0;
}
# 결과 실행 시간 메모리 Grader output
1 Correct 6 ms 2816 KB Output is correct
2 Incorrect 6 ms 2816 KB Output isn't correct
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 6 ms 2816 KB Output is correct
2 Incorrect 6 ms 2816 KB Output isn't correct
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 560 ms 489208 KB Output isn't correct
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 6 ms 2816 KB Output is correct
2 Incorrect 6 ms 2816 KB Output isn't correct
3 Halted 0 ms 0 KB -