Submission #1256458

#TimeUsernameProblemLanguageResultExecution timeMemory
1256458madamadam3Triple Peaks (IOI25_triples)C++20
18 / 100
2125 ms890208 KiB
#include "triples.h"
#include <bits/stdc++.h>

using namespace std;

typedef long long ll;
using vi = vector<int>;
using vvi = vector<vi>;
using vl = vector<ll>;
using vvl = vector<vl>;
using pi = pair<int, int>;

#define FOR(i, a, b) for (int i = a; i < b; i++)
#define ROF(i, a, b) for (int i = a; i >= b; i--)
#define each(a, x) for (auto &a : x)
#define all(x) (x).begin(), (x).end()
#define bg(x) (x).begin()
#define en(x) (x).end()
#define rev(x) reverse(all(x))
#define sz(x) int((x).size())
#define srt(x) sort(all(x))
#define cnt(a, x) count(all(x), a)
#define trace(x) each(a, (x)) cout << a << " "
#define mp make_pair
#define pb push_back
#define lb lower_bound
#define ub upper_bound

struct DSU {
  int n; vi par, siz;

  DSU() {};
  DSU(int N) {
    n = N; par.resize(n); siz.assign(n, 1);
    iota(all(par), 0);
  }

  int find(int v) {
    if (par[v] == v) return v;
    return par[v] = find(par[v]);
  }

  void unite(int a, int b) {
    a = find(a); b = find(b);
    if (a != b) {
      if (siz[a] < siz[b]) swap(a, b);
      par[b] = a;
      siz[a] += siz[b];
    }
  }
};

ll brute(vi h) {
  int n = sz(h);
  ll ans = 0;

  // cerr << "brute function: \n";
  FOR(i, 0, n) {
    FOR(j, i+1, n) {
      FOR(k, j+1, n) {
        // vi d = {k - j, k - i, j - i}, f = {h[i], h[j], h[k]};
        // srt(d); srt(f);
        int d1 = min({k-j, k-i,j-i}), d3 = max({k-j, j-i, k-i});
        int d2 = (k-j) ^ (k-i) ^ (j-i) ^ d1 ^ d3;
        int f1 = min({h[i], h[j], h[k]}), f3 = max({h[i], h[j], h[k]}); 
        int f2 = h[i] ^ h[j] ^ h[k] ^ f1 ^ f3;
        // bool valid = d[0] == f[0] && d[1] == f[1] && d[2] == f[2];
        // if (!valid) continue;
        // ans++;
        // cerr << i << " " << j << " " << k << "\n";
        if (d1 == f1 && d2 == f2 && d3 == f3) ans++;
      }
    }
  }

  return ans;
}

struct Triple {
  int i, j, k;

  Triple() {};
  Triple(int I, int J, int K) {
    i = min({I, J, K});
    k = max({I, J, K});
    j = (I ^ J ^ K) ^ i ^ k; // lol
  }

  const bool operator<(const Triple &other) const {
    if (i != other.i) return i < other.i;
    if (j != other.j) return j < other.j;
    
    return k < other.k;
  }
};

ll increasing(vi h) {
  int n = sz(h);
  
  set<Triple> used;
  FOR(i, 0, n) {
    int k = i - h[i];
    if (k < 0) continue;

    int j1 = k + h[k], j2 = i - h[k];
    if (k < j1 && j1 < i) {
      int e = j1 + h[j1];
      if (e == i) {
        used.insert(Triple(i, j1, k));
      }
    }
    if (k < j2 && j2 < i) {
      int e = j2 - h[j2];
      if (e == k) {
        used.insert(Triple(i, j2, k));
      }
    }
  }

  // for (auto &el : used) {
  //   cerr << el.i << " " << el.j << " " << el.k << "\n";
  // }
  return sz(used);
}

