import sys
class FenwickTree:
""" Fenwick Tree (Binary Indexed Tree) for inversion counting. """
def __init__(self, size):
self.size = size
self.tree = [0] * (size + 1)
def update(self, index, value):
""" Update Fenwick Tree at index. """
while index <= self.size:
self.tree[index] += value
index += index & -index
def query(self, index):
""" Compute prefix sum up to index. """
total = 0
while index > 0:
total += self.tree[index]
index -= index & -index
return total
def range_query(self, left, right):
""" Compute range sum from left to right. """
return self.query(right) - self.query(left - 1)
def count_initial_inversions(A_mapped, K):
""" Count initial inversions using a Fenwick Tree in O(N log K). """
ft = FenwickTree(K)
inv_count = 0
for i in range(len(A_mapped)):
inv_count += ft.range_query(A_mapped[i] + 1, K) # Count larger elements before it
ft.update(A_mapped[i], 1) # Mark this number as visited
return inv_count, ft
def swap_and_update(swaps, P, A, K):
""" Efficiently update inversion count dynamically on each swap. """
# Position map for quick lookup
pos = {P[i]: i for i in range(K)}
# Convert A to mapped indices
A_mapped = [pos[x] for x in A]
# Compute initial inversion count
inv_count, ft = count_initial_inversions(A_mapped, K)
result = []
for j in swaps:
v1, v2 = P[j - 1], P[j]
# Get their positions in P
idx1, idx2 = pos[v1], pos[v2]
# Find elements in A affected by this swap
affected = [i for i in range(N) if A[i] == v1 or A[i] == v2]
# Remove affected elements from Fenwick Tree before swap
for i in affected:
ft.update(A_mapped[i], -1)
# Swap in P and update position map
P[j - 1], P[j] = P[j], P[j - 1]
pos[v1], pos[v2] = pos[v2], pos[v1]
# Update A_mapped for affected values
for i in affected:
A_mapped[i] = pos[A[i]]
# Reinsert into Fenwick Tree and adjust inversion count
new_inv_count = 0
for i in affected:
new_inv_count += ft.range_query(A_mapped[i] + 1, K)
ft.update(A_mapped[i], 1)
inv_count = new_inv_count
result.append(str(inv_count))
sys.stdout.write("\n".join(result) + "\n")
# Fast input reading
input = sys.stdin.read
data = input().split()
idx = 0
# Read N, K, Q
N, K, Q = map(int, data[idx:idx + 3])
idx += 3
# Read array A
A = list(map(int, data[idx:idx + N]))
idx += N
# Read swap queries
swaps = list(map(int, data[idx:idx + Q]))
# Initial target permutation
P = list(range(1, K + 1))
# Process swaps efficiently
swap_and_update(swaps, P, A, K)
Compilation message (stdout)
Compiling 'Main.py'...
=======
adding: __main__.pyc (deflated 47%)
=======
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |