답안 #530825

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
530825 2022-02-26T21:59:09 Z Alex_tz307 Lampice (COCI19_lampice) C++17
42 / 110
5000 ms 13556 KB
#include <bits/stdc++.h>
#pragma GCC optimize ("Ofast")
#pragma GCC target ("avx2")
#pragma GCC optimization ("O3")
#pragma GCC optimization ("unroll-loops")

using namespace std;

const int kN = 5e4;
const int mod = 1e9 + 9;
const int base = 29;
string s;
vector<int> g[1 + kN];
int len, m, p[1 + kN], sz[1 + kN], h1[1 + kN], h2[1 + kN], nodes[1 + kN];
vector<int> centroids;
bitset<1 + kN> vis;
unordered_multiset<int> preffixes[kN];

void addSelf(int &x, const int &y) {
  x += y;
  if (x >= mod) {
    x -= mod;
  }
}

int add(int x, const int &y) {
  addSelf(x, y);
  return x;
}

void multSelf(int &x, const int &y) {
  x = (int64_t)x * y % mod;
}

int mult(int x, const int &y) {
  multSelf(x, y);
  return x;
}

void precalc() {
  p[0] = 1;
  for (int i = 1; i <= kN; ++i) {
    p[i] = mult(p[i - 1], base);
  }
}

void findSize(int u, int par) {
  sz[u] = 1;
  for (int v : g[u]) {
    if (!vis[v] && v != par) {
      findSize(v, u);
      sz[u] += sz[v];
    }
  }
}

int findCentroid(int u, int par, int n) {
  for (int v : g[u]) {
    if (!vis[v] && v != par && sz[v] > n / 2) {
      return findCentroid(v, u, n);
    }
  }
  return u;
}

void build(int u) {
  findSize(u, 0);
  int c = findCentroid(u, 0, sz[u]);
  centroids.emplace_back(c);
  vis[c] = true;
  for (int v : g[c]) {
    if (!vis[v]) {
      build(v);
    }
  }
}

void dfs1(int u, int par, int depth) {
  h1[u] = add(mult(h1[par], base), s[u] - 'a' + 1);
  h2[u] = add(h2[par], mult(p[depth], s[u] - 'a' + 1));
  preffixes[depth + 1].emplace(h1[u]);
  for (int v : g[u]) {
    if (!vis[v] && v != par) {
      dfs1(v, u, depth + 1);
    }
  }
}

void addPaths(int u, int par, int depth) {
  preffixes[depth].emplace(h1[u]);
  for (int v : g[u]) {
    if (!vis[v] && v != par) {
      addPaths(v, u, depth + 1);
    }
  }
}

void removePaths(int u, int par, int depth) {
  preffixes[depth].erase(preffixes[depth].find(h1[u]));
  for (int v : g[u]) {
    if (!vis[v] && v != par) {
      removePaths(v, u, depth + 1);
    }
  }
}

bool dfs2(int u, int par, int depth) {
  if (depth == len) {
    return h1[u] == h2[u];
  }
  nodes[m++] = u;
  int d = len - depth + 1;
  if (d <= depth) {
    int split = nodes[m - d];
    if (h1[split] == h2[split]) {
      split = nodes[m - d - 1];
      int hsh = add(h1[u], mod - mult(p[d], h1[split]));
      if (preffixes[d].count(hsh)) {
        return true;
      }
    }
  }
  for (int v : g[u]) {
    if (!vis[v] && v != par && dfs2(v, u, depth + 1)) {
      return true;
    }
  }
  m -= 1;
  return false;
}

bool solve(int u) {
  for (int c : centroids) {
    vis[c] = true;
    dfs1(c, 0, 0);
    for (int v : g[c]) {
      if (!vis[v]) {
        removePaths(v, c, 2);
        m = 0;
        nodes[m++] = 0;
        nodes[m++] = c;
        if (dfs2(v, c, 2)) {
          addPaths(v, c, 2);
          removePaths(c, 0, 1);
          return true;
        }
        addPaths(v, c, 2);
      }
    }
    removePaths(c, 0, 1);
  }
  return false;
}

bool check(int n) {
  bool ret = solve(1);
  for (int v = 1; v <= n; ++v) {
    vis[v] = false;
  }
  return ret;
}

void testCase() {
  int n;
  cin >> n >> s;
  s = '$' + s;
  bool ok2 = false;
  for (int i = 1; i < n; ++i) {
    int u, v;
    cin >> u >> v;
    g[u].emplace_back(v);
    g[v].emplace_back(u);
    if (s[u] == s[v]) {
      ok2 = true;
    }
  }
  build(1);
  for (int v = 1; v <= n; ++v) {
    vis[v] = false;
  }
  int l = 1, r = (n - 1) / 2, ans1 = 1;
  while (l <= r) {
    int mid = (l + r) / 2;
    len = 2 * mid + 1;
    if (check(n)) {
      ans1 = 2 * mid + 1;
      l = mid + 1;
    } else {
      r = mid - 1;
    }
  }
  int ans2 = 0;
  l = 2, r = n / 2;
  while (l <= r) {
    int mid = (l + r) / 2;
    len = 2 * mid;
    if (check(n)) {
      ans2 = 2 * mid;
      l = mid + 1;
    } else {
      r = mid - 1;
    }
  }
  cout << max({1 + ok2, ans1, ans2}) << '\n';
}

int main() {
  ios_base::sync_with_stdio(false);
  cin.tie(nullptr);
  precalc();
  int tests = 1;
  for (int tc = 0; tc < tests; ++tc) {
    testCase();
  }
  return 0;
}

Compilation message

lampice.cpp:4: warning: ignoring '#pragma GCC optimization' [-Wunknown-pragmas]
    4 | #pragma GCC optimization ("O3")
      | 
lampice.cpp:5: warning: ignoring '#pragma GCC optimization' [-Wunknown-pragmas]
    5 | #pragma GCC optimization ("unroll-loops")
      |
# 결과 실행 시간 메모리 Grader output
1 Correct 6 ms 4428 KB Output is correct
2 Correct 13 ms 4428 KB Output is correct
3 Correct 48 ms 4556 KB Output is correct
4 Correct 66 ms 4556 KB Output is correct
5 Correct 3 ms 4428 KB Output is correct
6 Correct 3 ms 4428 KB Output is correct
7 Correct 3 ms 4428 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 3068 ms 12244 KB Output is correct
2 Correct 2209 ms 12488 KB Output is correct
3 Correct 1375 ms 12744 KB Output is correct
4 Correct 1722 ms 13132 KB Output is correct
5 Correct 2855 ms 13512 KB Output is correct
6 Correct 278 ms 13556 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Execution timed out 5041 ms 9916 KB Time limit exceeded
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 6 ms 4428 KB Output is correct
2 Correct 13 ms 4428 KB Output is correct
3 Correct 48 ms 4556 KB Output is correct
4 Correct 66 ms 4556 KB Output is correct
5 Correct 3 ms 4428 KB Output is correct
6 Correct 3 ms 4428 KB Output is correct
7 Correct 3 ms 4428 KB Output is correct
8 Correct 3068 ms 12244 KB Output is correct
9 Correct 2209 ms 12488 KB Output is correct
10 Correct 1375 ms 12744 KB Output is correct
11 Correct 1722 ms 13132 KB Output is correct
12 Correct 2855 ms 13512 KB Output is correct
13 Correct 278 ms 13556 KB Output is correct
14 Execution timed out 5041 ms 9916 KB Time limit exceeded
15 Halted 0 ms 0 KB -