답안 #894885

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
894885 2023-12-29T07:19:48 Z vjudge1 Palinilap (COI16_palinilap) C++17
100 / 100
717 ms 72588 KB
#include<bits/stdc++.h>

using namespace std;

mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
typedef long long ll;
#define int ll
typedef unsigned long long ull;
typedef long double ld;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;

#define pb push_back
#define all(x) x.begin(), x.end()
#define sz(x) (int)x.size()
#define mispertion ios_base::sync_with_stdio(0), cin.tie(0), cout.tie(0)
#define F first
#define S second
#define getlast(s) (*s.rbegin())
#define debg cout << "OK\n"

const ld PI = 3.1415926535;
const int N = 2e5 + 2;
const int M = 7e6 + 1;
int mod0 = 1e9+7, mod1 = 1e9+9;
const int infi = INT_MAX;
const ll infl = LLONG_MAX;
const int P = 31;

int mult(int a, int b, int mod) {
    return a * 1LL * b % mod;
}

int sum(int a, int b, int mod) {
    a %= mod; 
    if (a + b < 0)
        return a + b + mod;
    if (a + b >= mod)
        return a + b - mod;
    return a + b;
}

ll binpow(ll a, ll n, int mod) {
    if (n == 0)
        return 1;
    if (n % 2 == 1) {
        return binpow(a, n - 1, mod) * a % mod;
    } else {
        ll b = binpow(a, n / 2, mod);
        return b * b % mod;
    }
}

int p = 31;
int pp[N][2], ip[N][2];
int hss[N][2], hsr[N][2];
vector<pair<int, pii>> rs[N], ls[N];
vector<int> rss[N], lss[N];
string s;

pii gets(int l, int r){
    return {mult(sum(hss[r][0], -hss[l - 1][0], mod0), ip[l - 1][0], mod0), 
        mult(sum(hss[r][1], -hss[l - 1][1], mod1), ip[l - 1][1], mod1)};
}

pii getr(int l, int r){
    return {mult(sum(hsr[l][0], -hsr[r + 1][0], mod0), ip[sz(s) - r - 1][0], mod0),
        mult(sum(hsr[l][1], -hsr[r + 1][1], mod1), ip[sz(s) - r - 1][1], mod1)};
}

int stupid(string s){
    int ans = -infl;
    for(int i = 1; i < sz(s); i++){
        for(char c = 'a'; c <= 'z'; c++){
            char tmp = s[i];
            s[i] = c;
            int cur = 0;
            for(int l = 1; l < sz(s); l++){
                for(int r = l; r < sz(s); r++){
                    bool ok = true;
                    for(int i = 1; i <= (r - l + 1); i++)
                        if(s[l + i - 1] != s[r - i + 1])
                            ok = false;
                    cur += ok;
                }
            }
            ans = max (ans, cur);
            s[i] = tmp;
        }
    }
    for(int i = 1; i < sz(s); i++){
        for(char c = 'a'; c <= 'z'; c++){
            char tmp = s[i];
            s[i] = c;
            int cur = 0;
            for(int l = 1; l < sz(s); l++){
                for(int r = l; r < sz(s); r++){
                    bool ok = true;
                    for(int i = 1; i <= (r - l + 1); i++)
                        if(s[l + i - 1] != s[r - i + 1])
                            ok = false;
                    cur += ok;
                }
            }
            if(cur == ans)
                cout << i << ' ' << c << '\n';
            s[i] = tmp;
        }
    }
    return ans;
}

void solve(){
    cin >> s;
    s = "#" + s;
    hss[0][0] = 0;
    hss[0][1] = 0;
    for(int i = 1; i < sz(s); i++){
        hss[i][0] = sum(hss[i - 1][0], mult(s[i] - 'a' + 1, pp[i - 1][0], mod0), mod0);
        hss[i][1] = sum(hss[i - 1][1], mult(s[i] - 'a' + 1, pp[i - 1][1], mod1), mod1);
    }
    hsr[sz(s)][0] = 0;
    hsr[sz(s)][1] = 0;
    for(int i = sz(s) - 1; i >= 1; i--){
        hsr[i][0] = sum(hsr[i + 1][0], mult(pp[sz(s) - i - 1][0], s[i] - 'a' + 1, mod0), mod0);
        hsr[i][1] = sum(hsr[i + 1][1], mult(pp[sz(s) - i - 1][1], s[i] - 'a' + 1, mod1), mod1);
    }
    vector<pii> kakoizhegovnokodyapishu = {};
    for(int i = 1; i < sz(s); i++){
        int lo = 1, hi = min(i, sz(s) - i) + 1;
        while(lo + 1 < hi){
            int m = (lo + hi) / 2;
            if(gets(i - m + 1, i) == getr(i, i + m - 1))
                lo = m;
            else
                hi = m;
        }
        kakoizhegovnokodyapishu.pb({i - lo + 1, i + lo - 1});
        if(i < sz(s) - 1 && s[i] == s[i + 1]){
            lo = 1, hi = min(i, sz(s) - i - 1) + 1;
            while(lo + 1 < hi){
                int m = (lo + hi) / 2;
                if(gets(i - m + 1, i) == getr(i + 1, i + m))
                    lo = m;
                else
                    hi = m;
            }
            kakoizhegovnokodyapishu.pb({i - lo + 1, i + lo});
        }
    }
    int ret = 0;
    for(auto e : kakoizhegovnokodyapishu){
        int l = e.F, r = e.S;
        //cout << l << ' ' << r << '\n';
        ret += ((r - l + 2) / 2);
        rss[r].pb(l);
        lss[l].pb(r);
        if(l == r)
            continue;
        if((r - l + 1) % 2){
            int m = (l + r) / 2;
            ls[l].pb({m - 1, {1, 1}});
            rs[m - 1].pb({l, {1, 1}});
            ls[m + 1].pb({r, {r - m, -1}});
            rs[r].pb({m + 1, {r - m, -1}});
        }else{
            int m1 = (l + r) / 2, m2 = m1 + 1;
            ls[l].pb({m1, {1, 1}});
            rs[m1].pb({l, {1, 1}});
            ls[m2].pb({r, {r - m1, -1}});
            rs[r].pb({m2, {r - m1, -1}});
        }
    }
    //cout << ret << '\n';
    int td[sz(s)];
    int ca = 0, cd = 0;
    for(int i = 1; i < sz(s); i++){
        ca += cd;
        for(auto e : ls[i]){
            ca += e.S.F;
            cd += e.S.S;
        }
        td[i] = ca;
        for(auto e : rs[i]){
            int r = i, l = e.F;
            ca -= (e.S.F + (r - l) * e.S.S);
            cd -= e.S.S;
        }
    }
    int ans = ret;
    for(int i = 1; i < sz(s); i++){
        for(char c = 'a'; c <= 'z'; c++){
            if(c == s[i])
                continue;
            int cur = ret - td[i];
            //cout << i << ' ' << c << '\n';
            //cout << '\t' << cur << '\n';
            if(i < sz(s) - 1 && s[i + 1] == c){
                int lo = 0, hi = min(i, sz(s) - i - 1);
                while(lo + 1 < hi){
                    int m = (lo + hi) / 2;
                    if(gets(i - m, i - 1) == getr(i + 2, i + 1 + m))
                        lo = m;
                    else
                        hi = m;
                }
                cur += (lo + 1);
            }
            if(i - 1 >= 1 && s[i - 1] == c){
                int lo = 0, hi = min(i - 1, sz(s) - i);
                while(lo + 1 < hi){
                    int m = (lo + hi) / 2;
                    if(gets(i - m - 1, i - 2) == getr(i + 1, i + m))
                        lo = m;
                    else
                        hi = m;
                }
                cur += (lo + 1);
            }
            for(auto e : rss[i - 1]){
                int l = e, r = i - 1;
                if(c == s[l - 1]){
                    int lo = 0, hi = min(l - 1, sz(s) - i);
                    //cout << '\t' << l << ' ' << r << '\n';
                    while(lo + 1 < hi){
                        int m = (lo + hi) / 2;
                        if(gets(l - 1 - m, l - 2) == getr(i + 1, i + m))
                            lo = m;
                        else
                            hi = m;
                    }
                    //cout << '\t' << lo + 1 << '\n';
                    cur += (lo + 1);
                }
            }
            for(auto e : lss[i + 1]){
                int l = i + 1, r = e;
                if(r + 1 != sz(s) && c == s[r + 1]){
                    int lo = 0, hi = min(l - 1, sz(s) - r - 1);
                    //cout << '\t' << l << ' ' << r << '\n';
                    while(lo + 1 < hi){
                        int m = (lo + hi) / 2;
                        if(gets(l - 1 - m, l - 2) == getr(r + 2, r + 1 + m))
                            lo = m;
                        else
                            hi = m;
                    }
                    //cout << '\t' << lo + 1 << '\n';
                    cur += (lo + 1);
                }
            }
            //cout << '\t' << cur << '\n';
            ans = max(ans, cur);
        }
    }
    //cout << stupid(s) << '\n';
    cout << ans << '\n';
}   

signed main() {
    mispertion;
    pp[0][0] = 1, pp[0][1] = 1, ip[0][0] = 1, ip[0][1] = 1;
    for(int i = 1; i < N; i++){
        pp[i][0] = mult(pp[i - 1][0], p, mod0);
        pp[i][1] = mult(pp[i - 1][1], p, mod1);
        ip[i][0] = binpow(pp[i][0], mod0 - 2, mod0);
        ip[i][1] = binpow(pp[i][1], mod1 - 2, mod1);
    }
    int t = 1;
    //cin >> t;
    while(t--){
        solve();
    }
    return 0;
}

Compilation message

palinilap.cpp: In function 'void solve()':
palinilap.cpp:221:28: warning: unused variable 'r' [-Wunused-variable]
  221 |                 int l = e, r = i - 1;
      |                            ^
# 결과 실행 시간 메모리 Grader output
1 Correct 175 ms 29592 KB Output is correct
2 Correct 174 ms 29592 KB Output is correct
3 Correct 174 ms 29596 KB Output is correct
4 Correct 174 ms 29592 KB Output is correct
5 Correct 174 ms 29524 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 183 ms 31528 KB Output is correct
2 Correct 186 ms 31704 KB Output is correct
3 Correct 191 ms 30364 KB Output is correct
4 Correct 184 ms 30040 KB Output is correct
5 Correct 189 ms 30288 KB Output is correct
6 Correct 193 ms 31060 KB Output is correct
7 Correct 189 ms 30036 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 694 ms 54448 KB Output is correct
2 Correct 496 ms 72112 KB Output is correct
3 Correct 529 ms 72588 KB Output is correct
4 Correct 662 ms 50212 KB Output is correct
5 Correct 659 ms 49328 KB Output is correct
6 Correct 674 ms 49328 KB Output is correct
7 Correct 663 ms 50088 KB Output is correct
8 Correct 419 ms 72176 KB Output is correct
9 Correct 668 ms 49760 KB Output is correct
10 Correct 667 ms 49896 KB Output is correct
11 Correct 500 ms 72116 KB Output is correct
12 Correct 717 ms 72444 KB Output is correct
13 Correct 659 ms 53544 KB Output is correct
14 Correct 660 ms 49004 KB Output is correct
15 Correct 667 ms 49888 KB Output is correct
16 Correct 587 ms 71064 KB Output is correct
17 Correct 601 ms 41400 KB Output is correct
18 Correct 660 ms 50148 KB Output is correct
19 Correct 604 ms 41148 KB Output is correct