Submission #1100874

#TimeUsernameProblemLanguageResultExecution timeMemory
1100874Kirill22Balanced Tree (info1cup18_balancedtree)C++17
100 / 100
2882 ms83956 KiB
#include "bits/stdc++.h"

using namespace std;

//#include "debug.h"

const int N = (int) 5e5 + 22;
const int inf = (int) 1e9;

vector<array<int, 3>> states;
int K;

struct Data {
    int dp[2][2][3]; // bad(tomin), good(tomin), empty(0)

    void init(int x) {
        for (auto [i, j, k] : states) {
            dp[i][j][k] = inf;
        }
        if (x == 0 || x == -1) {
            dp[0][0][2] = 0;
        }
        if (x == 1 || x == -1) {
            dp[1][0][2] = 0;
        }
    }

    void upd(int i, int j, int k, int val) {
        if (k == 2) {
            val = 0;
        }
        if (k == 0 && val >= K) {
            //
        } else {
            dp[i][j][k] = min(dp[i][j][k], val);
        }
    }
};

int n, a[N];
Data dp[N];
vector<int> g[N];

Data merge(const Data& a, const Data& b) {
    Data res;
    res.init(-2);
//    debug(b.dp[1][1][0]);
    for (auto [i, j, k] : states) {
        if (a.dp[i][j][k] == inf) {
            continue;
        }
//        debug(a.dp[1][1][2]);
        int _i = i, _j = j, _k = k;
        for (auto [i2, j2, k2] : states) {
            i = _i, j = _j, k = _k;
            if (b.dp[i2][j2][k2] == inf) {
                continue;
            }
            int adp = a.dp[i][j][k];
            int bdp = b.dp[i2][j2][k2];
//            if (i2 == 1 && j2 == 1 && k2 == 0) {
//                debug(bdp);
//            }
            bdp++;
//            debug(i, j, k, i2, j2, k2, adp, bdp, res.dp[1][1][1]);
            if (i == i2) {
                // dp[i][1]
                if (k == 2 && k2 == 2) {
                    res.upd(i, 1, 2, 0);
                } else if (k == 2) {
                    res.upd(i, 1, k2, bdp);
                } else if (k2 == 2) {
                    res.upd(i, 1, k, adp);
                } else {
                    if (adp + bdp <= K) {
                        k = k2 = 1;
                    }
                    if (k && k2) {
                        res.upd(i, 1, 1, min(adp, bdp));
                    } else {
                        res.upd(i, 1, 0, k ? bdp : k2 ? adp : max(adp, bdp));
                    }
                }
            } else {
//                if (i == 0 && j == 1 && k == 2 && i2 == 1 && j2 == 1 && k2 == 0) {
//                    debug(adp, bdp);
//                }
                if (k2 != 2 && bdp <= K) {
                    j = 1;
                    k2 = 1;
                }
                if (k != 2 && adp + 1 <= K) {
                    j2 = 1;
                    k = 1;
                }
                if (!k2 || (!j2 && 1 >= K) || (!k && adp >= K)) {
                    continue;
                }
                // res[i][j][...]
//                debug(k);
                if (k == 2) {
                    res.upd(i, j, j2, 1);
                } else if (k == 1) {
                    res.upd(i, j, j2, 1);
                } else {
                    res.upd(i, j, k, adp);
                }
            }
//            debug(i, j, k, i2, j2, k2, adp, bdp, res.dp[1][1][1]);
        }
    }
    return res;
}

array<int, 6> find(const Data& a, const Data& b, int X, int Y, int Z, int val) {
    Data res;
    res.init(-2);
//    debug(b.dp[1][1][0]);
    for (auto [i, j, k] : states) {
        if (a.dp[i][j][k] == inf) {
            continue;
        }
//        debug(a.dp[1][1][2]);
        int _i = i, _j = j, _k = k;
        for (auto [i2, j2, k2] : states) {
            i = _i, j = _j, k = _k;
            int _j2 = j2, _i2 = i2, _k2 = k2;
            if (b.dp[i2][j2][k2] == inf) {
                continue;
            }
            int adp = a.dp[i][j][k];
            int bdp = b.dp[i2][j2][k2];
//            if (i2 == 1 && j2 == 1 && k2 == 0) {
//                debug(bdp);
//            }
            bdp++;
//            debug(i, j, k, i2, j2, k2, adp, bdp, res.dp[1][1][1]);
            if (i == i2) {
                // dp[i][1]
                if (k == 2 && k2 == 2) {
                    res.upd(i, 1, 2, 0);
                    if (X == i && Y == 1 && Z == 2 && res.dp[X][Y][Z] == val) {
                        return {_i, _j, _k, i2, j2, k2};
                    }
                } else if (k == 2) {
                    res.upd(i, 1, k2, bdp);
                    if (X == i && Y == 1 && Z == k2 && res.dp[X][Y][Z] == val) {
                        return {_i, _j, _k, i2, j2, k2};
                    }
                } else if (k2 == 2) {
                    res.upd(i, 1, k, adp);
                    if (X == i && Y == 1 && Z == k && res.dp[X][Y][Z] == val) {
                        return {_i, _j, _k, _i2, _j2, _k2};
                    }
                } else {
                    if (adp + bdp <= K) {
                        k = k2 = 1;
                    }
                    if (k && k2) {
                        res.upd(i, 1, 1, min(adp, bdp));
                        if (X == i && Y == 1 && Z == 1 && res.dp[X][Y][Z] == val) {
                            return {_i, _j, _k, _i2, _j2, _k2};
                        }
                    } else {
                        res.upd(i, 1, 0, k ? bdp : k2 ? adp : max(adp, bdp));
                        if (X == i && Y == 1 && Z == 0 && res.dp[X][Y][Z] == val) {
                            return {_i, _j, _k, _i2, _j2, _k2};
                        }
                    }
                }
            } else {
//                if (i == 0 && j == 1 && k == 2 && i2 == 1 && j2 == 1 && k2 == 0) {
//                    debug(adp, bdp);
//                }
                if (k2 != 2 && bdp <= K) {
                    j = 1;
                    k2 = 1;
                }
                if (k != 2 && adp + 1 <= K) {
                    j2 = 1;
                    k = 1;
                }
                if (!k2 || (!j2 && 1 >= K) || (!k && adp >= K)) {
                    continue;
                }
                // res[i][j][...]
//                debug(k);
                if (k == 2) {
                    res.upd(i, j, j2, 1);
                    if (X == i && Y == j && Z == j2 && res.dp[X][Y][Z] == val) {
                        return {_i, _j, _k, _i2, _j2, _k2};
                    }
                } else if (k == 1) {
                    res.upd(i, j, j2, 1);
                    if (X == i && Y == j && Z == j2 && res.dp[X][Y][Z] == val) {
                        return {_i, _j, _k, _i2, _j2, _k2};
                    }
                } else {
                    res.upd(i, j, k, adp);
                    if (X == i && Y == j && Z == k && res.dp[X][Y][Z] == val) {
                        return {_i, _j, _k, _i2, _j2, _k2};
                    }
                }
            }
//            debug(i, j, k, i2, j2, k2, adp, bdp, res.dp[1][1][1]);
        }
    }
}

