Submission #1233768

#TimeUsernameProblemLanguageResultExecution timeMemory
1233768uranium235Boarding Passes (BOI22_passes)Java
100 / 100
773 ms289876 KiB
//package ojuz;

import java.io.*;
import java.util.*;

public class passes {
    public static void main(String[] args) throws IOException {
        BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));

        char[] a = reader.readLine().toCharArray();
        int n = a.length;
        int temp = 0;
        for (int i = 0; i < n; i++) temp = Math.max(temp, a[i] -= 'A');
        // reassign since values in lambdas need to be effectively final
        int g = temp + 1;

        // pref[i][j][k] stores for letters [0, i], if i is all seated, then what is the penalty from walking past
        // people in group i if people of group j were to board all from the front
        int[][][] pref = new int[g][g][n];
        // suff is the same but for letters [k, n) instead, and everyone boards from the back
        int[][][] suff = new int[g][g][n];
        // of the letters [0, i], how many are group j
        int[][] counts = new int[g][n];
        // total of how many people are group i
        int[] total = new int[g];
        // where the j-th occurrence of letter i is
        List<Integer>[] location = new List[g];
        Arrays.setAll(location, i -> new ArrayList<>());

        for (int i = 0; i < g; i++) {
            for (int j = 0; j < g; j++) {
                if (i == j) continue;
                int count = a[0] == i ? 1 : 0;
                for (int k = 1; k < n; k++) {
                    pref[i][j][k] = pref[i][j][k - 1];
                    if (i == a[k]) {
                        count++;
                    } else if (j == a[k]) {
                        pref[i][j][k] += count;
                    }
                }
                count = a[n - 1] == i ? 1 : 0;
                for (int k = n - 2; k >= 0; k--) {
                    suff[i][j][k] = suff[i][j][k + 1];
                    if (i == a[k]) {
                        count++;
                    } else if (j == a[k]) {
                        suff[i][j][k] += count;
                    }
                }
            }
        }
        for (char c : a) total[c]++;
        for (int i = 0; i < g; i++) {
            counts[i][0] = a[0] == i ? 1 : 0;
            for (int j = 1; j < n; j++) counts[i][j] = counts[i][j - 1] + (a[j] == i ? 1 : 0);
        }
        for (int i = 0; i < n; i++) location[a[i]].add(i);
        for (List<Integer> list : location) list.add(n);

        GetPenalty get = (pos, target, mask) -> {
            long result = 0;
            if (pos > 0) {
                for (int i = 0; i < g; i++) if ((mask & (1 << i)) != 0) {
                    result += 2L * pref[i][target][pos - 1];
                }
                result += counts[target][pos - 1] * (counts[target][pos - 1] - 1L) / 2;
            }
            if (pos < n) {
                for (int i = 0; i < g; i++) if ((mask & (1 << i)) != 0) result += 2L * suff[i][target][pos];
                int right = total[target] - (pos > 0 ? counts[target][pos - 1] : 0);
                result += right * (right - 1L) / 2;
            }
            return result;
        };

        long[] dp = new long[1 << g];
        Arrays.fill(dp, Long.MAX_VALUE / 2);
        dp[0] = 0;
        for (int i = 1; i < 1 << g; i++) {
            for (int j = 0; j < g; j++)
                if ((i & (1 << j)) != 0) {
                    int mask = i ^ (1 << j);
                    int lo = 0, hi = total[j];
                    while (lo < hi) {
                        int mid = lo + (hi - lo) / 2;
                        if (get.get(location[j].get(mid), j, mask) < get.get(location[j].get(mid + 1), j, mask))
                            hi = mid;
                        else lo = mid + 1;
                    }
                    long result = dp[mask] + get.get(location[j].get(lo), j, mask);
//                System.out.println("search with mask " + i + " and target " + j + " yielded " + lo + ", " + (result - dp[mask]));
                    dp[i] = Math.min(dp[i], result);
                }
        }

        long ans = dp[(1 << g) - 1];
        String out = String.valueOf(ans / 2);
        if (ans % 2 == 1) out += ".5";
        System.out.println(out);
    }

    @FunctionalInterface
    interface GetPenalty {
        long get(int pos, int target, int mask);
    }
}

Compilation message (stderr)

Note: passes.java uses unchecked or unsafe operations.
Note: Recompile with -Xlint:unchecked for details.

=======
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...