#include "grader.h"
#include <bits/stdc++.h>
using namespace std;
static constexpr int MAX_NODES = 1500000;
// Feature storage
static uint64_t keys[MAX_NODES];
static int16_t weights[MAX_NODES][56];
static int node_cnt = 0;
// NO HASH: balanced tree mapping 64-bit ngram key -> unique id
static std::map<uint64_t, int> id_of;
// Retrieves the unique ID for an n-gram, creating it if unseen (NO HASH)
static int get_id(uint64_t key) {
auto it = id_of.lower_bound(key);
if (it != id_of.end() && it->first == key) return it->second;
// Prevent Memory Limit Exceeding (graceful degradation)
if (node_cnt >= MAX_NODES) return -1;
int id = node_cnt++;
keys[id] = key;
id_of.emplace_hint(it, key, id);
return id;
}
// Safely update perceptron weights to avoid int16 overflow
static inline void sat_add(int16_t& w, int delta) {
int v = (int)w + delta;
if (v > 30000) v = 30000;
else if (v < -30000) v = -30000;
w = (int16_t)v;
}
void excerpt(int E[100]) {
static bool inited = false;
if (!inited) {
// globals are zero-initialized, weights already 0
// map is empty initially
inited = true;
}
vector<int> f1, f2, f3, f4;
f1.reserve(100); f2.reserve(100); f3.reserve(100); f4.reserve(100);
// Extract packed 1..4-grams (E[i] is 16-bit, so up to 64-bit packed)
for (int i = 0; i < 100; i++) {
uint64_t h = (uint16_t)E[i];
int id = get_id(h);
if (id != -1) f1.push_back(id);
if (i < 99) {
h = (h << 16) | (uint16_t)E[i + 1];
id = get_id(h);
if (id != -1) f2.push_back(id);
}
if (i < 98) {
h = (h << 16) | (uint16_t)E[i + 2];
id = get_id(h);
if (id != -1) f3.push_back(id);
}
if (i < 97) {
h = (h << 16) | (uint16_t)E[i + 3];
id = get_id(h);
if (id != -1) f4.push_back(id);
}
}
// Unique presence features
auto make_unique = [](vector<int>& v) {
sort(v.begin(), v.end());
v.erase(unique(v.begin(), v.end()), v.end());
};
make_unique(f1);
make_unique(f2);
make_unique(f3);
make_unique(f4);
int best_l = 0, second_l = 1;
int32_t best_score = INT32_MIN, second_score = INT32_MIN;
// Dot product prediction
for (int l = 0; l < 56; l++) {
int32_t score = 0;
for (int id : f1) score += (int32_t)weights[id][l] * 1;
for (int id : f2) score += (int32_t)weights[id][l] * 2;
for (int id : f3) score += (int32_t)weights[id][l] * 3;
for (int id : f4) score += (int32_t)weights[id][l] * 4;
if (score > best_score) {
second_score = best_score;
second_l = best_l;
best_score = score;
best_l = l;
} else if (score > second_score) {
second_score = score;
second_l = l;
}
}
int guess = best_l;
int correct = language(guess);
auto update = [&](int pos_l, int neg_l, int multiplier) {
for (int id : f1) { sat_add(weights[id][pos_l], 1 * multiplier); sat_add(weights[id][neg_l], -1 * multiplier); }
for (int id : f2) { sat_add(weights[id][pos_l], 2 * multiplier); sat_add(weights[id][neg_l], -2 * multiplier); }
for (int id : f3) { sat_add(weights[id][pos_l], 3 * multiplier); sat_add(weights[id][neg_l], -3 * multiplier); }
for (int id : f4) { sat_add(weights[id][pos_l], 4 * multiplier); sat_add(weights[id][neg_l], -4 * multiplier); }
};
if (correct != guess) {
update(correct, guess, 1);
} else {
if (best_score - second_score < 50) {
update(correct, second_l, 1);
}
}
}