답안 #289944

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
289944 2020-09-03T08:54:47 Z pichulia Lampice (COCI19_lampice) C++17
17 / 110
1862 ms 18188 KB
#pragma GCC optimize("O3")
#pragma GCC optimize("Ofast")
#pragma GCC optimize("unroll-loops")
#pragma GCC target("avx,avx2")
#include<stdio.h>
#include<algorithm>
#include<vector>
#include<unordered_map>
#include<math.h>
#include<map>
#include<time.h>
#include<assert.h>
using namespace std;
using lld = long long int;
using pll = pair<long long int, long long int>;
using pii = pair<int, int>;
using pil = pair<int, long long int>;
using pli = pair<long long int, int>;
using piii = pair<pii, int>;
using lf = double;
using pff = pair<lf, lf>;
int n;
char a[50009];
bool centmark[50009];
vector<int> v[50009];
int sz[50009];
int dd[50009];
void getsz(int si, int pi) {
    int i, j, k;
    sz[si] = 1;
    for (i = 0; i < v[si].size(); i++) {
        j = v[si][i];
        if (j == pi)continue;
        if (centmark[j])continue;
        getsz(j, si);
        sz[si] += sz[j];
    }
}
int getcent(int si) {
    int pi = -1;
    int size = sz[si];
    while (1) {
        bool done = true;
        for (int i = 0; i < v[si].size(); i++) {
            int j = v[si][i];
            if (j == pi)continue;
            if (centmark[j])continue;
            if (sz[j] * 2 > size) {
                pi = si;
                si = j;
                done = false;
                break;
            }
        }
        if (done)break;
    }
    return si;
}
lld expp(lld p, lld q, lld m) {
    lld r = 1;
    lld z = p % m;
    while (q) {
        if (q & 1) { r = (r * z) % m; }
        q >>= 1;
        if (q) { z = (z * z) % m; }
    }
    return r;
}
int result = 1;
struct Hash {
    const static int HM = 5;
    int v[HM];
    static int M[HM];
    static int x[HM];
    static int rx[HM];
    Hash() {
        for (int i = 0; i < HM; i++) { v[i] = 0; }
    }
    void push(char z) {
        for (int i = 0; i < HM; i++) {
            v[i] = (1LL * v[i] * x[i] + z) % M[i];
        }
    }
    void rpush(char z) {
        for (int i = 0; i < HM; i++) {
            v[i] = (1LL * v[i] * rx[i] + z) % M[i];
        }
    }
    bool operator != (const Hash& z)const {
        for (int i = 0; i < HM; i++) {
            if (v[i] != z.v[i])return true;
        }
        return false;
    }
    bool operator == (const Hash& z)const {
        for (int i = 0; i < HM; i++) {
            if (v[i] != z.v[i])return false;
        }
        return true;
    }
    bool operator <(const Hash& z)const {
        for (int i = 0; i < HM; i++) {
            if (v[i] - z.v[i])return v[i] < z.v[i];
        }
        return false;
    }
    Hash shift(int z) const {
        Hash res;
        for (int i = 0; i < HM; i++) {
            lld diff = 1;
            if (z > 0) {
                diff = expp(x[i], z, M[i]);
            }
            else {
                diff = expp(rx[i], -z, M[i]);
            }
            res.v[i] = (1LL * v[i] * diff) % M[i];
        }
        return res;
    }
    Hash operator -(const Hash& z)const {
        Hash res;
        for (int i = 0; i < HM; i++) {
            res.v[i] = (v[i] - z.v[i]) % M[i];
            if (res.v[i] < 0)res.v[i] += M[i];
        }
        return res;
    }
    void print() {
        printf("{ ");
        for (int i = 0; i < HM; i++) {
            printf("%d, ", v[i]);
        }
        printf("}\n");
    }
};
int Hash::M[Hash::HM] = {
    //    1000000007,1000000009, 998244353, 1000100161, 1000100173,
        1000100183, 1000100201, 1000100239, 1000100279,1000100303,
        //1428571429,
};
int Hash::x[Hash::HM] = { };
int Hash::rx[Hash::HM] = { };

struct A {
    int head;
    int len;
    Hash h;
    Hash rh;
    A() { len = 0; head = -1; }
    A operator+(char z) const {
        A res;
        res.head = head;
        res.h = h;
        res.rh = rh;

        res.len = len + 1;
        res.h.push(z);
        res.rh.rpush(z);
        return res;
    }
    bool operator <(const A& z)const {
        if (len - z.len)return len < z.len;
        if (head - z.head) return head < z.head;
        if (h != z.h) return h < z.h;
        return rh < z.rh;
    }
};
A b[50009];
A c[50009];
int blist[50009];
int clist[50009];
int bl;
int cl;
vector<int> vd[50009];
int dmax_1;
int dmax_2;
void build(int dep, int si, int pi) {
    int i, j;
    dd[si] = dep;
    vd[dep].push_back(si);

    if (dmax_2 < dep) {
        dmax_2 = dep;
        if (dmax_1 < dmax_2) {
            swap(dmax_1, dmax_2);
        }
    }

    blist[bl++] = si;
    if (dep == 0) {
        b[si] = A();
        b[si].len = 1;
        b[si].h.push(a[si]);
        b[si].rh.rpush(a[si]);
        c[si] = A();
    }
    else {
        b[si] = b[pi] + a[si];
        c[si] = c[pi] + a[si];
        if (dep == 1) {
            b[si].head = si;
            c[si].head = si;
        }
    }
    for (i = 0; i < v[si].size(); i++) {
        j = v[si][i];
        if (j == pi)continue;
        if (centmark[j])continue;
        build(dep + 1, j, si);
    }
}


