#include "werewolf.h"
#include <vector>
#include <iostream>
#include <algorithm>
using namespace std;
struct point {
int x, y;
};
struct event {
int type, pos, val;
};
bool cmp(event q1, event q2) {
if (q1.pos != q2.pos) return q1.pos < q2.pos;
return q1.type < q2.type;
}
struct rect {
int x1, x2, y1, y2;
rect() {
}
rect(int x1, int x2, int y1, int y2) {
this->x1 = x1;
this->x2 = x2;
this->y1 = y1;
this->y2 = y2;
}
};
class pointInRectangleSolution {
public:
vector < int > getRes(vector < point > p, vector < rect > Q) {
vector < int > res(Q.size(), 0);
vector < event > E;
int mxVal = 0;
for (int i = 0; i < p.size(); i++) {
E.push_back({ 0, p[i].x, p[i].y });
mxVal = max(mxVal, p[i].y);
}
for (int i = 0; i < Q.size(); i++) {
E.push_back({ 1, Q[i].x2, i });
E.push_back({ 2, Q[i].x1 - 1, i });
mxVal = max(mxVal, Q[i].y1);
mxVal = max(mxVal, Q[i].y2);
}
sort(E.begin(), E.end(), cmp);
fenwickTree t;
t.init(mxVal);
for (auto x : E) {
if (x.type == 0) {
t.add(x.val, 1);
}
if (x.type == 1) {
res[x.val] += t.sum(Q[x.val].y1, Q[x.val].y2);
}
if (x.type == 2) {
res[x.val] -= t.sum(Q[x.val].y1, Q[x.val].y2);
}
}
for (int i = 0; i < Q.size(); i++) res[i] = min(res[i], 1);
return res;
}
private:
class fenwickTree {
public:
void init(int sz) {
this->sz = sz;
T.resize(sz + 1, 0);
}
void add(int pos, int val) {
for (int i = pos; i <= sz; i += i & -i) T[i] += val;
}
int sum(int l, int r) {
if (l != 1) return sumPref(r) - sumPref(l - 1);
else return sumPref(r);
}
private:
int sumPref(int pos) {
int res = 0;
for (int i = pos; i > 0; i -= i & -i) res += T[i];
return res;
}
int sz;
vector < int > T;
};
};
struct edge {
int v1, v2, w;
};
bool cmp1(edge e1, edge e2) {
return e1.w < e2.w;
}
bool cmp2(edge e1, edge e2) {
return e1.w > e2.w;
}
struct tree {
public:
void init(int n) {
this->n = n;
w.resize(2 * n);
v.resize(2 * n);
tin.resize(2 * n);
tout.resize(2 * n);
binJ.resize(20, vector < int >(2 * n));
r.resize(2 * n, true);
}
void addEdge(int v1, int v2) {
v[v1].push_back(v2);
r[v2] = false;
}
void findOrder() {
timer = 0;
for (int i = 0; i < 2 * n; i++) {
if (r[i]) dfs(i, i);
}
}
int lastSmaller(int val, int r) {
for (int j = 19; j >= 0; j--) {
if (w[binJ[j][val]] <= r) val = binJ[j][val];
}
return val;
}
int lastGreater(int val, int l) {
for (int j = 19; j >= 0; j--) {
if (w[binJ[j][val]] >= l) val = binJ[j][val];
}
return val;
}
vector < int > w;
vector < int > tin, tout;
private:
void dfs(int val, int prev) {
tin[val] = ++timer;
binJ[0][val] = prev;
for (int j = 1; j < 20; j++) binJ[j][val] = binJ[j - 1][binJ[j - 1][val]];
for (auto to : v[val]) {
dfs(to, val);
}
tout[val] = timer;
}
int n, timer;
vector < vector < int > > v;
vector < vector < int > > binJ;
vector < bool > r;
};
vector < int > comp;
int getComp(int val) {
if (comp[val] == val) return val;
return comp[val] = getComp(comp[val]);
}
vector < int > check_validity(int n, vector < int > x, vector < int > y, vector < int > s, vector < int > e, vector < int > L, vector < int > R) {
int q = s.size();
vector < edge > edges;
for (int i = 0; i < x.size(); i++) {
edges.push_back({ x[i], y[i], max(x[i], y[i]) });
}
tree tmx, tmi;
tmx.init(n);
tmi.init(n);
comp.resize(2 * n);
for (int i = 0; i < 2 * n; i++) comp[i] = i;
sort(edges.begin(), edges.end(), cmp1);
int allNodes = n;
for (int i = 0; i < edges.size(); i++) {
int t1 = getComp(edges[i].v1);
int t2 = getComp(edges[i].v2);
if (t1 == t2) continue;
tmx.addEdge(allNodes, t1);
tmx.addEdge(allNodes, t2);
tmx.w[allNodes] = edges[i].w;
comp[t1] = allNodes;
comp[t2] = allNodes;
allNodes++;
}
edges.clear();
for (int i = 0; i < x.size(); i++) {
edges.push_back({ x[i], y[i], min(x[i], y[i]) });
}
for (int i = 0; i < 2 * n; i++) comp[i] = i;
sort(edges.begin(), edges.end(), cmp2);
allNodes = n;
for (int i = 0; i < edges.size(); i++) {
int t1 = getComp(edges[i].v1);
int t2 = getComp(edges[i].v2);
if (t1 == t2) continue;
tmi.addEdge(allNodes, t1);
tmi.addEdge(allNodes, t2);
tmi.w[allNodes] = edges[i].w;
comp[t1] = allNodes;
comp[t2] = allNodes;
allNodes++;
}
edges.clear();
tmi.findOrder();
tmx.findOrder();
vector < point > points(n);
vector < rect > rQ(q);
for (int i = 0; i < n; i++) {
points[i].x = tmi.tin[i];
points[i].y = tmx.tin[i];
}
vector < bool > flag(q, false);
for (int i = 0; i < q; i++) {
if (s[i] < L[i]) flag[i] = true;
int ver = tmi.lastGreater(s[i], L[i]);
rQ[i].x1 = tmi.tin[ver];
rQ[i].x2 = tmi.tout[ver];
ver = tmx.lastSmaller(e[i], R[i]);
rQ[i].y1 = tmx.tin[ver];
rQ[i].y2 = tmx.tout[ver];
}
pointInRectangleSolution sl;
vector < int > res = sl.getRes(points, rQ);
for (int i = 0; i < q; i++) if (flag[i]) res[i] = 0;
return res;
}