제출 #554676

#제출 시각아이디문제언어결과실행 시간메모리
554676kwongweng수열 (APIO14_sequence)C++17
100 / 100
1306 ms86540 KiB
/*
Solution for APIO 2014 - Sequence
Tags : dp, Convex-hull Trick (CHT)
*/
 
#include <bits/stdc++.h>
using namespace std;
 
#pragma GCC target ("avx2")
#pragma GCC optimization ("Ofast")
#pragma GCC optimization ("unroll-loops")
 
typedef long long ll;
typedef vector<int> vi;
typedef pair<ll, ll> ii;
typedef vector<ii> vii;
typedef long double ld;
typedef pair<ll, ll> pll;
#define FOR(i, a, b) for(int i = a; i < b; i++)
#define ROF(i, a, b) for(int i = a; i >= b; i--)
#define ms memset
#define pb push_back
#define fi first
#define se second
 
ll MOD = 1000000007;
 
ll power(ll base, ll n){
	if (n == 0) return 1;
	if (n == 1) return base;
	ll halfn = power(base, n/2);
	if (n % 2 == 0) return (halfn * halfn) % MOD;
	return (((halfn * halfn) % MOD) * base) % MOD;
}
 
ll inverse(ll n){
	return power(n, MOD-2);
}
 
ll add(ll a, ll b){
	return (a+b) % MOD;
}
 
ll mul(ll a, ll b){
	a %= MOD;
	return (a*b) % MOD;
}
 
ll gcd(ll a, ll b){
    if (a == 1) return 1;
    if (a == 0) return b;
    return gcd(b%a, a);
}
 
 
const int N = 100001;
vector<ll> m(N), c(N);
 
ld g(ii l){
    ll l1 = l.fi; ll l2 = l.se;
    ld a = c[l1]-c[l2];
    ld b = m[l2]-m[l1];
    if (b==0) return -MOD;
    return (ld) a/b;
}
 
void solve(){
    // CHT to optimise O(n^2 * k) into O(n*k)
    int n, k; cin >> n >> k;
    vector<ll> a(n+1); FOR(i,1,n+1) cin >> a[i];
    vector<ll> s(n+1); FOR(i,1,n+1) s[i]=a[i]+s[i-1];
    ll dp[n+1][2]; // replace usual dp[n+1][k+1] to reduce memory usage for last subtask
    int pos[n+1][k+1];
    ms(dp,-1,sizeof(dp));
    ms(pos,-1,sizeof(pos));
    FOR(i,1,n+1){
        dp[i][0]=s[i]*(s[n]-s[i]);
        // only 1 component
    }
    FOR(j,1,k+1){
        // m_l = 2*s[l]+s[n], m_l non-decreasing
        // x = s[i]
        // c_l = dp[l][j-1] - s[l] * (s[n]+s[l])
        // f_l(x) = x^2 + m_l x + c_l
        // l1 < l2, g(l1, l2) = (c_l1-c_l2)/(m_l2-m_l1)
        // f_l1(x) <= f_l2(x) <==>  g(l1, l2) <= x
        // f_l1(x) >= f_l2(x) || f_l2(x) <= f_l3(x)
        // g(l1, l2) >= x || g(l2, l3) <= x
        // g(l1, l2) >= g(l2, l3) <==> l2 ignored
        
        FOR(i,1,n+1){
            dp[i][j%2]=-1; // dp values to compute later on
            m[i] = 2*s[i]+s[n];
            c[i] = dp[i][(j+1)%2] - s[i] * (s[n] + s[i]);
        }
        int best_l = 1;
        // edge case : i = 2, only 1 possible value
        ll val = dp[1][(j+1)%2] + (s[2]-s[1])*(s[n]-s[2]+s[1]);
        if (val > dp[2][j%2] && dp[1][(j+1)%2] != -1){
            dp[2][j%2]=val;
            pos[2][j]=best_l;
        }
        list <ii> li; // stores pairs with increasing g(l1, l2) from left to right.
        int r_l = 1;
        FOR(i,3,n+1){
            best_l = r_l;
            ld cur;
            if (dp[i-1][(j+1)%2] != -1){
                r_l = i-1;
                best_l = i-1;
                ii c = {i-2, i-1};
                cur = g(c);
                while (!li.empty()){
                    ii u = li.back();
                    ld val = g(u);
                    if (val >= cur){
                        c.fi = u.fi;
                        cur = g(c);
                        // merge u and c since u.se, or c.fi can be ignored
                    }else{
                        break;
                    }
                    li.pop_back();
                }
                li.pb(c);
            }
            best_l = r_l;
            cur = s[i];
            while (!li.empty()){
                if (cur < g(li.front())){
                    ii u = li.front();
                    best_l = u.fi;
                    break;
                }
                li.pop_front();
                // remove values from the left since they are now smaller than cur
            }
            ll val1 = dp[best_l][(j+1)%2] + (s[i]-s[best_l])*(s[n]-s[i]+s[best_l]);
            if (val1 > dp[i][j%2]){
                dp[i][j%2]=val1;
                //cout<<i<<" "<<j<<" "<<val<<" "<<val-dp[l][j-1]<<'\n';
                pos[i][j]=best_l;
            }
            /*
            FOR(l,1,i){
                //O(K*N^2) brute force
                ll val = dp[l][j-1] + (s[i]-s[l])*(s[n]-s[i]+s[l]);
                if (val > dp[i][j] && dp[l][j-1] != -1){
                    dp[i][j]=val;
                    //cout<<i<<" "<<j<<" "<<val<<" "<<val-dp[l][j-1]<<'\n';
                    pos[i][j]=l;
                }
                
            }
            */
        }
    }
    cout << dp[n][k%2]/2 << '\n';
    int cur = n;
    vi b;
    ROF(i,k,1){
        cur = pos[cur][i];
        b.pb(cur);
    }
    ROF(i,k-1,0){
        cout << b[i]<<" ";
    }
    cout << '\n';
    return;
}
 
int main() {
    ios::sync_with_stdio(false);
    int TC = 1;
    //cin >> TC;
    FOR(i, 1, TC+1){
        //cout << "Case #" << i << ": ";
        solve();
    }
}

컴파일 시 표준 에러 (stderr) 메시지

sequence.cpp:10: warning: ignoring '#pragma GCC optimization' [-Wunknown-pragmas]
   10 | #pragma GCC optimization ("Ofast")
      | 
sequence.cpp:11: warning: ignoring '#pragma GCC optimization' [-Wunknown-pragmas]
   11 | #pragma GCC optimization ("unroll-loops")
      |
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...