#include "grader.h"
#include <bits/stdc++.h>
using namespace std;
// IOI 2010: 56 languages, symbols are in [1..65535]
static constexpr int K = 56;
static constexpr int S = 65536; // include index 0 (unused)
// Smoothing constants
static constexpr double ALPHA = 0.5; // Dirichlet smoothing for seen symbols + UNK bucket
static constexpr double BETA = 1.0; // prior smoothing for language frequency
// Counts: cnt[L][symbol]
static uint32_t cnt[K][S]; // 56 * 65536 * 4 bytes ≈ 14.7MB
static uint32_t total_sym[K]; // total symbols observed for each language
static uint32_t seen_types[K]; // number of distinct symbols observed for each language
static uint32_t num_examples[K];// number of excerpts observed for each language
// Seen flags for distinct counting
static bitset<S> seen[K]; // 56 * 65536 bits ≈ 0.46MB
static inline void learn(int L, const int E[100]) {
num_examples[L]++;
for (int i = 0; i < 100; i++) {
int x = E[i]; // 1..65535
total_sym[L]++;
if (!seen[L].test(x)) {
seen[L].set(x);
seen_types[L]++;
}
cnt[L][x]++;
}
}
void excerpt(int E[100]) {
int bestL = 0;
double bestScore = -1e300;
for (int L = 0; L < K; L++) {
// Prior: prefer languages we've seen a bit, but keep it gentle (BETA).
double score = log((double)num_examples[L] + BETA);
// We smooth over (seen_types[L] + 1) categories:
// - each seen symbol
// - one UNK bucket for "symbol not seen yet in this language"
double denom = (double)total_sym[L] + ALPHA * ((double)seen_types[L] + 1.0);
// Common term for 100 symbols
score -= 100.0 * log(denom);
// Add numerator terms
for (int i = 0; i < 100; i++) {
int x = E[i];
if (seen[L].test(x)) {
score += log((double)cnt[L][x] + ALPHA);
} else {
// UNK bucket
score += log(ALPHA);
}
}
if (score > bestScore) {
bestScore = score;
bestL = L;
}
}
int correct = language(bestL); // MUST be called exactly once
learn(correct, E);
}