#include "grader.h"
#include <bits/stdc++.h>
using namespace std;
static constexpr int K = 56;
static constexpr int S = 65536; // symbols: 1..65535
// Hash bins (tunable vs memory/collisions)
static constexpr int B2 = 1 << 18; // 262144
static constexpr int B3 = 1 << 17; // 131072 (keep smaller to limit RAM)
// Tunables
static constexpr double ALPHA1 = 0.25; // unigram smoothing
static constexpr double ALPHA2 = 0.25; // bigram smoothing
static constexpr double ALPHA3 = 0.20; // trigram smoothing
static constexpr double BETA = 1.0; // prior smoothing
static constexpr double W2 = 0.9; // bigram weight
static constexpr double W3 = 0.7; // trigram weight (tune)
// -------------------- Unigram model (exact) --------------------
static uint32_t cnt1[K][S];
static uint32_t tot1[K];
static uint32_t types1[K];
static uint32_t ex[K];
static bitset<S> seen1[K];
// -------------------- Bigram model (hashed) --------------------
static uint32_t cnt2[K][B2];
static uint32_t tot2[K];
static uint32_t types2[K];
static bitset<B2> seen2[K];
// -------------------- Trigram model (hashed) --------------------
static uint32_t cnt3[K][B3];
static uint32_t tot3[K];
static uint32_t types3[K];
static bitset<B3> seen3[K];
// Fast avalanche mix
static inline uint32_t mix32(uint32_t x) {
x ^= x >> 16;
x *= 0x7feb352du;
x ^= x >> 15;
x *= 0x846ca68bu;
x ^= x >> 16;
return x;
}
static inline uint32_t h2(uint32_t a, uint32_t b) {
uint32_t x = a * 0x9e3779b1u ^ (b + 0x85ebca6bu);
return mix32(x) & (B2 - 1);
}
static inline uint32_t h3(uint32_t a, uint32_t b, uint32_t c) {
// combine 3 symbols, then avalanche
uint32_t x = a * 0x9e3779b1u;
x ^= b * 0x85ebca6bu + 0xc2b2ae35u;
x ^= c * 0x27d4eb2du + 0x165667b1u;
return mix32(x) & (B3 - 1);
}
static inline void learn(int L, const int E[100]) {
ex[L]++;
// ----- unigrams -----
for (int i = 0; i < 100; i++) {
int x = E[i];
tot1[L]++;
if (!seen1[L].test(x)) {
seen1[L].set(x);
types1[L]++;
}
cnt1[L][x]++;
}
// ----- bigrams -----
for (int i = 0; i + 1 < 100; i++) {
uint32_t id = h2((uint32_t)E[i], (uint32_t)E[i + 1]);
tot2[L]++;
if (!seen2[L].test(id)) {
seen2[L].set(id);
types2[L]++;
}
cnt2[L][id]++;
}
// ----- trigrams -----
for (int i = 0; i + 2 < 100; i++) {
uint32_t id = h3((uint32_t)E[i], (uint32_t)E[i + 1], (uint32_t)E[i + 2]);
tot3[L]++;
if (!seen3[L].test(id)) {
seen3[L].set(id);
types3[L]++;
}
cnt3[L][id]++;
}
}
void excerpt(int E[100]) {
// Pre-hash for this excerpt once
uint32_t big[99];
uint32_t tri[98];
for (int i = 0; i < 99; i++) big[i] = h2((uint32_t)E[i], (uint32_t)E[i + 1]);
for (int i = 0; i < 98; i++) tri[i] = h3((uint32_t)E[i], (uint32_t)E[i + 1], (uint32_t)E[i + 2]);
int bestL = 0;
double bestScore = -1e300;
for (int L = 0; L < K; L++) {
// Prior (smoothed)
double score = log((double)ex[L] + BETA);
// ---------------- unigram ----------------
double denom1 = (double)tot1[L] + ALPHA1 * ((double)types1[L] + 1.0);
score -= 100.0 * log(denom1);
for (int i = 0; i < 100; i++) {
int x = E[i];
if (seen1[L].test(x)) score += log((double)cnt1[L][x] + ALPHA1);
else score += log(ALPHA1);
}
// ---------------- bigram ----------------
double s2 = 0.0;
double denom2 = (double)tot2[L] + ALPHA2 * ((double)types2[L] + 1.0);
s2 -= 99.0 * log(denom2);
for (int i = 0; i < 99; i++) {
uint32_t id = big[i];
if (seen2[L].test(id)) s2 += log((double)cnt2[L][id] + ALPHA2);
else s2 += log(ALPHA2);
}
// ---------------- trigram ----------------
double s3 = 0.0;
double denom3 = (double)tot3[L] + ALPHA3 * ((double)types3[L] + 1.0);
s3 -= 98.0 * log(denom3);
for (int i = 0; i < 98; i++) {
uint32_t id = tri[i];
if (seen3[L].test(id)) s3 += log((double)cnt3[L][id] + ALPHA3);
else s3 += log(ALPHA3);
}
score += W2 * s2 + W3 * s3;
if (score > bestScore) {
bestScore = score;
bestL = L;
}
}
int correct = language(bestL);
learn(correct, E);
}