제출 #1330855

#제출 시각아이디문제언어결과실행 시간메모리
1330855normankr07Languages (IOI10_languages)C++20
101 / 100
4985 ms41336 KiB
#include "lang.h"
#include "grader.h"

#include <cmath>
#include <algorithm>
#include <cstring>
#include <unordered_map>

static const int NL = 56;
static const int MS = 65536;

// ── unigram data (14 MB array for O(1) lookup) ──
static int sym[NL][MS];
static int sym_total[NL];

// ── n-gram data (2-, 3-, 4-grams stored as frequency maps per language) ──
static std::unordered_map<int64_t, int> ng2[NL];
static std::unordered_map<int64_t, int> ng3[NL];
static std::unordered_map<int64_t, int> ng4[NL];


// ── bookkeeping ──
static int cnt[NL];          // excerpts seen per language
static int N = 0;            // total excerpts seen
static int V = 0;            // global vocabulary size
static bool seen[MS];

// ── saturation: x/(x+1.1) ∈ [0,1), cheaper than tanh (no exp) ──
static inline double hyperb(int x) { return x / (x + 1.1); }

// ── helpers to build n-gram keys ──
static inline int64_t key2(const int *E, int i) {
    return (int64_t)E[i] * MS + E[i + 1];
}
static inline int64_t key3(const int *E, int i) {
    return ((int64_t)E[i] * MS + E[i + 1]) * MS + E[i + 2];
}
static inline int64_t key4(const int *E, int i) {
    // Use a mixing scheme to stay within 64 bits:
    // 16 bits per symbol → 64 bits total for 4 symbols
    return ((int64_t)E[i] << 48) | ((int64_t)E[i + 1] << 32)
         | ((int64_t)E[i + 2] << 16) | (int64_t)E[i + 3];
}

void excerpt(int *E) {
    int best = 0;
    double best_s = -1e100;

    if (N > 0) {
        const double alpha = 0.5;
        const int Vv = V > 0 ? V : 1;

        // ── Pass 1: unigram Naive-Bayes score for every seen language ──
        double scores[NL];
        int order[NL];
        int nlang = 0;

        for (int L = 0; L < NL; L++) {
            if (!cnt[L]) { scores[L] = -1e100; continue; }

            double s = std::log((double)cnt[L] / N);            // prior
            double ld = std::log(sym_total[L] + alpha * Vv);    // denom

            for (int i = 0; i < 100; i++)
                s += std::log(sym[L][E[i]] + alpha) - ld;

            scores[L] = s;
            order[nlang++] = L;
        }

        // Sort candidates by unigram score (descending)
        std::sort(order, order + nlang,
                  [&](int a, int b) { return scores[a] > scores[b]; });

        // ── Pass 2: saturated n-gram scores (all candidates) ──
        //   score += W * tanh(c/cnt * SCALE)  per n-gram position
        //   Normalising by cnt removes bias toward well-trained languages:
        //   what matters is the RELATIVE frequency of the n-gram, not raw count.
        //   SCALE=20 means tanh≈0.76 when the n-gram appears in 1/20 = 5% of excerpts.
        int K = nlang;  // no cutoff: always re-rank all seen languages

        static const double W2 = 7.0;  // bigram weight
        static const double W3 = 8.0;  // trigram weight
        static const double W4 = 9.0;  // 4-gram weight

        for (int k = 0; k < K; k++) {
            int L = order[k];
            double ng = 0;

            // bigrams (weight 7)
            for (int i = 0; i < 99; i++) {
                auto it = ng2[L].find(key2(E, i));
                ng += W2 * hyperb(it != ng2[L].end() ? it->second : 0);
            }

            // trigrams (weight 8)
            for (int i = 0; i < 98; i++) {
                auto it = ng3[L].find(key3(E, i));
                ng += W3 * hyperb(it != ng3[L].end() ? it->second : 0);
            }

            // 4-grams (weight 9)
            for (int i = 0; i < 97; i++) {
                auto it = ng4[L].find(key4(E, i));
                ng += W4 * hyperb(it != ng4[L].end() ? it->second : 0);
            }

            scores[L] += ng / std::log(cnt[L] + 2.0);
        }

        // Pick best among top-K
        best   = order[0];
        best_s = scores[order[0]];
        for (int k = 1; k < K; k++) {
            if (scores[order[k]] > best_s) {
                best_s = scores[order[k]];
                best   = order[k];
            }
        }
    }

    int c = language(best);

    // ── online learning: update stats with correct answer ──
    for (int i = 0; i < 100; i++) {
        if (!seen[E[i]]) { seen[E[i]] = true; V++; }
        sym[c][E[i]]++;
        sym_total[c]++;
    }
    for (int i = 0; i < 99; i++)  ng2[c][key2(E, i)]++;
    for (int i = 0; i < 98; i++)  ng3[c][key3(E, i)]++;
    for (int i = 0; i < 97; i++)  ng4[c][key4(E, i)]++;

    cnt[c]++;
    N++;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...