답안 #525726

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
525726 2022-02-12T16:16:59 Z qwerasdfzxcl Mountains and Valleys (CCO20_day1problem3) C++14
0 / 25
19 ms 15564 KB
#include <bits/stdc++.h>

typedef long long ll;
using namespace std;
struct Query{
    int x, y, z;
    Query(){}
    Query(int _x, int _y, int _z): x(_x), y(_y), z(_z) {}
};
vector<int> adj[500500];
vector<Query> E;
int n;
priority_queue<pair<int, int>, vector<pair<int, int>>, greater<pair<int, int>>> pqD, pqF;

int dep[500500], par[500500], dp_D[500500], dp_far[500500], D[500500], deep[500500];
void dfs_dp1(int s, int pa = -1){
    //printf(" %d", s);
    par[s] = pa;

    for (auto &v:adj[s]) if (v!=pa){
        dep[v] = dep[s]+1;
        dfs_dp1(v, s);

        dp_D[s] = max(dp_D[s], dp_D[v]);
        dp_far[s] = max(dp_far[s], dp_far[v]+1);

        pqD.emplace(dp_D[v], v);
        pqF.emplace(dp_far[v]+1, v);
        if (pqD.size()>3) {pqD.pop(); pqF.pop();}
    }

    pair<int, int> Drr[3], Frr[3];
    for (int i=pqD.size();i<3;i++) Drr[i] = Frr[i] = {0, -1};

    for (int i=(int)pqD.size()-1;i>=0;i--){
        Drr[i] = pqD.top(); pqD.pop();
    }
    for (int i=(int)pqF.size()-1;i>=0;i--){
        Frr[i] = pqF.top(); pqF.pop();
    }

    dp_D[s] = max(dp_D[s], Frr[0].first + Frr[1].first);

    ///calc D, deep
    for (auto &v:adj[s]) if (v!=pa){
        if (v==Drr[0].second) D[v] = Drr[1].first;
        else D[v] = Drr[0].first;

        if (v==Frr[0].second){
            D[v] = max(D[v], Frr[1].first + Frr[2].first);
            deep[v] = Frr[1].first;
        }
        else if (v==Frr[1].second){
            D[v] = max(D[v], Frr[0].first + Frr[2].first);
            deep[v] = Frr[0].first;
        }
        else{
            D[v] = max(D[v], Frr[0].first + Frr[1].first);
            deep[v] = Frr[0].first;
        }
    }
}

pair<int, int> Dc[500500][4], Fc[500500][4];
int dp_parD[500500], dp_parF[500500];

void dfs_dp2(int s, int pa = -1){
    if (pa!=-1){
        pqD.emplace(dp_parD[s], pa);
        pqF.emplace(dp_parF[s]+1, pa);
    }

    for (auto &v:adj[s]) if (v!=pa){
        pqD.emplace(dp_D[v], v);
        pqF.emplace(dp_far[v]+1, v);
        if (pqD.size()>4) {pqD.pop(); pqF.pop();}
    }


    for (int i=pqD.size();i<4;i++) Dc[s][i] = Fc[s][i] = {0, -1};

    for (int i=(int)pqD.size()-1;i>=0;i--){
        Dc[s][i] = pqD.top(); pqD.pop();
    }
    for (int i=(int)pqF.size()-1;i>=0;i--){
        Fc[s][i] = pqF.top(); pqF.pop();
    }

    ///calc dp_par
    for (auto &v:adj[s]) if (v!=pa){
        if (v==Dc[s][0].second) dp_parD[v] = Dc[s][1].first;
        else dp_parD[v] = Dc[s][0].first;

        if (v==Fc[s][0].second){
            dp_parD[v] = max(dp_parD[v], Fc[s][1].first + Fc[s][2].first);
            dp_parF[v] = Fc[s][1].first;
        }
        else if (v==Fc[s][1].second){
            dp_parD[v] = max(dp_parD[v], Fc[s][0].first + Fc[s][2].first);
            dp_parF[v] = Fc[s][0].first;
        }
        else{
            dp_parD[v] = max(dp_parD[v], Fc[s][0].first + Fc[s][1].first);
            dp_parF[v] = Fc[s][0].first;
        }
    }

    for (auto &v:adj[s]) if (v!=pa) dfs_dp2(v, s);
}

