Submission #1100874

#TimeUsernameProblemLanguageResultExecution timeMemory
1100874Kirill22Balanced Tree (info1cup18_balancedtree)C++17
100 / 100
2882 ms83956 KiB
#include "bits/stdc++.h" using namespace std; //#include "debug.h" const int N = (int) 5e5 + 22; const int inf = (int) 1e9; vector<array<int, 3>> states; int K; struct Data { int dp[2][2][3]; // bad(tomin), good(tomin), empty(0) void init(int x) { for (auto [i, j, k] : states) { dp[i][j][k] = inf; } if (x == 0 || x == -1) { dp[0][0][2] = 0; } if (x == 1 || x == -1) { dp[1][0][2] = 0; } } void upd(int i, int j, int k, int val) { if (k == 2) { val = 0; } if (k == 0 && val >= K) { // } else { dp[i][j][k] = min(dp[i][j][k], val); } } }; int n, a[N]; Data dp[N]; vector<int> g[N]; Data merge(const Data& a, const Data& b) { Data res; res.init(-2); // debug(b.dp[1][1][0]); for (auto [i, j, k] : states) { if (a.dp[i][j][k] == inf) { continue; } // debug(a.dp[1][1][2]); int _i = i, _j = j, _k = k; for (auto [i2, j2, k2] : states) { i = _i, j = _j, k = _k; if (b.dp[i2][j2][k2] == inf) { continue; } int adp = a.dp[i][j][k]; int bdp = b.dp[i2][j2][k2]; // if (i2 == 1 && j2 == 1 && k2 == 0) { // debug(bdp); // } bdp++; // debug(i, j, k, i2, j2, k2, adp, bdp, res.dp[1][1][1]); if (i == i2) { // dp[i][1] if (k == 2 && k2 == 2) { res.upd(i, 1, 2, 0); } else if (k == 2) { res.upd(i, 1, k2, bdp); } else if (k2 == 2) { res.upd(i, 1, k, adp); } else { if (adp + bdp <= K) { k = k2 = 1; } if (k && k2) { res.upd(i, 1, 1, min(adp, bdp)); } else { res.upd(i, 1, 0, k ? bdp : k2 ? adp : max(adp, bdp)); } } } else { // if (i == 0 && j == 1 && k == 2 && i2 == 1 && j2 == 1 && k2 == 0) { // debug(adp, bdp); // } if (k2 != 2 && bdp <= K) { j = 1; k2 = 1; } if (k != 2 && adp + 1 <= K) { j2 = 1; k = 1; } if (!k2 || (!j2 && 1 >= K) || (!k && adp >= K)) { continue; } // res[i][j][...] // debug(k); if (k == 2) { res.upd(i, j, j2, 1); } else if (k == 1) { res.upd(i, j, j2, 1); } else { res.upd(i, j, k, adp); } } // debug(i, j, k, i2, j2, k2, adp, bdp, res.dp[1][1][1]); } } return res; } array<int, 6> find(const Data& a, const Data& b, int X, int Y, int Z, int val) { Data res; res.init(-2); // debug(b.dp[1][1][0]); for (auto [i, j, k] : states) { if (a.dp[i][j][k] == inf) { continue; } // debug(a.dp[1][1][2]); int _i = i, _j = j, _k = k; for (auto [i2, j2, k2] : states) { i = _i, j = _j, k = _k; int _j2 = j2, _i2 = i2, _k2 = k2; if (b.dp[i2][j2][k2] == inf) { continue; } int adp = a.dp[i][j][k]; int bdp = b.dp[i2][j2][k2]; // if (i2 == 1 && j2 == 1 && k2 == 0) { // debug(bdp); // } bdp++; // debug(i, j, k, i2, j2, k2, adp, bdp, res.dp[1][1][1]); if (i == i2) { // dp[i][1] if (k == 2 && k2 == 2) { res.upd(i, 1, 2, 0); if (X == i && Y == 1 && Z == 2 && res.dp[X][Y][Z] == val) { return {_i, _j, _k, i2, j2, k2}; } } else if (k == 2) { res.upd(i, 1, k2, bdp); if (X == i && Y == 1 && Z == k2 && res.dp[X][Y][Z] == val) { return {_i, _j, _k, i2, j2, k2}; } } else if (k2 == 2) { res.upd(i, 1, k, adp); if (X == i && Y == 1 && Z == k && res.dp[X][Y][Z] == val) { return {_i, _j, _k, _i2, _j2, _k2}; } } else { if (adp + bdp <= K) { k = k2 = 1; } if (k && k2) { res.upd(i, 1, 1, min(adp, bdp)); if (X == i && Y == 1 && Z == 1 && res.dp[X][Y][Z] == val) { return {_i, _j, _k, _i2, _j2, _k2}; } } else { res.upd(i, 1, 0, k ? bdp : k2 ? adp : max(adp, bdp)); if (X == i && Y == 1 && Z == 0 && res.dp[X][Y][Z] == val) { return {_i, _j, _k, _i2, _j2, _k2}; } } } } else { // if (i == 0 && j == 1 && k == 2 && i2 == 1 && j2 == 1 && k2 == 0) { // debug(adp, bdp); // } if (k2 != 2 && bdp <= K) { j = 1; k2 = 1; } if (k != 2 && adp + 1 <= K) { j2 = 1; k = 1; } if (!k2 || (!j2 && 1 >= K) || (!k && adp >= K)) { continue; } // res[i][j][...] // debug(k); if (k == 2) { res.upd(i, j, j2, 1); if (X == i && Y == j && Z == j2 && res.dp[X][Y][Z] == val) { return {_i, _j, _k, _i2, _j2, _k2}; } } else if (k == 1) { res.upd(i, j, j2, 1); if (X == i && Y == j && Z == j2 && res.dp[X][Y][Z] == val) { return {_i, _j, _k, _i2, _j2, _k2}; } } else { res.upd(i, j, k, adp); if (X == i && Y == j && Z == k && res.dp[X][Y][Z] == val) { return {_i, _j, _k, _i2, _j2, _k2}; } } } // debug(i, j, k, i2, j2, k2, adp, bdp, res.dp[1][1][1]); } } } void dfs(int v, int pr) { dp[v].init(a[v]); for (auto u : g[v]) { if (u != pr) { dfs(u, v); // if (v == 1 && u == 3) { // debug(dp[v].dp[0][1][2]); // debug(dp[u].dp[1][1][0]); // } dp[v] = merge(dp[v], dp[u]); // debug(v); } } // debug(v); } void restore(int v, int pr, int X, int Y, int Z) { a[v] = X; dp[v].init(a[v]); vector<Data> tmp; tmp.push_back(dp[v]); for (auto u : g[v]) { if (u != pr) { dp[v] = merge(dp[v], dp[u]); tmp.push_back(dp[v]); } } int lst = (int) tmp.size() - 1; for (int it = (int) g[v].size() - 1; it >= 0; it--) { int u = g[v][it]; if (u == pr) { continue; } auto [A, B, C, D, E, F] = find(tmp[lst - 1], dp[u], X, Y, Z, tmp[lst].dp[X][Y][Z]); X = A, Y = B, Z = C; restore(u, v, D, E, F); lst--; } // debug(v); } bool check(int _) { K = _; // debug(K); dfs(0, 0); // if (K == 6) { // debug(dp[0].dp[1][1][1]); // debug(dp[3].dp[1][1][0]); // debug(dp[1].dp[0][1][1]); // debug(dp[2].dp[0][0][2]); // } for (int it = 0; it < 2; it++) { if (dp[0].dp[it][1][1] != inf) { // debug(1);exit(0); return true; } if (dp[0].dp[it][1][2] != inf) { // debug(2); exit(0); return true; } } return false; } void restore(int _) { K = _; // debug(K); dfs(0, 0); // if (K == 6) { // debug(dp[0].dp[1][1][1]); // debug(dp[3].dp[1][1][0]); // debug(dp[1].dp[0][1][1]); // debug(dp[2].dp[0][0][2]); // } for (int it = 0; it < 2; it++) { if (dp[0].dp[it][1][1] != inf) { restore(0, 0, it, 1, 1); return; } if (dp[0].dp[it][1][2] != inf) { restore(0, 0, it, 1, 2); return; } } } void solve() { cin >> n; for (int i = 0; i < n; i++) { g[i].clear(); } for (int i = 0, x, y; i + 1 < n; i++) { cin >> x >> y; x--, y--; g[x].push_back(y); g[y].push_back(x); } for (int i = 0; i < n; i++) { cin >> a[i]; } int l = 0, r = n; if (!check(r)) { cout << -1 << '\n'; } else { while (l + 1 < r) { int m = (l + r) / 2; if (check(m)) { r = m; } else { l = m; } } cout << r << '\n'; restore(r); for (int i = 0; i < n; i++) { cout << a[i] << " "; } cout << '\n'; } } int main() { ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0); for (int i = 0; i < 2; i++) { for (int j = 0; j < 2; j++) { for (int k = 0; k < 3; k++) { states.push_back({i, j, k}); } } } int t = 1; cin >> t; while (t--) { solve(); } }

Compilation message (stderr)

balancedtree.cpp: In function 'std::array<int, 6> find(const Data&, const Data&, int, int, int, int)':
balancedtree.cpp:208:1: warning: control reaches end of non-void function [-Wreturn-type]
  208 | }
      | ^
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...