제출 #44805

#제출 시각아이디문제언어결과실행 시간메모리
44805model_codeBalanced Tree (info1cup18_balancedtree)C++11
100 / 100
1723 ms246112 KiB
#include<bits/stdc++.h>

using namespace std;

int N, depth[500009], ans[500009], pattern[500009];
vector < int > initV[500009], v[500009];

int code[2][2][3], iVal[20], jVal[20], kVal[20];
struct info
{
    int dp[2][2][3];
    pair < int, int > how[2][2][3];
    info ()
    {
        for (int i=0; i<2; i++)
            for (int j=0; j<2; j++)
                for (int k=0; k<3; k++)
                    dp[i][j][k] = -1, how[i][j][k] = {-1, -1};
    }
    void update (int i, int j, int k, int newDp, pair < int, int > curr)
    {
        if (dp[i][j][k] == -1 || k == 2 || newDp < dp[i][j][k])
            dp[i][j][k] = newDp, how[i][j][k] = curr;

    }
}sub[500009], prefBro[500009];

void initDfs (int nod, int tata)
{
    for (auto it : initV[nod])
        if (it != tata)
            v[nod].push_back (it), depth[it] = depth[nod] + 1, initDfs (it, nod);
}

info addSon (int D, int rootDepth, info big, info newSon)
{
    info ans = info ();
    for (int v=0; v<2; v++)
    for (int tv=0; tv<2; tv++)
    for (int to=0; to<3; to++)
    if (big.dp[v][tv][to] != -1)
    for (int i=0; i<2; i++)
    for (int j=0; j<2; j++)
    for (int k=0; k<3; k++)
    if (newSon.dp[i][j][k] != -1)
    {
        if (v != i)
        {
            int newJ = 0, newK, valDp = rootDepth + 1;
            if (k == 2) newJ = tv;
            else
            {
                if (newSon.dp[i][j][k] - rootDepth <= D) newJ = 1;///they satisfy each other
                else
                if (k == 0) continue;///then newSon's unsatisfied opposite color node would die unpaired
                else newJ = tv;
            }
            if (to == 2) newK = j;
            else
            {
                if ((rootDepth + 1) + big.dp[v][tv][to] - 2 * rootDepth <= D) newK = 1;///they satisfy each other
                else
                if (to == 0) continue;///then big's unsatisfied opposite color node would die unpaired
                else newK = j;
            }
            ans.update (v, newJ, newK, valDp, {code[v][tv][to], code[i][j][k]});
            continue;
        }
        int newJ = 1, newK, valDp;
        if (k == 2) newK = to, valDp = big.dp[v][tv][to];
        else
        if (to == 2) newK = k, valDp = newSon.dp[i][j][k];
        else
        {
            if (newSon.dp[i][j][k] + big.dp[v][tv][to] - 2 * rootDepth <= D || k + to == 2)
                newK = 1, valDp = min (big.dp[v][tv][to], newSon.dp[i][j][k]);
            else
            {
                newK = 0;
                if (k + to == 0) valDp = max (big.dp[v][tv][to], newSon.dp[i][j][k]);
                else
                if (k == 0) valDp = newSon.dp[i][j][k];
                else valDp = big.dp[v][tv][to];
            }
        }
        ans.update (v, newJ, newK, valDp, {code[v][tv][to], code[i][j][k]});
    }
    return ans;
}

void solve (int nod, int D)
{
    info aux;
    if (pattern[nod] != -1) aux.dp[pattern[nod]][0][2] = 3 * N;
    else aux.dp[0][0][2] = aux.dp[1][0][2] = 3 * N;
    if (v[nod].empty ())
    {
        sub[nod] = aux;
        return ;
    }
    for (auto it : v[nod])
        solve (it, D);
    prefBro[v[nod][0]] = addSon (D, depth[nod], aux, sub[v[nod][0]]);
    for (int i=1; i<v[nod].size (); i++)
        prefBro[v[nod][i]] = addSon (D, depth[nod], prefBro[v[nod][i - 1]], sub[v[nod][i]]);
    sub[nod] = prefBro[v[nod][v[nod].size () - 1]];
}

bool isAchievable (int D)
{
    solve (1, D);
    if (sub[1].dp[0][1][1] != -1 || sub[1].dp[1][1][1] != -1) return 1;
    if (sub[1].dp[0][1][2] != -1 || sub[1].dp[1][1][2] != -1) return 1;
    return 0;
}

void build (int nod, int i, int j, int k)
{
    ans[nod] = i;
    if (v[nod].empty ()) return ;
    for (int pos=v[nod].size () - 1; pos>=0; pos--)
    {
        int curr = v[nod][pos];
        pair < int, int > how = prefBro[curr].how[i][j][k];
        build (curr, iVal[how.second], jVal[how.second], kVal[how.second]);
        if (pos == 0) break;
        j = jVal[how.first], k = kVal[how.first], assert (i == iVal[how.first]);
    }
}

void reconstitution (int D)
{
    if (sub[1].dp[0][1][1] != -1 || sub[1].dp[0][1][2] != -1) ans[1] = 0;
    else ans[1] = 1, assert (sub[1].dp[1][1][1] != -1 || sub[1].dp[1][1][2]);
    if (sub[1].dp[ans[1]][1][1] != -1) build (1, ans[1], 1, 1);
    else build (1, ans[1], 1, 2), assert (sub[1].dp[ans[1]][1][2] != -1);
}

void cleanUp ()
{
    for (int i=1; i<=N; i++)
        initV[i].clear (), v[i].clear (),
        sub[i] = prefBro[i] = info ();
}

pair < int, int > paint[500009];
void calcMinDist (vector < int > nodes, int minDist[])
{
    for (int i=1; i<=N; i++)
        paint[i] = {-1, 0}, minDist[i] = -1;
    queue < int > cc;
    for (auto it : nodes)
        cc.push (it), paint[it] = {0, it};
    while (!cc.empty ())
    {
        int nod = cc.front ();
        cc.pop ();
        for (auto it : initV[nod])
            if (paint[it].first == -1)
                paint[it] = {paint[nod].first + 1, paint[nod].second}, cc.push (it);
            else
            if (paint[it].second != paint[nod].second)
            {
                int x = paint[nod].second, y = paint[it].second, d = paint[nod].first + paint[it].first + 1;
                if (minDist[x] == -1 || d < minDist[x])
                    minDist[x] = d;
                if (minDist[y] == -1 || d < minDist[y])
                    minDist[y] = d;
            }
    }
}

void calcMinDistFromQuestionMark (int minDist[])
{
    queue < int > cc;
    for (int i=1; i<=N; i++)
        if (pattern[i] == -1) minDist[i] = 0, cc.push (i);
        else minDist[i] = -1;
    while (!cc.empty ())
    {
        int nod = cc.front ();
        cc.pop ();
        for (auto it : initV[nod])
            if (minDist[it] == -1)
                minDist[it] = minDist[nod] + 1, cc.push (it);
    }
}

int minD[3][500009];
int getLowerBound ()
{
    int ans = 1;
    vector < int > t[2];
    for (int i=1; i<=N; i++)
        if (pattern[i] != -1)
            t[pattern[i]].push_back (i);
    calcMinDist (t[0], minD[0]);
    calcMinDist (t[1], minD[1]);
    calcMinDistFromQuestionMark (minD[2]);
    for (int i=1; i<=N; i++)
        if (pattern[i] != -1)
        {
            int curr = minD[pattern[i]][i];
            if (minD[2][i] != -1 && (minD[2][i] < curr || curr == -1))
                curr = minD[2][i];
            if (curr > ans)
                ans = curr;
        }
    return ans;
}

bool read ()
{
    scanf ("%d", &N);
    for (int i=1; i<N; i++)
    {
        int x, y;
        scanf ("%d %d", &x, &y);
        initV[x].push_back (y);
        initV[y].push_back (x);
    }
    int frq[2], questionMarks = 0;
    frq[0] = frq[1] = 0;
    for (int i=1; i<=N; i++)
    {
        scanf ("%d", &pattern[i]);
        if (pattern[i] == -1) questionMarks ++;
        else frq[pattern[i]] ++;
    }
    if (N == 1) return 0;
    if (questionMarks == 0 && (frq[0] == 1 || frq[1] == 1)) return 0;
    if (questionMarks == 1 && frq[0] == 1 && frq[1] == 1) return 0;
    return 1;
}

int main ()
{
//freopen ("input", "r", stdin);
//freopen ("output", "w", stdout);

int Tests, codeIndex = 0;
scanf ("%d", &Tests);
for (int i=0; i<2; i++)
    for (int j=0; j<2; j++)
        for (int k=0; k<3; k++)
            code[i][j][k] = ++codeIndex, iVal[codeIndex] = i, jVal[codeIndex] = j, kVal[codeIndex] = k;
while (Tests --)
{
    bool anyChance = read ();
    if (anyChance == 0) printf ("-1\n");
    else
    {
        initDfs (1, -1);

        int ras = getLowerBound ();
        if (!isAchievable (ras))
        {
            if (ras > 1)
            {
                ras ++;
                assert (isAchievable (ras));
            }
            else
            {
                ras ++;
                if (!isAchievable (ras))
                    ras ++, assert (isAchievable (ras));
            }
        }

        printf ("%d\n", ras);
        reconstitution (ras);
        for (int i=1; i<=N; i++)
            printf ("%d%c", ans[i], " \n"[i == N]);
    }
    cleanUp ();
}

return 0;
}

컴파일 시 표준 에러 (stderr) 메시지

balancedtree.cpp: In function 'void solve(int, int)':
balancedtree.cpp:104:20: warning: comparison between signed and unsigned integer expressions [-Wsign-compare]
     for (int i=1; i<v[nod].size (); i++)
                   ~^~~~~~~~~~~~~~~
balancedtree.cpp: In function 'bool read()':
balancedtree.cpp:214:11: warning: ignoring return value of 'int scanf(const char*, ...)', declared with attribute warn_unused_result [-Wunused-result]
     scanf ("%d", &N);
     ~~~~~~^~~~~~~~~~
balancedtree.cpp:218:15: warning: ignoring return value of 'int scanf(const char*, ...)', declared with attribute warn_unused_result [-Wunused-result]
         scanf ("%d %d", &x, &y);
         ~~~~~~^~~~~~~~~~~~~~~~~
balancedtree.cpp:226:15: warning: ignoring return value of 'int scanf(const char*, ...)', declared with attribute warn_unused_result [-Wunused-result]
         scanf ("%d", &pattern[i]);
         ~~~~~~^~~~~~~~~~~~~~~~~~~
balancedtree.cpp: In function 'int main()':
balancedtree.cpp:242:7: warning: ignoring return value of 'int scanf(const char*, ...)', declared with attribute warn_unused_result [-Wunused-result]
 scanf ("%d", &Tests);
 ~~~~~~^~~~~~~~~~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...