#include "seats.h"
#include <bits/stdc++.h>
using namespace std;
typedef pair<int,int> pii;
pii merge(pii a, pii b){
if(a.first == b.first) return {a.first, a.second + b.second};
else if(a.first < b.first) return a;
else return b;
}
struct node{
int s,e,m,lazy = 0;
pii val = {0,1};
node *l = nullptr, *r = nullptr;
node(int S, int E){
s=S, e=E, m=(s+e)/2;
if(s!=e){
l = new node(s,m);
r = new node(m+1,e);
}
}
void prop(){
if(lazy == 0) return;
val.first += lazy;
if(s!=e){
l->lazy += lazy;
r->lazy += lazy;
}
lazy = 0;
}
void update(int S, int E, int v){
if(s==S and e==E) lazy += v;
else{
prop();
if(E<=m) l->update(S,E,v);
else if(S>m) r->update(S,E,v);
else l->update(S,m,v), r->update(m+1,E,v);
l->prop(), r->prop();
val = merge(l->val, r->val);
}
}
pii query(int S, int E){
prop();
if(s==S and e==E) return val;
else if(E<=m) return l->query(S,E);
else if(S>m) return r->query(S,E);
else return merge(l->query(S,m), r->query(m+1,E));
}
} *root;
int n,h,w;
vector<vector<int>> grid;
vector<int> r, c, deltas;
int dx[] = {0,-1,-1,-1,0,1,1,1,0};
int dy[] = {-1,-1,0,1,1,1,0,-1,-1};
int count_square(int x, int y, int st, int en){
int cur = grid[x][y];
int exists = 0;
for(int i=st; i<=en; i++){
int nx = x+dx[i];
int ny = y+dy[i];
if(grid[nx][ny] < cur) exists++;
}
if(exists == 0 or exists == 2) return 1;
else return -1;
}
int count_delta(int x, int y){
int delta = 0;
delta += count_square(x,y,0,2);
delta += count_square(x,y,2,4);
delta += count_square(x,y,4,6);
delta += count_square(x,y,6,8);
return delta;
}
void give_initial_chart(int H, int W, std::vector<int> R, std::vector<int> C) {
h = H, w = W, n = h*w;
r = R, c = C;
grid.resize(h+2, vector<int>(w+2, n));
for(int i=0; i<n; i++){
r[i]++, c[i]++;
grid[r[i]][c[i]] = i;
}
root = new node(0,n-1);
int sum = 0;
for(int i=0; i<n; i++){
int x = r[i], y = c[i];
int delta = count_delta(x,y);
deltas.push_back(delta);
sum += delta;
root->update(i,i,sum);
}
}
int swap_seats(int a, int b) {
set<int>s;
s.insert(a);
s.insert(b);
int x = r[a], y = c[a];
for(int i=0; i<8; i++){
int nx = x+dx[i], ny = y+dy[i];
if(nx<1 or ny<1 or nx>h or ny>w) continue;
s.insert(grid[nx][ny]);
}
x = r[b], y = c[b];
for(int i=0; i<8; i++){
int nx = x+dx[i], ny = y+dy[i];
if(nx<1 or ny<1 or nx>h or ny>w) continue;
s.insert(grid[nx][ny]);
}
for(auto i: s) root->update(i, n-1, -1*deltas[i]);
swap(r[a],r[b]);
swap(c[a],c[b]);
swap(grid[r[a]][c[a]], grid[r[b]][c[b]]);
for(auto i: s){
deltas[i] = count_delta(r[i],c[i]);
root->update(i, n-1, deltas[i]);
}
return root->query(0,n-1).second;
}