Submission #305425

#TimeUsernameProblemLanguageResultExecution timeMemory
305425llakiPacking Biscuits (IOI20_biscuits)Java
0 / 100
636 ms58208 KiB
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;

public class biscuits {
    long count_tastiness(long x, long[] a) {
        int k = a.length;
        long[] s = new long[k];
        s[0] = a[0];
        for (int i = 1; i < k; i++) {
            s[i] = (1l << i) * a[i] + s[i - 1];
        }
//        ArrayList<Long> ys = new ArrayList<>();
//        ys.add(0l);
        HashMap<Long, long[]> map = new HashMap<>();
        long eval = s[0] / x;
        map.put(eval, new long[] {0l, 0l});
//        System.out.println("Map for layer " + 0 + ": \n");
//        for (long ev : map.keySet()) {
//            System.out.println(ev + ": " + Arrays.toString(map.get(ev)));
//        }
//        System.out.println();
        for (int i = 1; i < k; i++) {
            ArrayList<Long> newYs = new ArrayList<>();
            HashMap<Long, long[]> newMap = new HashMap<>();
            HashMap<Long, ArrayList<long[]>> tempMap = new HashMap<>();
            for (long ev : map.keySet()) {
                long[] range = map.get(ev);
                // range[0] <= y <= range[1]. ev = [(s[i-1]-y*x)/(x * 2^(i-1))], for each y in range.
                long min = eval(s, i, x, range[1]);
                long max = eval(s, i, x, range[0]);
                if (min == max) {
                    if (!tempMap.containsKey(min)) {
                        tempMap.put(min, new ArrayList<>());
                    }
                    tempMap.get(min).add(new long[] { range[0], range[1] });
                } else {
                    // max = min + 1. Find m s.t. eval(m) = eval(m + 1) + 1.
                    // Find max m s.t. eval(m) > min.
                    long low = range[0], high = range[1];
                    while (low < high) {
                        long mid = (low + high + 1) / 2;
                        if (eval(s, i, x, mid) > min) {
                            low = mid;
                        } else {
                            high = mid - 1;
                        }
                    }
                    // m = low.
                    if (!tempMap.containsKey(max)) {
                        tempMap.put(max, new ArrayList<>());
                    }
                    tempMap.get(max).add(new long[] { range[0], low});
                    if (!tempMap.containsKey(min)) {
                        tempMap.put(min, new ArrayList<>());
                    }
                    tempMap.get(min).add(new long[] { low + 1, range[1] });
                }
                if (ev > 0) {
                    long pow2 = (1l << (i - 1));
                    min = eval(s, i, x, range[1] + pow2);
                    max = eval(s, i, x, range[0] + pow2);
                    if (min == max) {
                        long[] newRange = new long[] { range[0] + pow2, range[1] + pow2 };
                        if (!tempMap.containsKey(min)) {
                            tempMap.put(min, new ArrayList<>());
                        }
                        tempMap.get(min).add(newRange);
                    } else {
                        // max = min + 1. Find m s.t. eval(m) = eval(m + 1) + 1.
                        // Find max m s.t. eval(m) > min.
                        long low = range[0], high = range[1];
                        while (low < high) {
                            long mid = (low + high + 1) / 2;
                            if (eval(s, i, x, mid + pow2) > min) {
                                low = mid;
                            } else {
                                high = mid - 1;
                            }
                        }
                        // m = low.
                        if (!tempMap.containsKey(min)) {
                            tempMap.put(min, new ArrayList<>());
                        }
                        tempMap.get(min).add(new long[] { low + 1 + pow2, range[1] + pow2 });

                        if (!tempMap.containsKey(max)) {
                            tempMap.put(max, new ArrayList<>());
                        }
                        tempMap.get(max).add(new long[] { range[0] + pow2, low + pow2 });
                    }
                }
            }
            map = transform(tempMap);
//            System.out.println("Map for layer " + i + ": \n");
//            for (long ev : map.keySet()) {
//                System.out.println(ev + ": " + Arrays.toString(map.get(ev)));
//            }
//            System.out.println();
        }
        long ans = 0;
        for (long ev : map.keySet()) {
            long[] range = map.get(ev);
            //System.out.println(ev + ", range = " + Arrays.toString(range));
            ans += (eval + 1) * (range[1] - range[0] + 1);
        }
        return ans;
    }

    HashMap<Long, long[]> transform(HashMap<Long, ArrayList<long[]>> map) {
        HashMap<Long, long[]> res = new HashMap<>();
        for (long key : map.keySet()) {
            ArrayList<long[]> ranges = map.get(key);
            long min = ranges.get(0)[0], max = ranges.get(0)[1];
            for (long[] range : ranges) {
                min = Math.min(min, range[0]);
                max = Math.max(max, range[1]);
            }
            res.put(key, new long[] { min, max });
        }
        return res;
    }

    long[] mergeRanges(long[] r1, long[] r2) {
        System.out.println("Merging " + Arrays.toString(r1) + " with " + Arrays.toString(r2));
        if (r1[1] == r2[0] - 1) {
            return new long[] { r1[0], r2[1] };
        } else if (r2[1] == r1[0] - 1) {
            return new long[] { r2[0], r1[1] };
        } return null;
    }

    long eval(long[] s, int i, long x, long y) {
        // (s[i] - x * y) / (2^i * x);
        if (y >= s[i] / x) return 0;
        long up = s[i] - x * y;
        return (up / (1l << i)) / x;
    }

    long countRec(long x, long[] a, int index) {
        if (index == a.length - 1) {
            return a[a.length - 1] / x + 1;
        }
        long temp = a[index + 1];
        a[index + 1] = a[index + 1] + a[index] / 2;
        long answer = countRec(x, a, index + 1);
        if (a[index] >= x) {
            a[index + 1] = temp + (a[index] - x) / 2;
            answer += countRec(x, a, index + 1);
            a[index + 1] = temp;
        }
        a[index + 1] = temp;
        return answer;
    }

}
// (s[k-1] - i * X) / (2^(k-1)), 0 <= i < 2^(k - 1).
// For which i is this state valid?
// If for each position b s.t. b-th bit is set in i, (s[b+1] - (2^b + prev(i,b))X) / 2^(b+1) >= X.

#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...