Submission #776689

#TimeUsernameProblemLanguageResultExecution timeMemory
776689mousebeaverCatfish Farm (IOI22_fish)C++17
100 / 100
346 ms55500 KiB
        #define ll long long
        #define pll pair<ll, ll>
        #define INF numeric_limits<ll>::max()
        #include "fish.h"
        #include <bits/stdc++.h>
        using namespace std;
         
        struct fish
        {
            ll height;
            ll weight;
            ll below;
            ll above; //Sums of weights in column
            ll bigPier; //No bigger pier in the previous column -> first one to be caught
            ll noPier; //No pier in this column -> first one to be caught
            ll smallPier; //Bigger pier in the previous column -> first one not below a pier
        };
         
        bool operator < (fish a, fish b)
        {
            return (a.height < b.height);
        }
     
        ll transitionBig(vector<vector<fish>>& dp, ll i, ll preind, ll index)
        {
            //Return dp[i][index].bigPier based on dp[i-1][preind]
            if(dp[i][index].height < dp[i-1][preind].height)
            {
                //The given pier is not bigger!
                return 0;
            }
     
            ll sum = 0; //Sum of the fish that get catched additionally
            if(dp[i][index].height > dp[i-1][preind].height)
            {
                //The piers have different size
                //Search for the last caught fish
                ll left = preind;
                ll right = dp[i-1].size()-1;
                while(left < right)
                {
                    ll mid = (left+right+1)/2;
                    if(dp[i-1][mid].height < dp[i][index].height)
                    {
                        left = mid;
                    }
                    else
                    {
                        right = mid-1;
                    }
                }
                sum = dp[i-1][left].below + dp[i-1][left].weight - dp[i-1][preind].below;
            }
     
            return max(dp[i-1][preind].bigPier, dp[i-1][preind].noPier) + sum;
        }

        ll intersectBig(vector<vector<fish>>& dp, ll i, ll first, ll second)
        {
            //First index in dp[i] with second (> first) being the at least as good as transition base in dp[i-1]
            if(transitionBig(dp, i, first, dp[i].size()-1) > transitionBig(dp, i, second, dp[i].size()-1))                                                                     
            {
                //It never happens!
                return INF;
            }

            ll left = 0;
            ll right = dp[i].size()-1;
            while(left < right)
            {
                ll mid = (left+right)/2;
                if(transitionBig(dp, i, first, mid) > transitionBig(dp, i, second, mid))
                {
                    left = mid+1;
                }
                else
                {
                    right = mid;
                }
            }
            return left;
        }
     
        ll transitionSmall(vector<vector<fish>>& dp, ll i, ll preind, ll index)
        {
            //Return dp[i][index].smallPier based on dp[i-1][preind]
            if(dp[i][index].height >= dp[i-1][preind].height)
            {
                //The given pier is not strictly smaller!
                return 0;
            }
     
            ll sum = 0; //Sum of the fish that get catched additionally
            //Search for the last caught fish
            ll left = index;
            ll right = dp[i].size()-1;
            while(left < right)
            {
                ll mid = (left+right+1)/2;
                if(dp[i][mid].height < dp[i-1][preind].height)
                {
                    left = mid;
                }
                else
                {
                    right = mid-1;
                }
            }
            sum = dp[i][left].below + dp[i][left].weight - dp[i][index].below;
     
            return max(dp[i-1][preind].bigPier, dp[i-1][preind].smallPier) + sum;
        }

        ll intersectSmall(vector<vector<fish>>& dp, ll i, ll high, ll low)
        {
            //Highest index in dp[i] with low being the at least as good as transition base in dp[i-1]
            if(transitionSmall(dp, i, high, 0) > transitionSmall(dp, i, low, 0))
            {
                //It never happens!
                return -1;
            }

            ll left = 0;
            ll right = dp[i].size()-1;
            while(left < right)
            {
                ll mid = (left+right+1)/2;
                if(transitionSmall(dp, i, low, mid) >= transitionSmall(dp, i, high, mid))
                {
                    left = mid;
                }
                else
                {
                    right = mid-1;
                }
            }
            return left;
        }        
         
        long long max_weights(int N, int M, std::vector<int> X, std::vector<int> Y, std::vector<int> W) 
        {
            bool sub1 = true;
            bool sub2 = true;
            bool sub3 = true;
         
            for(ll i = 0; i < M; i++)
            {
                if(X[i] % 2 == 1)
                {
                    sub1 = false;
                }
                if(X[i] > 1)
                {
                    sub2 = false;
                }
                if(Y[i] != 0)
                {
                    sub3 = false;
                }
            }
         
            if(sub1)
            {
                ll sum = 0;
                for(int i : W)
                {
                    sum += (ll) i;
                }
                return sum;
            }
         
            if(sub2)
            {
                vector<pll> left(0);
                vector<pll> right(0); //Height, weight
                ll lsum = 0;
                ll rsum = 0;
         
                for(ll i = 0; i < M; i++)
                {
                    if(X[i] == 0)
                    {
                        left.push_back({Y[i], W[i]});
                        lsum += W[i];
                    }
                    else
                    {
                        right.push_back({Y[i], W[i]});
                        rsum += W[i];
                    }
                }
                sort(left.begin(), left.end());
                sort(right.begin(), right.end());
         
                ll output = max(lsum, rsum);
         
                if(N > 2)
                {
                    ll lindex = -1;
                    ll rindex = -1;
                    ll shadow = 0;
                    ll roof = 0;
                    for(ll i = 0; i < N; i++)
                    {
                        while(lindex+1 < (ll) left.size() && left[lindex+1].first <= i)
                        {
                            lindex++;
                            shadow += left[lindex].second;
                        }
                        while(rindex+1 < (ll) right.size() && right[rindex+1].first <= i)
                        {
                            rindex++;
                            roof += right[rindex].second;
                        }
                        output = max(output, shadow + rsum - roof);
                    }
                }
         
                return output;
            }
            
            if(sub3)
            {
                vector<ll> w(N, 0);
                for(ll i = 0; i < M; i++)
                {
                    w[X[i]] = W[i];
                }
         
                vector<vector<ll>> dp(N, vector<ll> (3, 0)); //Pier, no pier + uncaught, no pier + caught
                for(ll i = 1; i < N; i++)
                {
                    //Pier:
                    dp[i][0] = dp[i-1][0];
                    dp[i][0] = max(dp[i][0], dp[i-1][1]+w[i-1]);
                    dp[i][0] = max(dp[i][0], dp[i-1][2]);
         
                    //no pier + uncaught:
                    dp[i][1] = dp[i-1][1];
                    dp[i][1] = max(dp[i][1], dp[i-1][2]);
         
                    //no pier + caught:
                    dp[i][2] = dp[i-1][0]+w[i];
                }
         
                return max(max(dp[N-1][0], dp[N-1][1]), dp[N-1][2]);
            }
         
            vector<vector<fish>> grid(N, vector<fish> (0));
            for(ll i = 0; i < M; i++)
            {
                fish f;
                f.height = Y[i];
                f.weight = W[i];
                f.bigPier = 0;
                f.smallPier = 0;
                f.noPier = 0;
                grid[X[i]].push_back(f);
            }
         
            fish top, bottom;
            top.height = N;
            top.weight = 0;
            top.noPier = 0;
            top.bigPier = 0;
            top.smallPier = 0;
            bottom.height = -1;
            bottom.weight = 0;
            bottom.smallPier = 0;
            bottom.bigPier = 0;
            bottom.noPier = 0;
            for(ll i = 0; i < N; i++)
            {
                grid[i].push_back(top);
                grid[i].push_back(bottom);
                sort(grid[i].begin(), grid[i].end());
                ll sum = 0;
                for(ll j = 0; j < (ll) grid[i].size(); j++)
                {
                    grid[i][j].below = sum;
                    sum += grid[i][j].weight;
                }
                sum = 0;
                for(ll j = grid[i].size()-1; j >= 0; j--)
                {
                    grid[i][j].above = sum;
                    sum += grid[i][j].weight;
                }
            }
         
            for(ll i = 0; i < (ll) grid[0].size(); i++)
            {
                grid[0][i].bigPier = 0;
                grid[0][i].smallPier = 0;
                grid[0][i].noPier = 0;
            }
         
            for(ll i = 1; i < N; i++)
            {   
                //no pier:
                ll preind = 0;
                ll postind = 0;
                ll premax = 0;
                while(preind < (ll) grid[i-1].size())
                {
                    while(postind < (ll) grid[i].size() && grid[i-1][preind].height >= grid[i][postind].height)
                    {
                        grid[i][postind].noPier = max(grid[i-1][preind].bigPier, grid[i-1][preind].smallPier)+grid[i][postind].below;
                        postind++;
                    }
                    premax = max(premax, grid[i-1][preind].noPier);
                    preind++;
                }
                for(ll j = 0; j < (ll) grid[i].size(); j++)
                {
                    grid[i][j].noPier = max(grid[i][j].noPier, premax);
                }
                
                //calculate DP[i][j]:
                vector<pll> opt = {{0, 0}}; //Index in dp[i], index in dp[i-1]
                for(ll j = 1; j < (ll) grid[i-1].size(); j++)
                {
                    while(opt.size() && intersectBig(grid, i, opt.back().second, j) <= opt.back().first)
                    {
                        opt.pop_back();
                    }
                    if(opt.size())
                    {
                        opt.push_back({intersectBig(grid, i, opt.back().second, j), j});
                    }
                    else
                    {
                        opt.push_back({0, j});
                    }
                }
                ll optind = 0;
                for(ll j = 0; j < (ll) grid[i].size(); j++)
                {
                    //bigPier
                    while(optind < (ll) opt.size()-1 && opt[optind+1].first <= j)
                    {
                        optind++;
                    }
                    grid[i][j].bigPier = transitionBig(grid, i, opt[optind].second, j);
                }

                opt = {{grid[i].size()-1, grid[i-1].size()-1}}; //Index in dp[i], index in dp[i-1]
                for(ll j = grid[i-1].size()-2; j >= 0; j--)
                {
                    while(opt.size() && intersectSmall(grid, i, opt.back().second, j) >= opt.back().first)
                    {
                        opt.pop_back();
                    }
                    if(opt.size())
                    {
                        opt.push_back({intersectSmall(grid, i, opt.back().second, j), j});
                    }
                    else
                    {
                        opt.push_back({grid[i].size()-1, j});
                    }
                }
                optind = 0;
                for(ll j = grid[i].size()-1; j >= 0; j--)
                {
                    //smallPier
                    while(optind < (ll) opt.size()-1 && opt[optind+1].first >= j)
                    {
                        optind++;
                    }
                    grid[i][j].smallPier = transitionSmall(grid, i, opt[optind].second, j);
                }
            }
         
            ll output = 0;
            for(fish f : grid[N-1])
            {
                ll val = max(max(f.noPier, f.bigPier), f.smallPier);
                output = max(output, val);
            }
            
            return output;
        }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...