#include "grader.h"
#include <bits/stdc++.h>
using namespace std;
static constexpr int MAX_NODES = 1500000;
static constexpr int HASH_CAP = 1 << 22; // 4,194,304 (must be power of 2)
// Custom flat hash table to guarantee ZERO feature collisions
static int head[HASH_CAP];
static int next_node[MAX_NODES];
static uint64_t keys[MAX_NODES];
static int16_t weights[MAX_NODES][56];
static int node_cnt = 0;
// High-quality avalanche for the open-addressing table
static inline uint64_t splitmix64(uint64_t x) {
x += 0x9e3779b97f4a7c15ULL;
x = (x ^ (x >> 30)) * 0xbf58476d1ce4e5b9ULL;
x = (x ^ (x >> 27)) * 0x94d049bb133111ebULL;
return x ^ (x >> 31);
}
// Retrieves the unique ID for an n-gram, creating it if unseen
static int get_id(uint64_t key) {
uint32_t h = splitmix64(key) & (HASH_CAP - 1);
for (int i = head[h]; i != -1; i = next_node[i]) {
if (keys[i] == key) return i;
}
// Prevent Memory Limit Exceeding (graceful degradation)
if (node_cnt >= MAX_NODES) return -1;
int id = node_cnt++;
keys[id] = key;
next_node[id] = head[h];
head[h] = id;
return id;
}
// Safely update perceptron weights to avoid int16 overflow
static inline void sat_add(int16_t& w, int delta) {
int v = w + delta;
if (v > 30000) v = 30000;
else if (v < -30000) v = -30000;
w = v;
}
void excerpt(int E[100]) {
static bool inited = false;
if (!inited) {
memset(head, -1, sizeof(head));
// Global variables are natively zero-initialized, so weights array is already 0.
inited = true;
}
vector<int> f1, f2, f3, f4;
f1.reserve(100); f2.reserve(100); f3.reserve(100); f4.reserve(100);
// Extract exactly packed 1 to 4-grams (E[i] takes 16 bits, so max 64 bits)
for (int i = 0; i < 100; i++) {
uint64_t h = E[i];
int id = get_id(h);
if (id != -1) f1.push_back(id);
if (i < 99) {
h = (h << 16) | E[i + 1];
id = get_id(h);
if (id != -1) f2.push_back(id);
}
if (i < 98) {
h = (h << 16) | E[i + 2];
id = get_id(h);
if (id != -1) f3.push_back(id);
}
if (i < 97) {
h = (h << 16) | E[i + 3];
id = get_id(h);
if (id != -1) f4.push_back(id);
}
}
// Keep only unique presences to prevent high-frequency noise skewing the dot product
auto make_unique = [](vector<int>& v) {
sort(v.begin(), v.end());
v.erase(unique(v.begin(), v.end()), v.end());
};
make_unique(f1);
make_unique(f2);
make_unique(f3);
make_unique(f4);
int best_l = 0, second_l = 1;
int32_t best_score = -2e9, second_score = -2e9;
// Dot product prediction
for (int l = 0; l < 56; l++) {
int32_t score = 0;
for (int id : f1) score += weights[id][l] * 1;
for (int id : f2) score += weights[id][l] * 2;
for (int id : f3) score += weights[id][l] * 3;
for (int id : f4) score += weights[id][l] * 4;
if (score > best_score) {
second_score = best_score;
second_l = best_l;
best_score = score;
best_l = l;
} else if (score > second_score) {
second_score = score;
second_l = l;
}
}
int guess = best_l;
int correct = language(guess);
// Margin-based multiclass update closure
auto update = [&](int pos_l, int neg_l, int multiplier) {
for (int id : f1) { sat_add(weights[id][pos_l], 1 * multiplier); sat_add(weights[id][neg_l], -1 * multiplier); }
for (int id : f2) { sat_add(weights[id][pos_l], 2 * multiplier); sat_add(weights[id][neg_l], -2 * multiplier); }
for (int id : f3) { sat_add(weights[id][pos_l], 3 * multiplier); sat_add(weights[id][neg_l], -3 * multiplier); }
for (int id : f4) { sat_add(weights[id][pos_l], 4 * multiplier); sat_add(weights[id][neg_l], -4 * multiplier); }
};
if (correct != guess) {
// Correct aggressive mistake
update(correct, guess, 1);
} else {
// If correct but runner-up was close, sharpen margin boundary
if (best_score - second_score < 50) {
update(correct, second_l, 1);
}
}
}