Submission #1262402

#TimeUsernameProblemLanguageResultExecution timeMemory
1262402Canuc80kTriple Peaks (IOI25_triples)C++20
11 / 100
220 ms62448 KiB
#include <bits/stdc++.h>
using namespace std;
using ll = long long;

ll res = 0;
void add(ll i, ll j, ll k) { res++; }

ll count_pairs_distinct(const vector<int>& a, const vector<int>& b) {
    int n = (int)a.size();
    if (n == 0) return 0;

    map<pair<int,int>, int> cntPair;
    map<int,int> cntA, cntB;
    map<int,int> diag;
    for (int i = 0; i < n; ++i) {
        cntPair[{a[i], b[i]}]++;
        cntA[a[i]]++;
        cntB[b[i]]++;
    }
    for (auto &it : cntPair) {
        int x = it.first.first;
        int y = it.first.second;
        int freq = it.second;
        if (x == y) diag[x] = freq;
    }

    map<int, vector<pair<int,int>>> listA;
    for (auto &it : cntPair) {
        int x = it.first.first;
        int y = it.first.second;
        int f = it.second;
        listA[x].push_back({y, f});
    }

    auto getCnt = [&](int x, int v)->int {
        auto it = cntPair.find({x, v});
        if (it == cntPair.end()) return 0;
        return it->second;
    };

    ll S_total = 0;
    for (auto &it : cntPair) {
        int x = it.first.first;
        int y = it.first.second;
        int f_xy = it.second;
        ll dot = 0;
        auto &Lx = listA[x];
        auto &Ly = listA[y];
        if (Lx.size() <= Ly.size()) {
            for (auto &p : Lx) {
                int v = p.first;
                int cxv = p.second;
                int cyv = getCnt(y, v);
                if (cyv) dot += (ll)cxv * cyv;
            }
        } else {
            for (auto &p : Ly) {
                int v = p.first;
                int cyv = p.second;
                int cxv = getCnt(x, v);
                if (cxv) dot += (ll)cxv * cyv;
            }
        }
        S_total += (ll)f_xy * dot;
    }

    // Tính số trường hợp có chỉ số trùng nhau (ít nhất 2 trùng)
    // A = số (i==j)
    ll A = 0;
    for (auto &it : cntA) {
        int x = it.first;
        ll cntax = it.second;
        ll d = diag.count(x) ? diag[x] : 0;
        A += cntax * d;
    }
    // B_jk = số (j==k)
    ll B_jk = 0;
    for (auto &it : cntB) {
        int y = it.first;
        ll cntby = it.second;
        ll d = diag.count(y) ? diag[y] : 0;
        B_jk += cntby * d;
    }
    // T = số (i==j==k)
    ll T = 0;
    for (auto &it : diag) T += it.second;

    ll non_distinct = 2 * A + B_jk - 2 * T; // = |A| + |B| + |C| - 2*T  với |A|=|C|=A, |B|=B_jk
    ll result = S_total - non_distinct;
    return result;
}

long long count_triples(std::vector<int> H) {
    int n = (int)H.size();
    res = 0;

    // #TH: a[k] max
    for (int k = 2; k < n; k ++) {
        int i = k - H[k];
        if (i < 0 || H[i] >= H[k]) continue;
        int j1 = k - H[i], j2 = i + H[i];
        if (j1 < 0 || H[j1] >= H[k] || H[j1] != j1 - i) j1 = -1;
        if (j2 >= k || H[j2] >= H[k] || H[j2] != k - j2) j2 = -1;
        if (j1 != -1) add(i, j1, k);
        if (j2 != -1 && j1 != j2) add(i, j2, k);
    }

    // #TH: a[i] max
    for (int i = 0; i + 2 < n; i ++) {
        int k = i + H[i];
        if (k >= n || H[k] >= H[i]) continue;
        int j1 = k - H[k], j2 = i + H[k];
        if (j1 < 0 || H[j1] >= H[i] || H[j1] != j1 - i) j1 = -1;
        if (j2 >= k || H[j2] >= H[i] || H[j2] != k - j2) j2 = -1;
        if (j1 != -1) add(i, j1, k);
        if (j2 != -1 && j1 != j2) add(i, j2, k);
    }
        
    // #TH: a[j] max, j - i = a[i]
    for (int i = 0; i + 2 < n; i ++) {
        int j = i + H[i];
        if (j >= n || H[j] <= H[i]) continue;
        int k = i + H[j];
        if (k >= n || H[k] >= H[j] || k - H[k] != j || k - j == H[i]) k = -1;
        if (k != -1) add(i, j, k);
    }

    // #TH: a[j] max, j - i = a[k]
    vector<int> a, b;
    a.reserve(n); b.reserve(n);
    for (int i = 0; i < n; i++) a.push_back(H[i] + i);
    for (int i = 0; i < n; i++) b.push_back(H[i] - i);
    res += count_pairs_distinct(a, b);
    return res;
}

std::vector<int> construct_range(int M, int K) {
    // vector<int> res; res.push_back(1);
    // for (int i = 1; i < M; i ++) res.push_back(i);
    // return res;
}

Compilation message (stderr)

triples.cpp: In function 'std::vector<int> construct_range(int, int)':
triples.cpp:141:1: warning: no return statement in function returning non-void [-Wreturn-type]
  141 | }
      | ^
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...