Hash bb[50009];
Hash cc[50009];
int ch[50009];
bool valid(int len) {
    //printf("len : %d, dmax_1 : %d\n", len, dmax_1);
    int i, j, k;
    int di, ei;
    int dsi = max(0, len - dmax_1 - 1);
    int dei = min(dmax_1, len - 1);
    // (ph * x + qr) * x^(ei-1) = (pr + qh * x) * x^(di)
    // ph * x^ ei - pr * x ^ di == qh * x ^ (di + 1) - qr * x ^ (ei - 1)
    cl = 0;
    for (i = 0; i < bl; i++) {
        int qi = blist[i];
        int di = dd[qi];
        int ei = len - di - 1;
        if (di >= dsi && di <= dei) {
            clist[cl++] = blist[i];
            bb[qi] = b[qi].h - b[qi].rh.shift(len - 1);
            cc[qi] = c[qi].h - c[qi].rh.shift(len - 1);
        }
        /*
        printf("about %d\n", qi);
        printf("%d %d\n", di, ei);
        printf("%d %d\n", b[qi].len, b[qi].head);
        b[qi].h.print();
        b[qi].rh.print();
        printf("%d %d\n", c[qi].len, c[qi].head);
        c[qi].h.print();
        c[qi].rh.print();

        bb[qi].print();
        cc[qi].print();
        */
    }
    sort(clist, clist + cl, [](int i, int j) {
        if (c[i].len - c[j].len)return c[i].len < c[j].len;
        if (cc[i] != cc[j])return cc[i] < cc[j];
        return c[i].head < c[j].head;
        });
    k = 0;
    ch[k] = 0;
    for (i = 0; i < cl;) {
        for (j = i; j < cl; j++) {
            if (c[clist[j]].len != k) {
                break;
            }
        }
        ch[k + 1] = j;
        k++;
        i = j;
    }
    for (di = dsi; di <= dei; di++) {
        ei = len - di - 1;
        // check b where dep == di and check c where dep == ci
        if (ei < 0 || ei >= k || ch[ei] == ch[ei + 1])continue;
        for (auto bi : vd[di]) {
            Hash& me = bb[bi];
            int bhead = b[bi].head;
            int l, r;
            l = ch[ei];
            r = ch[ei + 1];
            while (l < r) {
                int m = (l + r) / 2;
                int mi = clist[m];
                if (cc[mi] < me) l = m + 1;
                else r = m;
            }
            if (l < ch[ei + 1] && cc[clist[l]] == me) {
                if (c[clist[l]].head == bhead) {
                    if (ch[ei + 1] - l < 10) {
                        l++;
                        for (; l < ch[ei + 1]; l++) {
                            if (cc[clist[l]] != me)break;
                            if (c[clist[l]].head != bhead)return true;
                        }
                    }
                    else {
                        int ll = l + 1;
                        int rr = ch[ei + 1];
                        while (ll < rr) {
                            int mm = (ll + rr) / 2;
                            int mi = clist[mm];
                            if (me < cc[mi]) { rr = mm; }
                            else if (c[mi].head == bhead) { ll = mm + 1; }
                            else return true;
                        }
                        if (ll < ch[ei + 1] && cc[clist[ll]] == me) {
                            if (c[clist[ll]].head != bhead)return true;
                        }
                    }
                }
                else {
                    return true;
                }
            }
        }
    }
    return false;
}
void process(int si) {
    getsz(si, -1);
    if (sz[si] <= result)return;
    si = getcent(si);

    bl = 0;
    dmax_1 = dmax_2 = 0;
    build(0, si, -1);

    int l, r;
    // odd
    l = result + 1;
    r = dmax_1 + dmax_2 + 2;
    l = l / 2;
    r = r / 2;
    if (l < r && valid(l * 2 + 1)) {
        result = l * 2 + 1;
        l++;
        while (l < r) {
            int mid = (l + r) / 2;
            if (valid(mid * 2 + 1)) {
                //printf("centroid %d : mid %d success!!\n", si, mid);
                if (result < mid * 2 + 1)result = mid * 2 + 1;
                l = mid + 1;
            }
            else {
                //printf("centroid %d : mid %d fail\n", si, mid);
                r = mid;
            }
        }
    }
    // even
    l = result + 1;
    r = dmax_1 + dmax_2 + 2;
    l = (l + 1) / 2;
    r = (r + 1) / 2;
    if (l < r && valid(2 * l)) {
        result = 2 * l;
        l++;
        while (l < r) {
            int mid = (l + r) / 2;
            if (valid(mid * 2)) {
                //printf("centroid %d : mid %d success!!\n", si, mid);
                if (result < mid * 2)result = mid * 2;
                l = mid + 1;
            }
            else {
                //printf("centroid %d : mid %d fail\n", si, mid);
                r = mid;
            }
        }
    }

    for (int i = 0; i <= dmax_1; i++) {
        vd[i].clear();
    }

    centmark[si] = true;
    for (int i = 0; i < v[si].size(); i++) {
        int j = v[si][i];
        if (centmark[j])continue;
        process(j);
    }
}
void inil() {
    for (int i = 0; i < Hash::HM; i++) {
        Hash::x[i] = 257679;
        Hash::rx[i] = expp(Hash::x[i], Hash::M[i] - 2, Hash::M[i]);
    }
}
int main() {
    inil();
    int i, j, k, l;
    int t = 1, tv = 0;
    while (t--) {
        scanf("%d", &n);
        scanf("%s", a);
        a[0] ^= 73;
        for (i = 1; i < n; i++) {
            a[i] ^= 73;
            scanf("%d %d", &j, &k);
            j--; k--;
            v[j].push_back(k);
            v[k].push_back(j);
        }
        process(0);
        printf("%d\n", result);
    }
}

