#include "gondola.h"
#include <set>
#include <iostream>
#include <vector>
#include <cstring>
#include <algorithm>
using namespace std;
using ll = long long;
using pii = pair<int,int>;
using piii = tuple<int,int,int>;
#define all(x) begin(x),end(x)
int valid(int n, int arr[])
{
int should = -n-10;
set<int> st;
for(int i{};i < n;i++){
should++;
if(should == n+1) should = 1;
//cout << should << endl;
if(st.count(arr[i])) return 0;
st.insert(arr[i]);
if(arr[i] <= n){
if(should <= 0) should = arr[i];
else if(should != arr[i]) return 0;
}
else if(arr[i] <= 0) return 0;
}
return 1;
}
//----------------------
int replacement(int n, int arr[], int finarr[])
{
int start = -1;
int n1 = 0;
for(int i{};i < n;i++){
n1 = max(n1,arr[i]-n);
if(arr[i] <= n && start == -1){
start = arr[i]-i;
if(start < 1) start += n;
}
}
if(start == -1) start = 1;
//cout << n1 << endl;
if(n1 == 0) return 0;
int ans[n1];
memset(ans,0,sizeof ans);
for(int i{};i < n;i++){
if(start == n+1) start = 1;
if(arr[i] > n){
ans[arr[i]-n-1] = start;
}
start++;
}
vector<int> leftover;
for(int i{};i < n1;i++){
//cout << ans[i] << " ";
if(ans[i] == 0) leftover.emplace_back(i);
}
//cout << endl;
if(!leftover.empty()){
// for(auto k:leftover) cout << k << " ";
// cout << endl;
int si = leftover.size();
ans[leftover[0]] = ans[n1-1];
ans[n1-1] = leftover[si-1]+n+1;
for(int i{1};i < si;i++){
ans[leftover[i]] = leftover[i-1]+n+1;
}
}
for(int i{};i < n1;i++){
finarr[i] = ans[i];
}
return n1;
}
//----------------------
int countReplacement(int n, int arr[])
{
if(!valid(n,arr)) return 0;
ll modval = 1e9+9;
auto fastexpo = [&](ll a,ll b,ll md,auto&& self){
if(b == 0) return 1ll;
else if(b == 1) return a;
ll t = self(a,b/2,md,self);
if(b&1) return (((t*t)%md)*a)%md;
else return (t*t)%md;
};
ll cnt = 0;
vector<int> vc = {n};
for(int i{};i < n;i++){
if(arr[i] > n) cnt++,vc.emplace_back(arr[i]);
}
sort(all(vc));
if(cnt == 0) return 1;
ll sum = 1;
for(int i{1};i <= cnt;i++){
if(vc[i]-vc[i-1]-1 > 0) sum = (sum*fastexpo(cnt-i+1,vc[i]-vc[i-1]-1,modval,fastexpo))%modval;
}
if(cnt == n) return (sum*n)%modval;
else return sum;
}