struct Node{
    int ans, a, b;
    Node(){}
    Node(int _ans, int _a, int _b): ans(_ans), a(_a), b(_b) {}
    Node operator +(const Node &R) const{
        return Node(max(max(ans, R.ans), a+R.b), max(a, R.a), max(b, R.b));
    }
};
pair<int, int> sp1[500500][20];
Node sp2[500500][20];

void build(int n){
    for (int i=0;i<n;i++){
        sp1[i][0] = {D[i], par[i]};
        sp2[i][0] = Node(-1e9, deep[i]-(dep[i]-1), deep[i]+(dep[i]-1));
    }

    for (int j=1;j<20;j++){
        for (int i=0;i<n;i++) if (dep[i]>=(1<<j)){
            int tmp = sp1[i][j-1].second;

            sp1[i][j].first = max(sp1[i][j-1].first, sp1[tmp][j-1].first);
            sp1[i][j].second = sp1[tmp][j-1].second;
            sp2[i][j] = sp2[i][j-1] + sp2[tmp][j-1];
        }
    }
}

int prV, prW;
int get_lca(int v, int w){
    if (dep[v]<dep[w]) swap(v, w);

    if(dep[v]!=dep[w]){
        int dist = dep[v] - dep[w] - 1;
        for (int j=0;dist>0;j++) if (dist&(1<<j)){
            v = sp1[v][j].second;
            dist -= 1<<j;
        }
    }
    prV = v, prW = w;
    if (dep[v]!=dep[w]) v = sp1[v][0].second;

    if (v==w) return v;

    for (int j=19;j>=0;j--) if (sp1[v][j].second!=sp1[w][j].second){
        v = sp1[v][j].second;
        w = sp1[w][j].second;
    }
    prV = v, prW = w;
    return sp1[v][0].second;
}

int calc1(int x, int y, int w){
    int lca = get_lca(x, y);
    int ret = 0;

    if (x!=lca) ret = max(ret, dp_D[x]);
    if (y!=lca) ret = max(ret, dp_D[y]);

    int rdist = dep[x] + dep[y] - dep[lca]*2;

    int dist = dep[x] - dep[lca] - 1;
    for (int j=0;dist>0;j++) if (dist&(1<<j)){
        ret = max(ret, sp1[x][j].first);
        x = sp1[x][j].second;
        dist -= 1<<j;
    }

    dist = dep[y] - dep[lca] - 1;
    for (int j=0;dist>0;j++) if (dist&(1<<j)){
        ret = max(ret, sp1[y][j].first);
        y = sp1[y][j].second;
        dist -= 1<<j;
    }

    int S = 0, cnt = 0;
    for (int i=0;i<4;i++) if (Dc[lca][i].second!=x && Dc[lca][i].second!=y) ret = max(ret, Dc[lca][i].first);
    for (int i=0;i<4;i++) if (Fc[lca][i].second!=x && Fc[lca][i].second!=y){
        S += Fc[lca][i].first;
        cnt++;
        if (cnt==2) break;
    }
    ret = max(ret, S);

    return (n-1+w)*2 - ret - rdist - w;
}

int calc2(int x, int y, int w){
    int lca = get_lca(x, y);
    int ret = -1e9, rdist = dep[x] + dep[y] - dep[lca]*2;

    int val = -1;
    for (int i=0;i<4;i++) if (Fc[lca][i].second!=prV && Fc[lca][i].second!=prW) {val = Fc[lca][i].first; break;}
    assert(val!=-1);

    ///left chain
    int tx = x, ty = y;
    Node cur(-1e9, dp_far[tx] - dep[tx], dp_far[tx] + dep[tx]);
    Node Top(-1e9, val - dep[lca], val + dep[lca]);

    int dist = dep[tx] - dep[lca] - 1;
    for (int j=0;dist>0;j++) if (dist&(1<<j)){
        cur = cur + sp2[tx][j];
        tx = sp1[tx][j].second;
        dist -= 1<<j;
    }

    if (tx!=lca) cur = cur + Top;
    ret = max(ret, cur.ans);

    ///right chain
    cur = Node(-1e9, dp_far[ty] - dep[ty], dp_far[ty] + dep[ty]);
    dist = dep[ty] - dep[lca] - 1;
    for (int j=0;dist>0;j++) if (dist&(1<<j)){
        cur = cur + sp2[ty][j];
        ty = sp1[ty][j].second;
        dist -= 1<<j;
    }

    if (ty!=lca) cur = cur + Top;
    ret = max(ret, cur.ans);

    ///left and right
    if (x==lca || y==lca) return (n-2+w) * 2 - (rdist+w+ret);

    int L = dp_far[x] - dep[x], R = dp_far[y] - dep[y];
    tx = x, ty = y;

    dist = dep[tx] - dep[lca] - 1;
    for (int j=0;dist>0;j++) if (dist&(1<<j)){
        L = max(L, sp2[tx][j].a);
        tx = sp1[tx][j].second;
        dist -= 1<<j;
    }

    dist = dep[ty] - dep[lca] - 1;
    for (int j=0;dist>0;j++) if (dist&(1<<j)){
        R = max(R, sp2[ty][j].a);
        ty = sp1[ty][j].second;
        dist -= 1<<j;
    }

    ret = max(ret, L+R + dep[lca]*2);

    return (n-2+w) * 2 - (rdist+w+ret);
}

void Debug(){
    printf("dep: ");
    for (int i=0;i<n;i++) printf("%d ", dep[i]);
    printf("\npar: ");
    for (int i=0;i<n;i++) printf("%d ", par[i]);
    printf("\ndp_D: ");
    for (int i=0;i<n;i++) printf("%d ", dp_D[i]);
    printf("\ndp_far: ");
    for (int i=0;i<n;i++) printf("%d ", dp_far[i]);
    printf("\nD: ");
    for (int i=0;i<n;i++) printf("%d ", D[i]);
    printf("\ndeep: ");
    for (int i=0;i<n;i++) printf("%d ", deep[i]);
    printf("\ndp_parD: ");
    for (int i=0;i<n;i++) printf("%d ", dp_parD[i]);
    printf("\ndp_parF: ");
    for (int i=0;i<n;i++) printf("%d ", dp_parF[i]);
    printf("\n");
}

int main(){
    cin.tie(NULL);
    ios_base::sync_with_stdio(false);
    int m;
    cin >> n >> m;
    for (int i=1;i<=m;i++){
        int x, y, z;
        cin >> x >> y >> z;
        if (z!=1) {E.emplace_back(x, y, z); continue;}

        adj[x].push_back(y);
        adj[y].push_back(x);
    }

    dfs_dp1(0);
    dfs_dp2(0);

    if (n>80000) exit(0);

    //Debug();

    build(n);

    //if (n>80000) exit(0);

    int ans = (n-1)*2 - dp_D[0];

    for (auto &e:E){
        ans = min(ans, calc1(e.x, e.y, e.z));
        ans = min(ans, calc2(e.x, e.y, e.z));
    }

    printf("%d\n", ans);
    return 0;
}
# 결과 실행 시간 메모리 Grader output
1 Incorrect 6 ms 12168 KB Output isn't correct
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 6 ms 12168 KB Output isn't correct
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 19 ms 15564 KB Output isn't correct
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 6 ms 12168 KB Output isn't correct
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 6 ms 12168 KB Output isn't correct
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 6 ms 12168 KB Output isn't correct
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 6 ms 12168 KB Output isn't correct
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 6 ms 12168 KB Output isn't correct
2 Halted 0 ms 0 KB -