답안 #133983

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
133983 2019-07-21T20:51:45 Z evpipis Cats or Dogs (JOI18_catdog) C++11
0 / 100
4 ms 2680 KB
#include "catdog.h"
#include <bits/stdc++.h>
using namespace std;

#define pb push_back

const int len = 1e5+5;
int par[len], dep[len], dif[len], col[len];
int ans, n;
vector<int> adj[len];

void fix(int u){
    for (int j = 0; j < adj[u].size(); j++){
        int v = adj[u][j];
        if (v == par[u])
            continue;

        dep[v] = dep[u]+1;
        par[v] = u;
        fix(v);
    }
}

void upd(int u, int v, int x){
    while (dep[u] >= dep[v])
        dif[u] += x, u = par[u];
}

int fin(int u, int l, int r){
    int ans = 0;
    while (u != 0 && l <= dif[u] && dif[u] <= r)
        ans = u, u = par[u];
    return ans;
}

void change(int u, int t){
    /*
    3: red-green
    2: white-green
    1: red-white
    0: nothing
    -1: green-white
    -2: white-red
    -3: green-red
    */

    if (u == 0)
        return;

    //printf("change: u = %d, t = %d\n", u, t);

    if (t == 3){
        int v = fin(u, -1, -1);
        if (v != 0){
            upd(u, v, 2);

            v = par[v];
            if (v == 0) return;
        }
        else
            v = u;

        if (col[v] == 1)
            ans--;
        else if (col[v] == -1)
            ans++;
        else if (dif[v] <= -3)
            ans++;
        else if (dif[v] == -2)
            ans++, change(par[v], 1);
        else if (dif[v] == 0)
            ans--, change(par[v], 2);
        else if (dif[v] >= 1)
            ans--;
        upd(v, v, 2);
    }
    else if (t == 2){
        int v = fin(u, -1, 0), c = 0;
        //printf("v = %d\n", v);
        if (v != 0){
            if (dif[v] == -1)
                ans++, c = 1;
            upd(u, v, 1);

            v = par[v];
            if (v == 0) return;
        }
        else
            v = u;

        if (!c){
            if (col[v] == 1)
                ans += 0;
            else if (col[v] == -1)
                ans++;
            else if (dif[v] <= -2)
                ans++;
            else if (dif[v] >= 1)
                ans += 0;
        }
        else{
            if (col[v] == 1)
                ans--;
            else if (col[v] == -1)
                ans += 0;
            else if (dif[v] <= -2)
                ans += 0;
            else if (dif[v] >= 1)
                ans--;
        }
        upd(v, v, 1);
    }
    else if (t == 1){
        int v = fin(u, -1, 0), c = 0;
        //printf("v = %d\n", v);
        if (v != 0){
            if (dif[v] == 0)
                ans--, c = 1;
            upd(u, v, 1);

            v = par[v];
            if (v == 0) return;
        }
        else
            v = u;

        if (!c){
            if (col[v] == 1)
                ans--;
            else if (col[v] == -1)
                ans += 0;
            else if (dif[v] <= -2)
                ans += 0;
            else if (dif[v] >= 1)
                ans--;
        }
        else{
            if (col[v] == 1)
                ans += 0;
            else if (col[v] == -1)
                ans++;
            else if (dif[v] <= -2)
                ans++;
            else if (dif[v] >= 1)
                ans += 0;
        }
        upd(v, v, 1);
    }
    else if (t == -1){
        int v = fin(u, 0, 1), c = 0;
        if (v != 0){
            if (dif[v] == 0)
                ans--, c = 1;
            upd(u, v, -1);

            v = par[v];
            if (v == 0) return;
        }
        else
            v = u;

        if (!c){
            if (col[v] == 1)
                ans += 0;
            else if (col[v] == -1)
                ans--;
            else if (dif[v] >= 2)
                ans += 0;
            else if (dif[v] <= -1)
                ans--;
        }
        else{
            if (col[v] == 1)
                ans++;
            else if (col[v] == -1)
                ans += 0;
            else if (dif[v] >= 2)
                ans++;
            else if (dif[v] <= -1)
                ans += 0;
        }
        upd(v, v, -1);
    }
    else if (t == -2){
        int v = fin(u, 0, 1), c = 0;
        if (v != 0){
            if (dif[v] == 1)
                ans++, c = 1;
            upd(u, v, -1);

            v = par[v];
            if (v == 0) return;
        }
        else
            v = u;

        if (!c){
            if (col[v] == 1)
                ans++;
            else if (col[v] == -1)
                ans += 0;
            else if (dif[v] >= 2)
                ans++;
            else if (dif[v] <= -1)
                ans += 0;
        }
        else{
            if (col[v] == 1)
                ans += 0;
            else if (col[v] == -1)
                ans--;
            else if (dif[v] >= 2)
                ans += 0;
            else if (dif[v] <= -1)
                ans--;
        }
        upd(v, v, -1);
    }
    else if (t == -3){
        int v = fin(u, 1, 1);
        if (v != 0){
            upd(u, v, -2);

            v = par[v];
            if (v == 0) return;
        }
        else
            v = u;

        if (col[v] == 1)
            ans++;
        else if (col[v] == -1)
            ans--;
        else if (dif[v] >= 3)
            ans++;
        else if (dif[v] == 2)
            ans++, change(par[v], -1);
        else if (dif[v] == 0)
            ans--, change(par[v], -2);
        else if (dif[v] <= -1)
            ans--;
        upd(v, v, -2);
    }
}

void print(){
    for (int i = 1; i <= n; i++)
        printf("i = %d, dif = %d\n", i, dif[i]);
    printf("\n");
}

void initialize(int N, vector<int> A, vector<int> B){
    n = N;
    for (int i = 0; i < n-1; i++){
        int a = A[i], b = B[i];
        adj[a].pb(b);
        adj[b].pb(a);
    }

    fix(1), dep[0] = -1;
}

int cat(int u){
    col[u] = -1;
    if (dif[u] > 0)
        ans += dif[u], change(par[u], -3);
    else if (dif[u] == 0)
        change(par[u], -2);

    //print();
    return ans;
}

int dog(int u){
    col[u] = 1;
    if (dif[u] < 0)
        ans -= dif[u], change(par[u], 3);
    else if (dif[u] == 0)
        change(par[u], 2);

    //print();
    return ans;
}

int neighbor(int u){
    if (col[u] == 1){
        if (dif[u] < 0)
            ans += dif[u], change(par[u], -3);
        else if (dif[u] == 0)
            change(par[u], -1);
    }
    else{
        if (dif[u] > 0)
            ans -= dif[u], change(par[u], 3);
        else if (dif[u] == 0)
            change(par[u], 1);
    }
    col[u] = 0;

    //print();
    return ans;
}
/*
test cases:
5
1 2
2 3
2 4
1 5
4
1 3
2 4
2 5
3 3
*/

Compilation message

catdog.cpp: In function 'void fix(int)':
catdog.cpp:13:23: warning: comparison between signed and unsigned integer expressions [-Wsign-compare]
     for (int j = 0; j < adj[u].size(); j++){
                     ~~^~~~~~~~~~~~~~~
# 결과 실행 시간 메모리 Grader output
1 Incorrect 4 ms 2680 KB Output isn't correct
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 4 ms 2680 KB Output isn't correct
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 4 ms 2680 KB Output isn't correct
2 Halted 0 ms 0 KB -