ll smart(vi h) {
  int n = sz(h);

  ll ans = 0;
  FOR(i, 0, n) { // count number of triples where h[i] = d(i, j) or d(i, k)
    int H = h[i]; 
    set<Triple> trips; // prevent double count with dupe height

    int j = i + H;
    if (j < n) {
      int k1 = i + h[j], k2 = j + h[j]; // 2 cases — h[j] = d(i, k) or h[j] = d(j, k)
      if (k1 < n && h[k1] == (k1 - j)) trips.insert(Triple(i, j, k1)); // we know d(i, j) and d(i, k) to be h[i] and h[j], so does h[k] == d(j, k)
      if (k2 < n && h[k2] == (k2 - i)) trips.insert(Triple(i, j, k2)); // symmetric but this time for i
    }

    int k = i + H;
    if (k < n) {
      int j1 = i + h[k], j2 = k - h[k]; // h[k] = d(i, k), h[k] = d(j, k)
      if ((i < j1 && j1 < k) && h[j1] == (k - j1)) trips.insert(Triple(i, j1, k));
      if ((i < j2 && j2 < k) && h[j2] == (j2 - i)) trips.insert(Triple(i, j2, k));
    }

    ans += trips.size();
  }

  unordered_map<int, vector<int>> pc, nc; // positive line intercepts, negative line intercepts
  vi inp(n), inn(n); // lookup for the intercept
  FOR(i, 0, n) {
    int c1 = h[i] - i, c2 = h[i] + i;
    inp[i] = c1, inn[i] = c2;
    pc[c1].pb(i), nc[c2].pb(i);
  }

  const int SQRT = 25;
  unordered_map<int, vector<vector<int>>> ppref, npref;
  for (auto &el : pc) {
    ppref[el.first] = vector<vector<int>>(SQRT, vector<int>(n+1, 0));
    npref[el.first] = vector<vector<int>>(SQRT, vector<int>(n+1, 0));

    FOR(H, 1, SQRT) {
      for (int i : el.second) {
        if (i + H < n && inp[i+H] == inp[i] && h[i] != H && h[i+H] != H) ppref[el.first][H][i+1] = 1;
        if (i + H < n && inn[i+H] == inn[i] && h[i] != H && h[i+H] != H) npref[el.first][H][i+1] = 1;
      }
      FOR(i, 1, n+1) ppref[el.first][H][i+1] += ppref[el.first][H][i];
      FOR(i, 1, n+1) npref[el.first][H][i+1] += npref[el.first][H][i];
    }
  }

  FOR(i, 0, n) {
    int H = h[i];
    int c1 = -i, c2 = inp[i];

    if (H < SQRT) { 
      if (ppref.count(c1)) {
        ans += ppref[c1][H][n] - ppref[c1][H][i+1];
      }
    } else {
      for (auto &el : pc[c1]) {
        if (el <= i) continue;
        int j = el, k = el + H;
        if (!(i < j && j < k && k < n)) continue;
        if (inp[k] != inp[j]) continue;
        if (h[j] == H || h[k] == H) continue;

        ans++;
      }
    }

    if (H < SQRT) {
      if (npref.count(c2)) {
        ans += npref[c2][H][n] - npref[c2][H][i+1];
      }
    } else {
      for (auto &el : pc[c2]) {
        if (el <= i) continue;
        int j = el, k = el + H;
        if (!(i < j && j < k && k < n)) continue;
        if (inn[k] != inn[j]) continue;
        if (h[j] == H || h[k] == H) continue;

        ans++;
      }
    }
  }

  return ans;
}

ll count_triples(vi h) {
  return smart(h); 
  int n = sz(h);
  
  ll ans = 0;
  set<Triple> triples;

  FOR(i, 0, n) {
    FOR(j, 0, i) {
      // for a given pair (i, j) the possible values for k are i - h[i], i - h[j], j+h[i], j+h[j]
      int dist = i - j;
      vi K = {i-h[i], i-h[j], j+h[i], j+h[j]}; // there are O(4n^2) triples total??
      set<int> seen;

      each(k, K) {
        if (!(j < k && k < i)) continue;
        if (seen.count(k)) continue;
        seen.insert(k);

        vi d = {abs(i-j), abs(i-k), abs(j-k)}, f = {h[i], h[j], h[k]};
        srt(d); srt(f);
        bool valid = d[0] == f[0] && d[1] == f[1] && d[2] == f[2];
        if (!valid) continue;

        ans++;
        triples.insert(Triple(i, j, k));
      }
    }
  }

  for (auto &el : triples) {
    cerr << el.i << " " << el.j << " " << el.k << "\n";
  }
  return ans;
}

vi construct_range(int M, int K) {
  // vi arr = {3, 1, 1, 2, 4, 3, 1, 2, 1, 4, 3, 2};
  // vi ret = {};

  // while (sz(ret) < M) {
  //   each(el, arr) ret.pb(el);
  // }

  // while (sz(ret) > M) ret.pop_back();
  // return ret;

  int best = -1;
  vi bc;

  int LO = 4, HI = 20, MAX_CALL = 4;
  auto dfs = [&](const auto &self, vi cur) {
    if (sz(cur) == M) {
      ll tl = brute(cur);
      if (tl > best) {
        cerr << "Found an array with " << tl << " triples.\n";
        trace(cur); cerr << "\n";
        best = tl;
        bc = cur;
      }
      return;
    }

    FOR(i, 1, MAX_CALL + 1) {
      vi ncur = cur;
      ncur.pb(i);
      self(self, ncur);
    }
  };

  vi initial = {};
  // dfs(dfs, initial);
  int cnt = 0; bool flip = false;
  for (M = LO; M <= HI; M += LO) {
    
    dfs(dfs, initial);
    initial = bc;
  }
  // for (auto &el : bc) {
  //   cerr << el << ", ";
  // }
  // cerr << "\n";
  return bc;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...