#include <bits/stdc++.h>
#define pb push_back
#define F first
#define all(x) x.begin(), x.end()
#define debug(x) cerr << #x << " : " << x << '\n'
using namespace std;
typedef long long ll;
typedef long double ld;
typedef string str;
typedef pair<ll, ll> pll;
const ll Mod = 1000000007LL;
const int N = 1e5 + 10;
const int M = 52;
const ll Inf = 2242545357980376863LL;
const ll Log = 30;
int n, m;
int a[N][M], par[N * M];
ll ans = 0;
int Find(int u){
if(par[u] == u) return u;
return par[u] = Find(par[u]);
}
int Unite(int u, int v){
u = Find(u); v = Find(v);
if(u == v) return false;
if(u < v) swap(u, v);
par[v] = u;
return true;
}
int Unite2(int u, int v){
u = Find(u); v = Find(v);
if(u == v) return false;
if(u > v) swap(u, v);
par[v] = u;
return true;
}
int st[M][M], C[M];
int st2[M][M], C2[M];
ll S[M], S2[M];
void Solve(int L, int R){
//ll lans = ans;
if(L + 1 == R){
for(int i = 0; i < m; i++) ans += (a[L][i] == 1);
for(int i = 0; i + 1 < m; i++) ans -= (a[L][i] + a[L][i + 1] == 2);
return ;
}
int mid = (L + R) >> 1;
Solve(L, mid);
Solve(mid, R);
iota(par + L*m, par + R*m, L*m);
int sz1 = 0, sz2 = 0;
memset(C, 0, sizeof C);
memset(S, 0, sizeof S);
int nw = 0;
for(int j = 0; j + 1 < m; j++){
if(a[mid - 1][j] + a[mid - 1][j + 1] == 2) nw -= Unite((mid - 1)*m + j, (mid - 1)*m + j + 1);
nw += a[mid - 1][j];
}
nw += a[mid - 1][m - 1];
int p = 0, fl;
C[0] = 1;
S[0] = nw;
for(int j = 0; j < m; j++) st[p][j] = Find((mid-1)*m + j);
for(int i = mid - 2; i >= L; i--){
for(int j = 0; j < m; j++){
nw += a[i][j];
if(a[i][j] + a[i + 1][j] == 2) nw -= Unite(i*m + j, (i+1)*m + j);
}
for(int j = 0; j + 1 < m; j++){
if(a[i][j] + a[i][j + 1] == 2) nw -= Unite(i*m + j, i*m + j + 1);
}
fl = 0;
for(int j = 0; j < m; j++) if(st[p][j] != Find((mid - 1) * m + j)) fl = 1;
if(fl){
p ++;
for(int j = 0; j < m; j++) st[p][j] = Find((mid-1)*m + j);
}
C[p] ++; S[p] += nw;
}
sz1 = p + 1;
//memset(C, 0, sizeof C);
//memset(S, 0, sizeof S);
nw = 0;
for(int j = 0; j + 1 < m; j++){
if(a[mid][j] + a[mid][j + 1] == 2) nw -= Unite(mid*m + j, mid*m + j + 1);
nw += a[mid][j];
}
nw += a[mid][m - 1];
p = 0;
C2[0] = 1;
S2[0] = nw;
for(int j = 0; j < m; j++) st2[p][j] = Find(mid*m + j);
for(int i = mid + 1; i < R; i++){
for(int j = 0; j < m; j++){
nw += a[i][j];
if(a[i][j] + a[i - 1][j] == 2) nw -= Unite2(i*m + j, (i-1)*m + j);
}
for(int j = 0; j + 1 < m; j++){
if(a[i][j] + a[i][j + 1] == 2) nw -= Unite2(i*m + j, i*m + j + 1);
}
fl = 0;
for(int j = 0; j < m; j++) if(st2[p][j] != Find(mid * m + j)) fl = 1;
if(fl){
p ++;
for(int j = 0; j < m; j++) st2[p][j] = Find(mid*m + j);
}
C2[p] ++; S2[p] += nw;
}
sz2 = p + 1;
//lans = ans;
for(int i1 = 0; i1 < sz1; i1 ++) for(int i2 = 0; i2 < sz2; i2 ++){
int d = 0;
for(int j = 0; j < m; j++) par[(mid - 1)*m + j] = st[i1][j];
for(int j = 0; j < m; j++) par[mid*m + j] = st2[i2][j];
for(int j = 0; j < m; j++) if(a[mid - 1][j] + a[mid][j] == 2) d -= Unite((mid - 1)*m + j, mid*m + j);
ans += C[i1] * S2[i2];
ans += S[i1] * C2[i2];
ans += 1ll * d * C[i1] * C2[i2];
}
//cerr << L << ' ' << R << ' ' << ans - lans << '\n';
}
int main(){
ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
cin >> n >> m;
char c;
for(int i = 0; i < n; i++){
for(int j = 0; j < m; j++){
cin >> c;
a[i][j] = (c == '1');
}
}
Solve(0, n);
cout << ans << '\n';
return 0;
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
5 ms |
384 KB |
Output is correct |
2 |
Incorrect |
5 ms |
384 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
5 ms |
384 KB |
Output is correct |
2 |
Incorrect |
5 ms |
384 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
242 ms |
13944 KB |
Output is correct |
2 |
Incorrect |
324 ms |
28280 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
5 ms |
384 KB |
Output is correct |
2 |
Incorrect |
5 ms |
384 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |