답안 #1108009

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
1108009 2024-11-02T14:53:07 Z atom Lampice (COCI19_lampice) C++17
0 / 110
1309 ms 11344 KB
#include "bits/stdc++.h"
// @JASPER'S BOILERPLATE
using namespace std;
using ll = long long;

#ifdef JASPER
#include "debug.h"
#else
#define debug(...) 166
#endif

const int N = 5e4 + 5;
const int BASE = 256;
const int MOD = 1e9 + 7;

int n, k;
string a;
int sz[N];
bool rem[N];
vector <int> adj[N];
int pw[N];
 
int dfsSZ(int u, int p) {
    sz[u] = 1;
    for (int v: adj[u]) if (v != p && !rem[v]) {
        sz[u] += dfsSZ(v, u);
    }
    return sz[u];
}
 
int findCentroid(int u, int p, int n) {
    for (int v: adj[u]) if (v != p && !rem[v]) {
        if (2 * sz[v] > n)
            return findCentroid(v, u, n);
    }
    return u;
}
 
int h[N];
int maxHigh;
int hashU[N];
int hashD[N];
int value[N];
int nNode, node[N];
map<int, bool> mp[N];
 
void add(int u, int p, bool &found) {
    h[u] = h[p] + 1;
    if (h[u] + 1 > k) return;
    maxHigh = max(maxHigh, h[u]);
    hashU[u] = (1LL * hashU[p] * BASE + (a[u] - 'a' + 1)) % MOD;
    hashD[u] = (hashD[p] + 1LL * pw[h[u]] * (a[u] - 'a' + 1)) % MOD;
    int need = (1LL * pw[k - h[u] - 1] * hashD[u] - hashU[u] + MOD) % MOD;
    if (mp[k - h[u] - 1].find(need) != mp[k - h[u] - 1].end()) {
        found = true;
        return;
    }
    if (h[u] + 1 == k && hashD[u] == (hashU[u] + 1LL * pw[h[u]] * (a[u] - 'a' + 1)) % MOD) {
        found = true;
        return;
    }
    node[++nNode] = u; value[u] = need;
    for (int v: adj[u]) if (v != p && !rem[v]) {
        add(v, u, found);
        if (found) return;
    }
}
 
void solve(int u, int p, bool &found) {
    int n = dfsSZ(u, p), c = findCentroid(u, p, n); rem[c] = true;

    h[c] = 0;
    nNode = 0;
    maxHigh = 0;
    hashU[c] = 0;
    hashD[c] = a[c] - 'a' + 1;
    for (int v: adj[c]) if (v != p && !rem[v]) {
        add(v, c, found);
        if (found) return;
        while (nNode > 0) {
            int u = node[nNode--];
            mp[h[u]][value[u]] = 1;
        }
    }
    for (int i = 0; i <= maxHigh; ++i) {
        mp[i].clear();
    }
 
    for (int v: adj[c]) if (v != p && !rem[v]) {
        solve(v, c, found);
        if (found) return;
    }
}
 
bool func(int x) {
    memset(rem, false, sizeof rem);
    bool found = false; k = x;
    solve(1, 0, found);
    return found;
}

signed main() {
    cin.tie(0) -> sync_with_stdio(0);
 
    cin >> n >> a;
    a = "@" + a;

    for (int i = 1; i < n; ++i) {
    	int u, v;
    	cin >> u >> v;
    	adj[u].push_back(v);
    	adj[v].push_back(u);
    }

    pw[0] = 1;
    for (int i = 1; i <= n; ++i) {
    	pw[i] = (pw[i - 1] * BASE) % MOD;
    }

   	int ans = 1;
    for (int i = 0; i <= 1; ++i) {
        int l = 0, r = n / 2 + 1;
        while (l + 1 < r) {
            int mid = (l + r) / 2;
            if (func(i + 2 * mid)) 
            	l = mid; 
            else 
            	r = mid;
        }
        ans = max(ans, i + 2 * l);
    }
    cout << ans << "\n";

    return 0;
}

// Centroid: path of equal length
# 결과 실행 시간 메모리 Grader output
1 Incorrect 3 ms 5368 KB Output isn't correct
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 622 ms 11344 KB Output isn't correct
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 1309 ms 10060 KB Output isn't correct
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 3 ms 5368 KB Output isn't correct
2 Halted 0 ms 0 KB -