#include "lang.h"
#include "grader.h"
#include <cmath>
#include <algorithm>
#include <cstring>
#include <unordered_map>
static const int NL = 56;
static const int MS = 65536;
// ── unigram data (14 MB array for O(1) lookup) ──
static int sym[NL][MS];
static int sym_total[NL];
// ── n-gram data (2-, 3-, 4-grams stored as frequency maps per language) ──
static std::unordered_map<int64_t, int> ng2[NL];
static std::unordered_map<int64_t, int> ng3[NL];
static std::unordered_map<int64_t, int> ng4[NL];
// ── n-gram totals per language (for computing P(ngram|lang)) ──
static int ng2_total[NL];
static int ng3_total[NL];
static int ng4_total[NL];
// ── bookkeeping ──
static int cnt[NL]; // excerpts seen per language
static int N = 0; // total excerpts seen
static int V = 0; // global vocabulary size
static bool seen[MS];
// ── helpers to build n-gram keys ──
static inline int64_t key2(const int *E, int i) {
return (int64_t)E[i] * MS + E[i + 1];
}
static inline int64_t key3(const int *E, int i) {
return ((int64_t)E[i] * MS + E[i + 1]) * MS + E[i + 2];
}
static inline int64_t key4(const int *E, int i) {
// Use a mixing scheme to stay within 64 bits:
// 16 bits per symbol → 64 bits total for 4 symbols
return ((int64_t)E[i] << 48) | ((int64_t)E[i + 1] << 32)
| ((int64_t)E[i + 2] << 16) | (int64_t)E[i + 3];
}
void excerpt(int *E) {
int best = 0;
double best_s = -1e100;
if (N > 0) {
const double alpha = 0.5;
const int Vv = V > 0 ? V : 1;
// ── Pass 1: unigram Naive-Bayes score for every seen language ──
double scores[NL];
int order[NL];
int nlang = 0;
for (int L = 0; L < NL; L++) {
if (!cnt[L]) { scores[L] = -1e100; continue; }
double s = std::log((double)cnt[L] / N); // prior
double ld = std::log(sym_total[L] + alpha * Vv); // denom
for (int i = 0; i < 100; i++)
s += std::log(sym[L][E[i]] + alpha) - ld;
scores[L] = s;
order[nlang++] = L;
}
// Sort candidates by unigram score (descending)
std::sort(order, order + nlang,
[&](int a, int b) { return scores[a] > scores[b]; });
// ── Pass 2: saturated n-gram scores for top-K candidates ──
// score += W * tanh(count) per n-gram position
// tanh saturates smoothly: count=1→0.76W, count=2→0.96W, count≥3≈W
int K = std::min(nlang, 100);
static const double W2 = 7.0; // bigram weight
static const double W3 = 1.0; // trigram weight
static const double W4 = 9.0; // 4-gram weight
for (int k = 0; k < K; k++) {
int L = order[k];
// bigrams (weight 7)
for (int i = 0; i < 99; i++) {
auto it = ng2[L].find(key2(E, i));
int c = (it != ng2[L].end()) ? it->second : 0;
scores[L] += W2 * std::tanh((double)c);
}
// trigrams (weight 8)
for (int i = 0; i < 98; i++) {
auto it = ng3[L].find(key3(E, i));
int c = (it != ng3[L].end()) ? it->second : 0;
scores[L] += W3 * std::tanh((double)c);
}
// 4-grams (weight 9)
for (int i = 0; i < 97; i++) {
auto it = ng4[L].find(key4(E, i));
int c = (it != ng4[L].end()) ? it->second : 0;
scores[L] += W4 * std::tanh((double)c);
}
}
// Pick best among top-K
best = order[0];
best_s = scores[order[0]];
for (int k = 1; k < K; k++) {
if (scores[order[k]] > best_s) {
best_s = scores[order[k]];
best = order[k];
}
}
}
int c = language(best);
// ── online learning: update stats with correct answer ──
for (int i = 0; i < 100; i++) {
if (!seen[E[i]]) { seen[E[i]] = true; V++; }
sym[c][E[i]]++;
sym_total[c]++;
}
for (int i = 0; i < 99; i++) { ng2[c][key2(E, i)]++; ng2_total[c]++; }
for (int i = 0; i < 98; i++) { ng3[c][key3(E, i)]++; ng3_total[c]++; }
for (int i = 0; i < 97; i++) { ng4[c][key4(E, i)]++; ng4_total[c]++; }
cnt[c]++;
N++;
}