void dfs(int v, int pr) {
    dp[v].init(a[v]);
    for (auto u : g[v]) {
        if (u != pr) {
            dfs(u, v);
//            if (v == 1 && u == 3) {
//                debug(dp[v].dp[0][1][2]);
//                debug(dp[u].dp[1][1][0]);
//            }
            dp[v] = merge(dp[v], dp[u]);
//            debug(v);
        }
    }
//    debug(v);
}

void restore(int v, int pr, int X, int Y, int Z) {
    a[v] = X;
    dp[v].init(a[v]);
    vector<Data> tmp;
    tmp.push_back(dp[v]);
    for (auto u : g[v]) {
        if (u != pr) {
            dp[v] = merge(dp[v], dp[u]);
            tmp.push_back(dp[v]);
        }
    }
    int lst = (int) tmp.size() - 1;
    for (int it = (int) g[v].size() - 1; it >= 0; it--) {
        int u = g[v][it];
        if (u == pr) {
            continue;
        }
        auto [A, B, C, D, E, F] = find(tmp[lst - 1], dp[u], X, Y, Z, tmp[lst].dp[X][Y][Z]);
        X = A, Y = B, Z = C;
        restore(u, v, D, E, F);
        lst--;
    }
//    debug(v);
}

bool check(int _) {
    K = _;
//    debug(K);
    dfs(0, 0);
//    if (K == 6) {
//        debug(dp[0].dp[1][1][1]);
//        debug(dp[3].dp[1][1][0]);
//        debug(dp[1].dp[0][1][1]);
//        debug(dp[2].dp[0][0][2]);
//    }
    for (int it = 0; it < 2; it++) {
        if (dp[0].dp[it][1][1] != inf) {
//            debug(1);exit(0);
            return true;
        }
        if (dp[0].dp[it][1][2] != inf) {
//            debug(2); exit(0);
            return true;
        }
    }
    return false;
}

void restore(int _) {
    K = _;
//    debug(K);
    dfs(0, 0);
//    if (K == 6) {
//        debug(dp[0].dp[1][1][1]);
//        debug(dp[3].dp[1][1][0]);
//        debug(dp[1].dp[0][1][1]);
//        debug(dp[2].dp[0][0][2]);
//    }
    for (int it = 0; it < 2; it++) {
        if (dp[0].dp[it][1][1] != inf) {
            restore(0, 0, it, 1, 1);
            return;
        }
        if (dp[0].dp[it][1][2] != inf) {
            restore(0, 0, it, 1, 2);
            return;
        }
    }
}

void solve() {
    cin >> n;
    for (int i = 0; i < n; i++) {
        g[i].clear();
    }
    for (int i = 0, x, y; i + 1 < n; i++) {
        cin >> x >> y;
        x--, y--;
        g[x].push_back(y);
        g[y].push_back(x);
    }
    for (int i = 0; i < n; i++) {
        cin >> a[i];
    }
    int l = 0, r = n;
    if (!check(r)) {
        cout << -1 << '\n';
    } else {
        while (l + 1 < r) {
            int m = (l + r) / 2;
            if (check(m)) {
                r = m;
            } else {
                l = m;
            }
        }
        cout << r << '\n';
        restore(r);
        for (int i = 0; i < n; i++) {
            cout << a[i] << " ";
        }
        cout << '\n';
    }
}

int main() {
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    for (int i = 0; i < 2; i++) {
        for (int j = 0; j < 2; j++) {
            for (int k = 0; k < 3; k++) {
                states.push_back({i, j, k});
            }
        }
    }
    int t = 1;
    cin >> t;
    while (t--) {
        solve();
    }
}

Compilation message (stderr)

balancedtree.cpp: In function 'std::array<int, 6> find(const Data&, const Data&, int, int, int, int)':
balancedtree.cpp:208:1: warning: control reaches end of non-void function [-Wreturn-type]
  208 | }
      | ^
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...