This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
public class biscuits {
long count_tastiness(long x, long[] a) {
int k = a.length;
long[] s = new long[k];
s[0] = a[0];
for (int i = 1; i < k; i++) {
s[i] = (1l << i) * a[i] + s[i - 1];
}
// ArrayList<Long> ys = new ArrayList<>();
// ys.add(0l);
HashMap<Long, long[]> map = new HashMap<>();
long eval = s[0] / x;
map.put(eval, new long[] {0l, 0l});
// System.out.println("Map for layer " + 0 + ": \n");
// for (long ev : map.keySet()) {
// System.out.println(ev + ": " + Arrays.toString(map.get(ev)));
// }
// System.out.println();
for (int i = 1; i < k; i++) {
ArrayList<Long> newYs = new ArrayList<>();
HashMap<Long, long[]> newMap = new HashMap<>();
HashMap<Long, ArrayList<long[]>> tempMap = new HashMap<>();
for (long ev : map.keySet()) {
long[] range = map.get(ev);
// range[0] <= y <= range[1]. ev = [(s[i-1]-y*x)/(x * 2^(i-1))], for each y in range.
long min = eval(s, i, x, range[1]);
long max = eval(s, i, x, range[0]);
if (min == max) {
if (!tempMap.containsKey(min)) {
tempMap.put(min, new ArrayList<>());
}
tempMap.get(min).add(new long[] { range[0], range[1] });
} else {
// max = min + 1. Find m s.t. eval(m) = eval(m + 1) + 1.
// Find max m s.t. eval(m) > min.
long low = range[0], high = range[1];
while (low < high) {
long mid = (low + high + 1) / 2;
if (eval(s, i, x, mid) > min) {
low = mid;
} else {
high = mid - 1;
}
}
// m = low.
if (!tempMap.containsKey(max)) {
tempMap.put(max, new ArrayList<>());
}
tempMap.get(max).add(new long[] { range[0], low});
if (!tempMap.containsKey(min)) {
tempMap.put(min, new ArrayList<>());
}
tempMap.get(min).add(new long[] { low + 1, range[1] });
}
if (ev > 0) {
long pow2 = (1l << (i - 1));
min = eval(s, i, x, range[1] + pow2);
max = eval(s, i, x, range[0] + pow2);
if (min == max) {
long[] newRange = new long[] { range[0] + pow2, range[1] + pow2 };
if (!tempMap.containsKey(min)) {
tempMap.put(min, new ArrayList<>());
}
tempMap.get(min).add(newRange);
} else {
// max = min + 1. Find m s.t. eval(m) = eval(m + 1) + 1.
// Find max m s.t. eval(m) > min.
long low = range[0], high = range[1];
while (low < high) {
long mid = (low + high + 1) / 2;
if (eval(s, i, x, mid + pow2) > min) {
low = mid;
} else {
high = mid - 1;
}
}
// m = low.
if (!tempMap.containsKey(min)) {
tempMap.put(min, new ArrayList<>());
}
tempMap.get(min).add(new long[] { low + 1 + pow2, range[1] + pow2 });
if (!tempMap.containsKey(max)) {
tempMap.put(max, new ArrayList<>());
}
tempMap.get(max).add(new long[] { range[0] + pow2, low + pow2 });
}
}
}
map = transform(tempMap);
// System.out.println("Map for layer " + i + ": \n");
// for (long ev : map.keySet()) {
// System.out.println(ev + ": " + Arrays.toString(map.get(ev)));
// }
// System.out.println();
}
long ans = 0;
for (long ev : map.keySet()) {
long[] range = map.get(ev);
//System.out.println(ev + ", range = " + Arrays.toString(range));
ans += (eval + 1) * (range[1] - range[0] + 1);
}
return ans;
}
HashMap<Long, long[]> transform(HashMap<Long, ArrayList<long[]>> map) {
HashMap<Long, long[]> res = new HashMap<>();
for (long key : map.keySet()) {
ArrayList<long[]> ranges = map.get(key);
long min = ranges.get(0)[0], max = ranges.get(0)[1];
for (long[] range : ranges) {
min = Math.min(min, range[0]);
max = Math.max(max, range[1]);
}
res.put(key, new long[] { min, max });
}
return res;
}
long[] mergeRanges(long[] r1, long[] r2) {
System.out.println("Merging " + Arrays.toString(r1) + " with " + Arrays.toString(r2));
if (r1[1] == r2[0] - 1) {
return new long[] { r1[0], r2[1] };
} else if (r2[1] == r1[0] - 1) {
return new long[] { r2[0], r1[1] };
} return null;
}
long eval(long[] s, int i, long x, long y) {
// (s[i] - x * y) / (2^i * x);
if (y >= s[i] / x) return 0;
long up = s[i] - x * y;
return (up / (1l << i)) / x;
}
long countRec(long x, long[] a, int index) {
if (index == a.length - 1) {
return a[a.length - 1] / x + 1;
}
long temp = a[index + 1];
a[index + 1] = a[index + 1] + a[index] / 2;
long answer = countRec(x, a, index + 1);
if (a[index] >= x) {
a[index + 1] = temp + (a[index] - x) / 2;
answer += countRec(x, a, index + 1);
a[index + 1] = temp;
}
a[index + 1] = temp;
return answer;
}
}
// (s[k-1] - i * X) / (2^(k-1)), 0 <= i < 2^(k - 1).
// For which i is this state valid?
// If for each position b s.t. b-th bit is set in i, (s[b+1] - (2^b + prev(i,b))X) / 2^(b+1) >= X.
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |