Submission #1311957

#TimeUsernameProblemLanguageResultExecution timeMemory
1311957quereantPairs (IOI07_pairs)C++20
60 / 100
311 ms589824 KiB
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
using cd = complex<double>;
const double PI = acos(-1.0);

// -------------------- FFT (1D) --------------------
void fft_inplace(vector<cd> & a, bool invert) {
    int n = (int)a.size();
    static vector<int> rev;
    static vector<cd> roots{0,1};
    if ((int)rev.size() != n) {
        rev.assign(n,0);
        for (int i=0, j=0; i<n; ++i) {
            rev[i] = j;
            int bit = n>>1;
            while (j & bit) { j ^= bit; bit >>= 1; }
            j ^= bit;
        }
    }
    for (int i = 0; i < n; ++i)
        if (i < rev[i]) swap(a[i], a[rev[i]]);

    if ((int)roots.size() < n) {
        int k = __builtin_ctz(roots.size());
        roots.resize(n);
        while ((1<<k) < n) {
            double angle = 2 * PI / (1 << (k+1));
            for (int i = 1 << (k-1); i < (1<<k); ++i) {
                roots[2*i] = roots[i];
                double ang = angle * (2*i+1 - (1<<k));
                roots[2*i+1] = cd(cos(ang), sin(ang));
            }
            ++k;
        }
    }

    for (int len = 1; len < n; len <<= 1) {
        for (int i = 0; i < n; i += 2*len) {
            for (int j = 0; j < len; ++j) {
                cd u = a[i+j];
                cd v = a[i+j+len] * roots[len + j];
                a[i+j] = u + v;
                a[i+j+len] = u - v;
            }
        }
    }
    if (invert) {
        reverse(a.begin() + 1, a.end());
        for (int i = 0; i < n; ++i) a[i] /= n;
    }
}

// -------------------- 3D FFT helpers --------------------
// data is a flat vector of size L*L*L with index idx(x,y,z) = x + L*(y + L*z)
static inline int idx3(int x, int y, int z, int L) {
    return x + L * (y + L * z);
}

// perform FFT along dimension with stride: copy elements into tmp, fft, write back.
// start at indices formed from xBase,yBase,zBase and increment by strideX,strideY,strideZ? We will use simpler loops:
// For X-dim: contiguous segments of length L, stride 1.
// For Y-dim: segments at indices x + L*(y + L*z) with x fixed, z fixed -> stride L.
// For Z-dim: stride L*L.
void fft_dim(vector<cd> & data, int L, int dim, bool invert) {
    vector<cd> tmp(L);
    if (dim == 0) {
        // x dimension: for z,y loops, data contiguous
        for (int z = 0; z < L; ++z) {
            for (int y = 0; y < L; ++y) {
                int base = idx3(0,y,z,L);
                for (int x = 0; x < L; ++x) tmp[x] = data[base + x];
                fft_inplace(tmp, invert);
                for (int x = 0; x < L; ++x) data[base + x] = tmp[x];
            }
        }
    } else if (dim == 1) {
        // y dimension: for z,x loops, stride L
        for (int z = 0; z < L; ++z) {
            for (int x = 0; x < L; ++x) {
                for (int y = 0; y < L; ++y) tmp[y] = data[idx3(x,y,z,L)];
                fft_inplace(tmp, invert);
                for (int y = 0; y < L; ++y) data[idx3(x,y,z,L)] = tmp[y];
            }
        }
    } else {
        // z dimension: for y,x loops, stride L*L
        for (int y = 0; y < L; ++y) {
            for (int x = 0; x < L; ++x) {
                for (int z = 0; z < L; ++z) tmp[z] = data[idx3(x,y,z,L)];
                fft_inplace(tmp, invert);
                for (int z = 0; z < L; ++z) data[idx3(x,y,z,L)] = tmp[z];
            }
        }
    }
}

// 3D FFT: perform fft along x, then y, then z (forward). invert = true -> inverse transform.
void fft3d(vector<cd> & data, int L, bool invert) {
    // forward/inverse done by applying 1D FFTs along each dimension
    fft_dim(data, L, 0, invert);
    fft_dim(data, L, 1, invert);
    fft_dim(data, L, 2, invert);
}

// -------------------- Fenwick Tree --------------------
struct Fenwick {
    int n;
    vector<int> bit;
    Fenwick(int n=0): n(n), bit(n+1,0) {}
    void add(int i, int delta=1) {
        for (++i; i <= n; i += i & -i) bit[i] += delta;
    }
    int sumPrefix(int i) {
        if (i < 0) return 0;
        if (i >= n) i = n-1;
        int s = 0;
        for (++i; i > 0; i -= i & -i) s += bit[i];
        return s;
    }
    int rangeSum(int l, int r) {
        if (r < l) return 0;
        return sumPrefix(r) - (l ? sumPrefix(l-1) : 0);
    }
};