Compilation message

lampice.cpp: In function 'void getsz(int, int)':
lampice.cpp:31:19: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
   31 |     for (i = 0; i < v[si].size(); i++) {
      |                 ~~^~~~~~~~~~~~~~
lampice.cpp:29:15: warning: unused variable 'k' [-Wunused-variable]
   29 |     int i, j, k;
      |               ^
lampice.cpp: In function 'int getcent(int)':
lampice.cpp:44:27: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
   44 |         for (int i = 0; i < v[si].size(); i++) {
      |                         ~~^~~~~~~~~~~~~~
lampice.cpp: In function 'void build(int, int, int)':
lampice.cpp:206:19: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  206 |     for (i = 0; i < v[si].size(); i++) {
      |                 ~~^~~~~~~~~~~~~~
lampice.cpp: In function 'bool valid(int)':
lampice.cpp:230:13: warning: unused variable 'ei' [-Wunused-variable]
  230 |         int ei = len - di - 1;
      |             ^~
lampice.cpp: In function 'void process(int)':
lampice.cpp:373:23: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  373 |     for (int i = 0; i < v[si].size(); i++) {
      |                     ~~^~~~~~~~~~~~~~
lampice.cpp: In function 'int main()':
lampice.cpp:387:18: warning: unused variable 'l' [-Wunused-variable]
  387 |     int i, j, k, l;
      |                  ^
lampice.cpp:388:16: warning: unused variable 'tv' [-Wunused-variable]
  388 |     int t = 1, tv = 0;
      |                ^~
lampice.cpp:390:14: warning: ignoring return value of 'int scanf(const char*, ...)', declared with attribute warn_unused_result [-Wunused-result]
  390 |         scanf("%d", &n);
      |         ~~~~~^~~~~~~~~~
lampice.cpp:391:14: warning: ignoring return value of 'int scanf(const char*, ...)', declared with attribute warn_unused_result [-Wunused-result]
  391 |         scanf("%s", a);
      |         ~~~~~^~~~~~~~~
lampice.cpp:395:18: warning: ignoring return value of 'int scanf(const char*, ...)', declared with attribute warn_unused_result [-Wunused-result]
  395 |             scanf("%d %d", &j, &k);
      |             ~~~~~^~~~~~~~~~~~~~~~~
# 결과 실행 시간 메모리 Grader output
1 Correct 8 ms 9344 KB Output is correct
2 Correct 8 ms 9344 KB Output is correct
3 Correct 27 ms 9472 KB Output is correct
4 Correct 23 ms 9472 KB Output is correct
5 Correct 6 ms 9344 KB Output is correct
6 Correct 6 ms 9344 KB Output is correct
7 Correct 7 ms 9344 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 629 ms 17860 KB Output is correct
2 Incorrect 1862 ms 18188 KB Output isn't correct
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 827 ms 14732 KB Output isn't correct
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 8 ms 9344 KB Output is correct
2 Correct 8 ms 9344 KB Output is correct
3 Correct 27 ms 9472 KB Output is correct
4 Correct 23 ms 9472 KB Output is correct
5 Correct 6 ms 9344 KB Output is correct
6 Correct 6 ms 9344 KB Output is correct
7 Correct 7 ms 9344 KB Output is correct
8 Correct 629 ms 17860 KB Output is correct
9 Incorrect 1862 ms 18188 KB Output isn't correct
10 Halted 0 ms 0 KB -