#include <bits/stdc++.h>
using namespace std;
void TLE() { while(true); }
void MLE() { exit(1); }
const int LOG = 20;
const int mxn = 1e5 + 10;
vector<int> adj[mxn];
int st[mxn], cyc[mxn], ord[mxn], root[mxn * 4];
int en[mxn], top[mxn], lvl[mxn], anc[LOG][mxn];
int TIME = 1, nxt[mxn], n;
bool rem[mxn];
void upt(int id, int tl, int tr, int l, int r, int val) {
if(tl == l && r == tr) {
root[id] = val;
return;
}
if(root[id]) {
root[id * 2 + 1] = root[id];
root[id * 2 + 2] = root[id];
root[id] = 0;
}
int tm = (tl + tr) / 2;
if(r <= tm) upt(id * 2 + 1, tl, tm, l, r, val);
else if(tm < l) upt(id * 2 + 2, tm + 1, tr, l, r, val);
else {
upt(id * 2 + 1, tl, tm, l, tm, val);
upt(id * 2 + 2, tm + 1, tr, tm + 1, r, val);
}
}
int query(int id, int tl, int tr, int i) {
if(root[id]) return root[id];
assert(tl < tr);
int tm = (tl + tr) / 2;
if(i <= tm) return query(id * 2 + 1, tl, tm, i);
return query(id * 2 + 2, tm + 1, tr, i);
}
void find_cyc(int u) {
stack<int> st;
const int now = -10;
int cur = u;
while(cur > 0 && cyc[cur] == 0) {
st.push(cur);
cyc[cur] = now;
cur = nxt[cur];
}
if(cur > 0 && cyc[cur] == now) {
int cnt = 1;
while(st.top() != cur) {
cyc[st.top()] = cur;
ord[st.top()] = cnt++;
st.pop();
}
cyc[st.top()] = cur;
ord[st.top()] = cnt;
st.pop();
}
while(st.size()) {
cyc[st.top()] = -1;
st.pop();
}
}
void euler_dfs(int u, int p, bool debug = false) {
st[u] = TIME;
TIME++;
top[u] = 0;
cyc[u] = 0;
anc[0][u] = p;
lvl[u] = lvl[p] + 1;
for(int J = 1; J < LOG; J++) {
int A = anc[J - 1][u];
if(A == 0) break;
anc[J][u] = anc[J - 1][A];
}
if(debug) {
cout << p << " -> " << u << endl;
}
for(int v : adj[u]) {
if(v == p) continue;
if(rem[v]) continue;
if(v == nxt[u]) continue;
euler_dfs(v, u, debug);
}
en[u] = TIME;
}
void cycle_dfs(int u, int p, int TOP) {
top[u] = TOP;
anc[0][u] = p;
lvl[u] = lvl[p] + 1;
for(int J = 1; J < LOG; J++) {
int A = anc[J - 1][u];
if(A == 0) break;
anc[J][u] = anc[J - 1][A];
}
for(int v : adj[u]) {
if(v == p) continue;
if(rem[v]) continue;
if(cyc[v]) continue;
cycle_dfs(v, u, TOP);
}
}
int cyc_dis(int u, int v) {
assert(cyc[u] == cyc[v]);
int ret = ord[u] - ord[v];
return ret < 0 ? ret + ord[cyc[u]] : ret;
}
int parent(int u) {
if(top[u] == 0) {
if(st[u] == 0) TLE();
int ret = query(0, 1, n, st[u]);
if(ret == 0) TLE();
return query(0, 1, n, st[u]);
}
if(cyc[top[u]] == 0) return top[u];
return cyc[top[u]];
}
int LCA(int u, int v) {
if(lvl[u] < lvl[v]) swap(u, v);
int jump = lvl[u] - lvl[v];
for(int J = 0; J < LOG; J++)
if(jump & (1 << J))
u = anc[J][u];
if(u == v) return u;
for(int J = LOG - 1; J >= 0; J--) {
if(anc[J][u] != anc[J][v]) {
u = anc[J][u];
v = anc[J][v];
}
}
return anc[0][u];
}
int find_dis(int u, int v) {
int pu = parent(u);
int pv = parent(v);
if(pu == 0 || pv == 0) MLE();
if(pu != pv) return -1;
if(cyc[pu] == 0) {
int A = LCA(u, v);
if(A != v) return -1;
return lvl[u] + lvl[v] - 2 * lvl[A];
}
if(top[u] == top[v]) {
int A = LCA(u, v);
if(A != v) return -1;
return lvl[u] + lvl[v] - 2 * lvl[A];
}
if(cyc[v] == 0) return -1;
assert(v == top[v]);
return lvl[u] - lvl[pu] + cyc_dis(top[u], top[v]);
}
int main() {
cin >> n;
for(int i = 1; i <= n; i++) {
cin >> nxt[i];
adj[i].push_back(nxt[i]);
adj[nxt[i]].push_back(i);
}
for(int i = 1; i <= n; i++)
if(cyc[i] == 0)
find_cyc(i);
for(int i = 1; i <= n; i++)
if(cyc[i] == -1)
cyc[i] = 0;
// check before
for(int i = 1; i <= n; i++) {
if(top[i] == 0) continue;
assert(cyc[top[i]] > 0);
}
for(int i = 1; i <= n; i++) {
if(nxt[i] == 0) {
euler_dfs(i, 0);
upt(0, 1, n, st[i], en[i] - 1, i);
}
else {
cycle_dfs(i, 0, i);
}
}
int q, t, u, v;
cin >> q;
for(int _ = 0; _ < q; _++) {
cin >> t >> u;
if(t == 2) {
cin >> v;
cout << find_dis(u, v) << endl;
}
else {
rem[u] = true;
if(top[u]) {
euler_dfs(u, 0);
}
upt(0, 1, n, st[u], en[u] - 1, u);
}
}
return 0;
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
23 ms |
10832 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |