#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define pb push_back
#define sz(a) ((int) a.size())
#define ls u << 1
#define rs u << 1 | 1
const int N = 5e5 + 15;
const int M = 4e6 + 19;
int n, k, a[N], cur = 1;
vector<int> pos[N];
ll add[M], tr[M];
int ver[M], cnt[M];
inline void push(int u, int l, int r) {
if(ver[u] != cur) {
tr[u] = cnt[u] = add[u] = 0;
ver[u] = cur;
}
int len = r - l + 1;
tr[u] += add[u] * len + (len + 1) * len / 2 * cnt[u];
if(l != r) {
int m = (l + r) / 2;
if(ver[ls] != cur) {
ver[ls] = cur;
tr[ls] = cnt[ls] = add[ls] = 0;
}
if(ver[rs] != cur) {
ver[rs] = cur;
tr[rs] = cnt[rs] = add[rs] = 0;
}
add[ls] += add[u];
add[rs] += add[u] + (m - l + 1) * cnt[u];
cnt[ls] += cnt[u];
cnt[rs] += cnt[u];
}
add[u] = cnt[u] = 0;
return;
}
void upd(int ql, int qr, int u = 1, int l = 0, int r = 2 * n) {
push(u, l, r);
if(ql > r || l > qr) return;
if(ql <= l && r <= qr) {
add[u] += l - ql;
cnt[u]++;
push(u, l, r);
return;
}
int m = (l + r) >> 1;
upd(ql, qr, ls, l, m);
upd(ql, qr, rs, m + 1, r);
tr[u] = tr[ls] + tr[rs];
}
void upd2(int ql, int qr, int x, int u = 1, int l = 0, int r = 2 * n) {
if(ql > r || l > qr) return;
push(u, l, r);
if(ql <= l && r <= qr) {
add[u] += x;
push(u, l, r);
return;
}
int m = (l + r) >> 1;
upd2(ql, qr, x, ls, l, m);
upd2(ql, qr, x, rs, m + 1, r);
tr[u] = tr[ls] + tr[rs];
}
ll get(int ql, int qr, int u = 1, int l = 0, int r = 2 * n) {
if(ql > r || l > qr || ql > qr) return 0;
push(u, l, r);
if(ql <= l && r <= qr) return tr[u];
int m = (l + r) >> 1;
return get(ql, qr, ls, l, m) + get(ql, qr, rs, m + 1, r);
}
int main() {
ios::sync_with_stdio(0);
cin.tie(0);
cin >> n >> k;
for(int i = 1; i <= n; i++) {
cin >> a[i];
pos[a[i]].pb(i);
}
ll ans = 0;
for(int i = 1; i <= k; i++) {
if(!sz(pos[i])) continue;
cur = i;
upd(-pos[i][0] + 1 + n, n);
upd2(n + 1, n + n, pos[i][0]);
for(int j = 0; j < sz(pos[i]); j++) {
int cnt = j + 1;
int ps = pos[i][j];
int nt = n + 1;
if(j != sz(pos[i]) - 1) {
nt = pos[i][j + 1];
}
ans += get((cnt << 1) - nt + n, (cnt << 1) - ps + n - 1);
upd((cnt << 1) - nt + n + 1, (cnt << 1) - ps + n);
upd2((cnt << 1) - ps + n + 1, (n << 1), nt - ps);
}
}
cout << ans << '\n';
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Runtime error |
17 ms |
24276 KB |
Execution killed with signal 11 |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Runtime error |
17 ms |
24276 KB |
Execution killed with signal 11 |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
71 ms |
13888 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Runtime error |
17 ms |
24276 KB |
Execution killed with signal 11 |
2 |
Halted |
0 ms |
0 KB |
- |