// -------------------- Main --------------------
int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int B;
    int N;
    long long D;
    int M;
    if (!(cin >> B >> N >> D >> M)) return 0;

    if (B == 1) {
        // 1D: just sort and two pointers
        vector<int> a(N);
        for (int i = 0; i < N; ++i) cin >> a[i];
        sort(a.begin(), a.end());
        long long ans = 0;
        int r = 0;
        for (int l = 0; l < N; ++l) {
            while (r < N && (long long)a[r] - a[l] <= D) ++r;
            // counts pairs (l, l+1..r-1) => r - l - 1
            ans += (r - l - 1);
        }
        cout << ans << "\n";
        return 0;
    }

    if (B == 2) {
        // transform to (u,v) where u = x+y, v = x-y, then L1<=D <=> both |du|<=D and |dv|<=D
        vector<pair<int,int>> pts; pts.reserve(N);
        for (int i = 0; i < N; ++i) {
            int x,y; cin >> x >> y;
            int u = x + y;
            int v = x - y;
            pts.emplace_back(u, v);
        }
        // compress v coordinates
        vector<int> allv;
        allv.reserve(N);
        for (auto &p: pts) allv.push_back(p.second);
        sort(allv.begin(), allv.end());
        allv.erase(unique(allv.begin(), allv.end()), allv.end());
        // sort by u
        vector<int> order(N);
        iota(order.begin(), order.end(), 0);
        sort(order.begin(), order.end(), [&](int i, int j){
            if (pts[i].first != pts[j].first) return pts[i].first < pts[j].first;
            return pts[i].second < pts[j].second;
        });
        Fenwick fw((int)allv.size());
        long long ans = 0;
        int left = 0;
        for (int ii = 0; ii < N; ++ii) {
            int idx = order[ii];
            int u = pts[idx].first;
            int v = pts[idx].second;
            // slide left until u - pts[order[left]].first <= D
            while (left < ii && (long long)u - pts[order[left]].first > D) {
                int vremove = pts[order[left]].second;
                int idv = (int)(lower_bound(allv.begin(), allv.end(), vremove) - allv.begin());
                fw.add(idv, -1);
                ++left;
            }
            // query number of existing points with v in [v - D, v + D]
            int vl = v - (int)D;
            int vr = v + (int)D;
            int L = (int)(lower_bound(allv.begin(), allv.end(), vl) - allv.begin());
            int R = (int)(upper_bound(allv.begin(), allv.end(), vr) - allv.begin()) - 1;
            if (L <= R) ans += fw.rangeSum(L, R);
            // insert current
            int idvcur = (int)(lower_bound(allv.begin(), allv.end(), v) - allv.begin());
            fw.add(idvcur, 1);
        }
        cout << ans << "\n";
        return 0;
    }

    // B == 3
    // Build freq grid f[x][y][z] for x,y,z in 1..M -> map to 0..M-1
    vector<int> coords(3);
    int mx = M;
    // grid size L must be power of two >= 2*M
    int need = 2 * M - 1;
    int L = 1;
    while (L < need) L <<= 1;

    // allocate arrays
    int L3 = L * L * L;
    vector<cd> A(L3); // occupancy
    vector<cd> K(L3); // kernel

    // frequency grid stored separately as int for later dot product
    vector<int> freq(L3, 0);

    // read points and populate A and freq
    ll total_animals = 0;
    for (int i = 0; i < N; ++i) {
        int x,y,z; cin >> x >> y >> z;
        --x; --y; --z;
        int id = idx3(x, y, z, L);
        ++freq[id];
        A[id] = cd((double)freq[id], 0); // will set properly later, but safe
        ++total_animals;
    }
    // A currently has counts equal to last freq value only for last incremented ones; better to zero & fill:
    for (int x = 0; x < L; ++x)
        for (int y = 0; y < L; ++y)
            for (int z = 0; z < L; ++z) {
                int id = idx3(x,y,z,L);
                int val = 0;
                if (x < M && y < M && z < M) val = freq[id];
                A[id] = cd((double)val, 0);
            }

    // build kernel K' of size 2*M-1 in each dimension: K'[dx+(M-1), dy+(M-1), dz+(M-1)] = 1 if |dx|+|dy|+|dz| <= D
    int off = M - 1;
    for (int dx = -(M-1); dx <= (M-1); ++dx) {
        for (int dy = -(M-1); dy <= (M-1); ++dy) {
            // remaining allowable |dz| <= D - |dx| - |dy|
            ll rem = (ll)D - (ll)abs(dx) - (ll)abs(dy);
            if (rem < 0) continue;
            // dz in range [-min(M-1, rem) .. +min(M-1, rem)]
            ll dzmax = min<ll>(M-1, rem);
            for (int dz = (int)-dzmax; dz <= (int)dzmax; ++dz) {
                int ix = dx + off;
                int iy = dy + off;
                int iz = dz + off;
                int idk = idx3(ix, iy, iz, L);
                K[idk] = cd(1.0, 0.0);
            }
        }
    }

    // Zero-padding K and A already done for rest of grid
    // Perform forward 3D FFT on both arrays
    fft3d(A, L, false);
    fft3d(K, L, false);

    // pointwise multiply
    for (int i = 0; i < L3; ++i) A[i] *= K[i];

    // inverse 3D FFT
    fft3d(A, L, true);

    // Now A[p] (real part) holds linear convolution result at index p.
    // For each original cell (x,y,z) with 0<=coord<M, the sum S at that cell is located at conv index (x+off, y+off, z+off)
    // Compute total_ordered = sum_{cell} f[cell] * S[cell]
    long double total_ordered_ld = 0.0L;
    for (int x = 0; x < M; ++x) {
        for (int y = 0; y < M; ++y) {
            for (int z = 0; z < M; ++z) {
                int id_src = idx3(x,y,z,L);
                int id_conv = idx3(x + off, y + off, z + off, L);
                long long f = freq[id_src];
                if (!f) continue;
                // round real part to nearest ll
                double val = A[id_conv].real();
                long long cnt = (long long) llround(val);
                total_ordered_ld += (long double)f * (long double)cnt;
            }
        }
    }

    // total_ordered should be integer
    long long total_ordered = (long long) llround(total_ordered_ld);

    // answer = (total_ordered - N) / 2
    long long ans = (total_ordered - total_animals) / 2LL;
    cout << ans << "\n";
    return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...