#include <bits/stdc++.h>
#define inf 2e9
#define int long long
#define all(v) v.begin(), v.end()
using namespace std;
typedef long long ll;
typedef long double ld;
typedef pair <int, int> pii;
mt19937_64 rnd(105);
const int N = 1000000 + 3;
ll n, q, a[N], p[N], ans_4, ans_3, cnt, b[N], sz[N];
ll code[N], sum[N], nsum[N];
map <ll, int> mp;
void init(){
for (int i = 0; i < n; i++){
p[i] = i;
sum[i] = code[a[i]];
nsum[i] = code[b[i]];
sz[i] = 1;
if (sum[i] == nsum[i]){
ans_3++;
}else{
ans_4 += mp[nsum[i] - sum[i]];
mp[sum[i] - nsum[i]]++;
}
}
}
int get_anc(int x){
if (x == p[x]) return x;
return p[x] = get_anc(p[x]);
}
void unite(int x, int y){
x = get_anc(x);
y = get_anc(y);
if (x == y) return;
cnt--;
if (sum[x] != nsum[x]){
mp[sum[x] - nsum[x]] -= sz[x];
ans_4 -= mp[nsum[x] - sum[x]] * sz[x];
}else ans_3--;
if (sum[y] != nsum[y]){
mp[sum[y] - nsum[y]] -= sz[y];
ans_4 -= mp[nsum[y] - sum[y]] * sz[y];
}else ans_3--;
if (rnd() & 1) swap(x, y);
nsum[x] += nsum[y];
sum[x] += sum[y];
p[y] = x;
sz[x] += sz[y];
if (sum[x] == nsum[x]) ans_3++;
else{
mp[sum[x] - nsum[x]] += sz[x];
ans_4 += mp[nsum[x] - sum[x]] * sz[x];
}
}
void sw(int x, int y){
int ax = get_anc(x);
int ay = get_anc(y);
if (ax == ay) return;
if (sum[ax] != nsum[ax]){
ans_4 -= mp[nsum[ax] - sum[ax]] * sz[ax];
mp[sum[ax] - nsum[ax]]-=sz[ax];
}else ans_3--;
if (sum[ay] != nsum[ay]){
ans_4 -= mp[nsum[ay] - sum[ay]] * sz[ay];
mp[sum[ay] - nsum[ay]]-=sz[ay];
}else ans_3--;
sum[ax] += code[a[y]] - code[a[x]];
sum[ay] += code[a[x]] - code[a[y]];
if (sum[ax] == nsum[ax]) ans_3++;
else{
mp[sum[ax] - nsum[ax]] += sz[ax];
ans_4 += mp[nsum[ax] - sum[ax]] * sz[ax];
}
if (sum[ay] == nsum[ay]) ans_3++;
else{
mp[sum[ay] - nsum[ay]] += sz[ay];
ans_4 += mp[nsum[ay] - sum[ay]] * sz[ay];
}
}
int32_t main()
{
ios_base::sync_with_stdio(false);
cin.tie(0);
#ifdef LOCAL
freopen("input.txt", "r", stdin);
#endif // LOCAL
cin >> n >> q;
for (int i = 0; i < n; i++){
cin >> a[i];
--a[i];
b[i] = a[i];
}
sort(b, b + n);
for (int i = 0; i < n; i++)
code[i] = rnd();
init();
cnt = n;
for (int i = 0; i < q; i++){
int tp, x, y;
cin >> tp;
if (tp == 1){
cin >> x >> y;
--x; --y;
sw(x, y);
}else
if (tp == 2){
cin >> x >> y;
--x; --y;
unite(x, y);
}else
if (tp == 3){
if (ans_3 == cnt) cout << "DA\n";
else cout << "NE\n";
}else{
cout << ans_4 << "\n";
}
}
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
1 ms |
384 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
1 ms |
384 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
1 ms |
384 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
1 ms |
384 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
1 ms |
404 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
6 ms |
1280 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
57 ms |
7416 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
1417 ms |
51740 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
2333 ms |
108792 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
1082 ms |
67512 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |