답안 #1100874

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
1100874 2024-10-14T21:30:39 Z Kirill22 Balanced Tree (info1cup18_balancedtree) C++17
100 / 100
2882 ms 83956 KB
#include "bits/stdc++.h"

using namespace std;

//#include "debug.h"

const int N = (int) 5e5 + 22;
const int inf = (int) 1e9;

vector<array<int, 3>> states;
int K;

struct Data {
    int dp[2][2][3]; // bad(tomin), good(tomin), empty(0)

    void init(int x) {
        for (auto [i, j, k] : states) {
            dp[i][j][k] = inf;
        }
        if (x == 0 || x == -1) {
            dp[0][0][2] = 0;
        }
        if (x == 1 || x == -1) {
            dp[1][0][2] = 0;
        }
    }

    void upd(int i, int j, int k, int val) {
        if (k == 2) {
            val = 0;
        }
        if (k == 0 && val >= K) {
            //
        } else {
            dp[i][j][k] = min(dp[i][j][k], val);
        }
    }
};

int n, a[N];
Data dp[N];
vector<int> g[N];

Data merge(const Data& a, const Data& b) {
    Data res;
    res.init(-2);
//    debug(b.dp[1][1][0]);
    for (auto [i, j, k] : states) {
        if (a.dp[i][j][k] == inf) {
            continue;
        }
//        debug(a.dp[1][1][2]);
        int _i = i, _j = j, _k = k;
        for (auto [i2, j2, k2] : states) {
            i = _i, j = _j, k = _k;
            if (b.dp[i2][j2][k2] == inf) {
                continue;
            }
            int adp = a.dp[i][j][k];
            int bdp = b.dp[i2][j2][k2];
//            if (i2 == 1 && j2 == 1 && k2 == 0) {
//                debug(bdp);
//            }
            bdp++;
//            debug(i, j, k, i2, j2, k2, adp, bdp, res.dp[1][1][1]);
            if (i == i2) {
                // dp[i][1]
                if (k == 2 && k2 == 2) {
                    res.upd(i, 1, 2, 0);
                } else if (k == 2) {
                    res.upd(i, 1, k2, bdp);
                } else if (k2 == 2) {
                    res.upd(i, 1, k, adp);
                } else {
                    if (adp + bdp <= K) {
                        k = k2 = 1;
                    }
                    if (k && k2) {
                        res.upd(i, 1, 1, min(adp, bdp));
                    } else {
                        res.upd(i, 1, 0, k ? bdp : k2 ? adp : max(adp, bdp));
                    }
                }
            } else {
//                if (i == 0 && j == 1 && k == 2 && i2 == 1 && j2 == 1 && k2 == 0) {
//                    debug(adp, bdp);
//                }
                if (k2 != 2 && bdp <= K) {
                    j = 1;
                    k2 = 1;
                }
                if (k != 2 && adp + 1 <= K) {
                    j2 = 1;
                    k = 1;
                }
                if (!k2 || (!j2 && 1 >= K) || (!k && adp >= K)) {
                    continue;
                }
                // res[i][j][...]
//                debug(k);
                if (k == 2) {
                    res.upd(i, j, j2, 1);
                } else if (k == 1) {
                    res.upd(i, j, j2, 1);
                } else {
                    res.upd(i, j, k, adp);
                }
            }
//            debug(i, j, k, i2, j2, k2, adp, bdp, res.dp[1][1][1]);
        }
    }
    return res;
}

array<int, 6> find(const Data& a, const Data& b, int X, int Y, int Z, int val) {
    Data res;
    res.init(-2);
//    debug(b.dp[1][1][0]);
    for (auto [i, j, k] : states) {
        if (a.dp[i][j][k] == inf) {
            continue;
        }
//        debug(a.dp[1][1][2]);
        int _i = i, _j = j, _k = k;
        for (auto [i2, j2, k2] : states) {
            i = _i, j = _j, k = _k;
            int _j2 = j2, _i2 = i2, _k2 = k2;
            if (b.dp[i2][j2][k2] == inf) {
                continue;
            }
            int adp = a.dp[i][j][k];
            int bdp = b.dp[i2][j2][k2];
//            if (i2 == 1 && j2 == 1 && k2 == 0) {
//                debug(bdp);
//            }
            bdp++;
//            debug(i, j, k, i2, j2, k2, adp, bdp, res.dp[1][1][1]);
            if (i == i2) {
                // dp[i][1]
                if (k == 2 && k2 == 2) {
                    res.upd(i, 1, 2, 0);
                    if (X == i && Y == 1 && Z == 2 && res.dp[X][Y][Z] == val) {
                        return {_i, _j, _k, i2, j2, k2};
                    }
                } else if (k == 2) {
                    res.upd(i, 1, k2, bdp);
                    if (X == i && Y == 1 && Z == k2 && res.dp[X][Y][Z] == val) {
                        return {_i, _j, _k, i2, j2, k2};
                    }
                } else if (k2 == 2) {
                    res.upd(i, 1, k, adp);
                    if (X == i && Y == 1 && Z == k && res.dp[X][Y][Z] == val) {
                        return {_i, _j, _k, _i2, _j2, _k2};
                    }
                } else {
                    if (adp + bdp <= K) {
                        k = k2 = 1;
                    }
                    if (k && k2) {
                        res.upd(i, 1, 1, min(adp, bdp));
                        if (X == i && Y == 1 && Z == 1 && res.dp[X][Y][Z] == val) {
                            return {_i, _j, _k, _i2, _j2, _k2};
                        }
                    } else {
                        res.upd(i, 1, 0, k ? bdp : k2 ? adp : max(adp, bdp));
                        if (X == i && Y == 1 && Z == 0 && res.dp[X][Y][Z] == val) {
                            return {_i, _j, _k, _i2, _j2, _k2};
                        }
                    }
                }
            } else {
//                if (i == 0 && j == 1 && k == 2 && i2 == 1 && j2 == 1 && k2 == 0) {
//                    debug(adp, bdp);
//                }
                if (k2 != 2 && bdp <= K) {
                    j = 1;
                    k2 = 1;
                }
                if (k != 2 && adp + 1 <= K) {
                    j2 = 1;
                    k = 1;
                }
                if (!k2 || (!j2 && 1 >= K) || (!k && adp >= K)) {
                    continue;
                }
                // res[i][j][...]
//                debug(k);
                if (k == 2) {
                    res.upd(i, j, j2, 1);
                    if (X == i && Y == j && Z == j2 && res.dp[X][Y][Z] == val) {
                        return {_i, _j, _k, _i2, _j2, _k2};
                    }
                } else if (k == 1) {
                    res.upd(i, j, j2, 1);
                    if (X == i && Y == j && Z == j2 && res.dp[X][Y][Z] == val) {
                        return {_i, _j, _k, _i2, _j2, _k2};
                    }
                } else {
                    res.upd(i, j, k, adp);
                    if (X == i && Y == j && Z == k && res.dp[X][Y][Z] == val) {
                        return {_i, _j, _k, _i2, _j2, _k2};
                    }
                }
            }
//            debug(i, j, k, i2, j2, k2, adp, bdp, res.dp[1][1][1]);
        }
    }
}

void dfs(int v, int pr) {
    dp[v].init(a[v]);
    for (auto u : g[v]) {
        if (u != pr) {
            dfs(u, v);
//            if (v == 1 && u == 3) {
//                debug(dp[v].dp[0][1][2]);
//                debug(dp[u].dp[1][1][0]);
//            }
            dp[v] = merge(dp[v], dp[u]);
//            debug(v);
        }
    }
//    debug(v);
}

void restore(int v, int pr, int X, int Y, int Z) {
    a[v] = X;
    dp[v].init(a[v]);
    vector<Data> tmp;
    tmp.push_back(dp[v]);
    for (auto u : g[v]) {
        if (u != pr) {
            dp[v] = merge(dp[v], dp[u]);
            tmp.push_back(dp[v]);
        }
    }
    int lst = (int) tmp.size() - 1;
    for (int it = (int) g[v].size() - 1; it >= 0; it--) {
        int u = g[v][it];
        if (u == pr) {
            continue;
        }
        auto [A, B, C, D, E, F] = find(tmp[lst - 1], dp[u], X, Y, Z, tmp[lst].dp[X][Y][Z]);
        X = A, Y = B, Z = C;
        restore(u, v, D, E, F);
        lst--;
    }
//    debug(v);
}

bool check(int _) {
    K = _;
//    debug(K);
    dfs(0, 0);
//    if (K == 6) {
//        debug(dp[0].dp[1][1][1]);
//        debug(dp[3].dp[1][1][0]);
//        debug(dp[1].dp[0][1][1]);
//        debug(dp[2].dp[0][0][2]);
//    }
    for (int it = 0; it < 2; it++) {
        if (dp[0].dp[it][1][1] != inf) {
//            debug(1);exit(0);
            return true;
        }
        if (dp[0].dp[it][1][2] != inf) {
//            debug(2); exit(0);
            return true;
        }
    }
    return false;
}

void restore(int _) {
    K = _;
//    debug(K);
    dfs(0, 0);
//    if (K == 6) {
//        debug(dp[0].dp[1][1][1]);
//        debug(dp[3].dp[1][1][0]);
//        debug(dp[1].dp[0][1][1]);
//        debug(dp[2].dp[0][0][2]);
//    }
    for (int it = 0; it < 2; it++) {
        if (dp[0].dp[it][1][1] != inf) {
            restore(0, 0, it, 1, 1);
            return;
        }
        if (dp[0].dp[it][1][2] != inf) {
            restore(0, 0, it, 1, 2);
            return;
        }
    }
}

void solve() {
    cin >> n;
    for (int i = 0; i < n; i++) {
        g[i].clear();
    }
    for (int i = 0, x, y; i + 1 < n; i++) {
        cin >> x >> y;
        x--, y--;
        g[x].push_back(y);
        g[y].push_back(x);
    }
    for (int i = 0; i < n; i++) {
        cin >> a[i];
    }
    int l = 0, r = n;
    if (!check(r)) {
        cout << -1 << '\n';
    } else {
        while (l + 1 < r) {
            int m = (l + r) / 2;
            if (check(m)) {
                r = m;
            } else {
                l = m;
            }
        }
        cout << r << '\n';
        restore(r);
        for (int i = 0; i < n; i++) {
            cout << a[i] << " ";
        }
        cout << '\n';
    }
}

int main() {
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    for (int i = 0; i < 2; i++) {
        for (int j = 0; j < 2; j++) {
            for (int k = 0; k < 3; k++) {
                states.push_back({i, j, k});
            }
        }
    }
    int t = 1;
    cin >> t;
    while (t--) {
        solve();
    }
}

Compilation message

balancedtree.cpp: In function 'std::array<int, 6> find(const Data&, const Data&, int, int, int, int)':
balancedtree.cpp:208:1: warning: control reaches end of non-void function [-Wreturn-type]
  208 | }
      | ^
# 결과 실행 시간 메모리 Grader output
1 Correct 7 ms 14928 KB Output is correct
2 Correct 10 ms 14928 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 94 ms 15176 KB Output is correct
2 Correct 220 ms 19016 KB Output is correct
3 Correct 174 ms 15688 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 111 ms 15212 KB Output is correct
2 Correct 314 ms 53384 KB Output is correct
3 Correct 159 ms 31568 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 194 ms 17992 KB Output is correct
2 Correct 135 ms 15176 KB Output is correct
3 Correct 136 ms 15176 KB Output is correct
4 Correct 97 ms 15052 KB Output is correct
5 Correct 112 ms 15200 KB Output is correct
6 Correct 183 ms 15432 KB Output is correct
7 Correct 138 ms 15136 KB Output is correct
8 Correct 100 ms 15092 KB Output is correct
9 Correct 115 ms 15184 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 7 ms 14928 KB Output is correct
2 Correct 10 ms 14928 KB Output is correct
3 Correct 94 ms 15176 KB Output is correct
4 Correct 220 ms 19016 KB Output is correct
5 Correct 174 ms 15688 KB Output is correct
6 Correct 111 ms 15212 KB Output is correct
7 Correct 314 ms 53384 KB Output is correct
8 Correct 159 ms 31568 KB Output is correct
9 Correct 194 ms 17992 KB Output is correct
10 Correct 135 ms 15176 KB Output is correct
11 Correct 136 ms 15176 KB Output is correct
12 Correct 97 ms 15052 KB Output is correct
13 Correct 112 ms 15200 KB Output is correct
14 Correct 183 ms 15432 KB Output is correct
15 Correct 138 ms 15136 KB Output is correct
16 Correct 100 ms 15092 KB Output is correct
17 Correct 115 ms 15184 KB Output is correct
18 Correct 1053 ms 16632 KB Output is correct
19 Correct 1178 ms 23912 KB Output is correct
20 Correct 522 ms 15824 KB Output is correct
21 Correct 2882 ms 54860 KB Output is correct
22 Correct 1404 ms 83956 KB Output is correct