# | Time | Username | Problem | Language | Result | Execution time | Memory |
---|---|---|---|---|---|---|---|
1229747 | BigBadBully | Sorting (IOI15_sorting) | C++20 | 0 ms | 0 KiB |
#include <bits/stdc++.h>
#define int long long
#define pii pair<int, int>
#define ff first
#define ss second
using namespace std;
using ll = long long;
const int mod = 998244353;
const int inf = 1e18;
const int maxn = 5 * 1e5 + 5;
vector<int> fact(maxn + 1, 1);
vector<int> invfact(maxn + 1, 1);
ll exp(ll x, ll n)
{
assert(n >= 0);
x %= mod; // note: m * m must be less than 2^63 to avoid ll overflow
ll res = 1;
while (n > 0)
{
if (n % 2 == 1)
{
res = res * x % mod;
}
x = x * x % mod;
n /= 2;
}
return res;
}
int inv(int x)
{
return exp(x, mod - 2);
}
void init()
{
for (int i = 1; i <= maxn; i++)
fact[i] = fact[i - 1] * i % mod;
invfact[maxn] = inv(fact[maxn]);
for (int i = maxn - 1; i >= 0; i--)
invfact[i] = invfact[i + 1] * (i + 1) % mod;
}
int ncr(int n, int r)
{
if (n < r)
return 0ll;
return fact[n] * invfact[r] % mod * invfact[n - r] % mod;
}
int sigma(vector<int>&a,vector<int>&b,vector<int>&wher,vector<int>&v,bool mode)
{
int ans = 0;
auto f = [&](int i,int j)
{
swap(v[i],v[j]);
ans++;
if(mode)
{
swap(wher[i],wher[j]);
a.push_back(wher[i]);
b.push_back(wher[j]);
}
};
int n = v.size();
vector<int> vis(n,0);
for (int i = 0; i < n; i++)
{
if (!vis[i])
{
vis[i] = 1;
int cur = v[i];
while(!vis[cur])
{
vis[cur] = 1;
auto old = v[cur];
f(i,cur);
cur = old;
}
}
}
assert(is_sorted(v.begin(),v.end()));
return ans;
};
int findSwapPairs
(int N, vector<int>& S, int M, vector<int>& X, vector<int>& Y,
vector<int>& P, vector<int>& Q)
{
int n = N;
auto v = S;
int m = M;
auto x = X;
auto y = Y;
auto p = P;
auto q = Q;
auto check = [&](int t,bool mode)
{
vector<int> wher(n),orig(n);
for (int i = 0; i < n; i++)
wher[i] = i,orig[i] = i;
for (int i = 0; i < t; i++)
{
swap(v[x[i]],v[y[i]]);
swap(orig[x[i]],orig[y[i]]);
swap(wher[orig[x[i]]],wher[orig[y[i]]]);
}
int res = sigma(p,q,wher,v,mode);
return res <= t;
};
int l = 0,r = n;
while(r-l)
{
int mid = l+r>>1;
if (check(mid,0))
r = mid;
else
l = mid+1;
}
return check(l,1);
}