Submission #376804

#TimeUsernameProblemLanguageResultExecution timeMemory
376804wiwihoMergers (JOI19_mergers)C++14
70 / 100
3093 ms200164 KiB
#pragma opitimize("O3")

#include<bits/stdc++.h>

#define mp make_pair
#define F first
#define S second
#define iter(a) a.begin(), a.end()
#define lsort(a) sort(iter(a))
#define printv(a, b) { \
    for(auto pv : a) b << pv << " "; \
    b << "\n"; \
}
#define eb emplace_back

using namespace std;

typedef long long ll;

using pll = pair<ll, ll>;
using pii = pair<int, int>;

ostream& operator<<(ostream& o, pll p){
    return o << '(' << p.F << ',' << p.S << ')';
}

void waassert(bool t){
    if(!t){
        cout << "OAO\n";
        exit(0);
    }
}

struct Info{
    pii mn, mx, al, lb, alb;
};

Info pull(Info a, Info b){
    Info c;
    c.mn = min(a.mn, b.mn);
    c.mx = max(a.mx, b.mx);
    c.al = max({a.al, b.al, mp(a.mx.F - 2 * b.mn.F, b.mn.S)});
    c.lb = max({a.lb, b.lb, mp(- 2 * a.mn.F + b.mx.F, a.mn.S)});
    c.alb = max({a.alb, b.alb, mp(a.al.F + b.mx.F, a.al.S), mp(a.mx.F + b.lb.F, b.lb.S)});
    return c;
}

ostream& operator<<(ostream& o, Info info){
    return o << '(' << info.mn << ',' << info.mx << ',' << info.al << ',' << info.lb << ',' << info.alb << ')';
}

struct Node{
    int l = -1, r = -1;
    int tag = 0;
    Info info;
    Info rv(){
        Info t = info;
        t.mn.F += tag;
        t.mx.F += tag;
        t.al.F -= tag;
        t.lb.F -= tag;
        return t;
    }
};

struct SegmentTree{
    vector<Node> st;
    int ts = 0;
    int build(int l, int r){
        int id = ts++;
        st[id].info.mn.S = st[id].info.mx.S = l;
        st[id].info.al.S = l;
        st[id].info.lb.S = l;
        st[id].info.alb.S = l;
        if(l == r) return id;
        int m = (l + r) / 2;
        st[id].l = build(l, m);
        st[id].r = build(m + 1, r);
        return id;
    }
    void push(int id){
        st[st[id].l].tag += st[id].tag;
        st[st[id].r].tag += st[id].tag;
        st[id].info = st[id].rv();
        st[id].tag = 0;
    }
    void modify(int l, int r, int v, int L, int R, int id){
        if(l == L && r == R){
            st[id].tag += v;
            return;
        }
        int M = (L + R) / 2;
        if(r <= M) modify(l, r, v, L, M, st[id].l);
        else if(l > M) modify(l, r, v, M + 1, R, st[id].r);
        else{
            modify(l, M, v, L, M, st[id].l);
            modify(M + 1, r, v, M + 1, R, st[id].r);
        }
        st[id].info = pull(st[st[id].l].rv(), st[st[id].r].rv());
    }
    pii query(int l, int r, int L, int R, int id){
        if(l == L && r == R) return st[id].rv().mx;
        push(id);
        int M = (L + R) / 2;
        if(r <= M) return query(l, r, L, M, st[id].l);
        else if(l > M) return query(l, r, M + 1, R, st[id].r);
        else return max(query(l, M, L, M, st[id].l), query(M + 1, r, M + 1, R, st[id].r));
    }
    void print(int l, int r, int id){
        cerr << "print " << l << " " << r << " " << id << " " << st[id].l << " " << st[id].r << " " << st[id].info << "\n";
        if(l == r) return;
        push(id);
        int m = (l + r) / 2;
        print(l, m, st[id].l);
        print(m + 1, r, st[id].r);
    }
};

struct DSU{
    vector<int> dsu;
    vector<int> rk, up;
    int cnt;
    void init(int n){
        dsu.resize(n + 1);
        rk.resize(n + 1, 1);
        up.resize(n + 1);
        cnt = n;
        for(int i = 1; i <= n; i++) dsu[i] = up[i] = i;
    }
    int findDSU(int a){
        if(dsu[a] != a) dsu[a] = findDSU(dsu[a]);
        return dsu[a];
    }
    void unionDSU(int a, int b){
        a = findDSU(a);
        b = findDSU(b);
        if(a == b) return;
        int tmp = up[a];
        if(rk[a] < rk[b]) swap(a, b);
        else if(rk[a] == rk[b]) rk[a]++;
        up[a] = tmp;
        cnt--;
        dsu[b] = a;
    }
};

int n;
vector<vector<int>> g;
vector<int> s;
vector<vector<int>> anc;
vector<int> in, out;
vector<int> et(1);
DSU dsu;
int ts = 0;
SegmentTree st;

void dfs1(int now, int p){
    anc[0][now] = p;
    in[now] = ++ts;
    et.eb(now);
    for(int i : g[now]){
        if(i == p) continue;
        dfs1(i, now);
        ts++;
        et.eb(now);
    }
    out[now] = ts;
}

void buildLCA(){
    for(int i = 1; i < 20; i++){
        for(int j = 1; j <= n; j++){
            anc[i][j] = anc[i - 1][anc[i - 1][j]];
        }
    }
}

bool isAnc(int a, int b){
    return in[a] <= in[b] && out[a] >= out[b];
}

int getLCA(int a, int b){
    if(isAnc(a, b)) return a;
    for(int i = 19; i >= 0; i--){
        if(!isAnc(anc[i][a], b)) a = anc[i][a];
    }
    return anc[0][a];
}

int unionPath(int a, int b){
    int lca = getLCA(a, b);
    a = dsu.up[dsu.findDSU(a)];
    b = dsu.up[dsu.findDSU(b)];
    //cerr << a << " " << b << "\n";
    while(!isAnc(a, lca)){
        //cerr << a << " " << b << "\n";
        st.modify(in[a], out[a], -1, 1, 2 * n - 1, 0);
        dsu.unionDSU(anc[0][a], a);
        a = dsu.findDSU(a);
        a = dsu.up[a];
    }
    while(!isAnc(b, lca)){
        //cerr << a << " " << b << "\n";
        st.modify(in[b], out[b], -1, 1, 2 * n - 1, 0);
        dsu.unionDSU(anc[0][b], b);
        b = dsu.findDSU(b);
        b = dsu.up[b];
    }
    return dsu.up[dsu.findDSU(lca)];
}

int fv = -1, fd = -1;

void dfs2(int now, int p, int dpt){
    if(dpt > fd) fv = now, fd = dpt;
    //cerr << "dfs2 " << now << " " << p << " " << dpt << "\n";
    for(int i : g[now]){
        if(i == p) continue;
        if(dsu.findDSU(now) != dsu.findDSU(i)) dfs2(i, now, dpt + 1);
        else dfs2(i, now, dpt);
    }
}

void owo(){
    //cerr << "owo\n";
    int mid = st.st[0].info.alb.S;
    int t1 = et[st.query(1, mid, 1, 2 * n - 1, 0).S];
    int t2 = et[st.query(mid, 2 * n - 1, 1, 2 * n - 1, 0).S];
    //cerr << mid << " " << t1 << " " << t2 << "\n";
    unionPath(t1, t2);
}

int main(){
    ios_base::sync_with_stdio(false);
    cin.tie(0);

    int k;
    cin >> n >> k;
    g.resize(n + 1);
    s.resize(n + 1);
    anc.resize(20, vector<int>(n + 1));
    in.resize(n + 1);
    out.resize(n + 1);
    dsu.init(n);
    vector<vector<int>> city(k + 1);
    st.st.resize(4 * n);

    for(int i = 0; i < n - 1; i++){
        int u, v;
        cin >> u >> v;
        g[u].eb(v);
        g[v].eb(u);
    }
    for(int i = 1; i <= n; i++){
        cin >> s[i];
        city[s[i]].eb(i);
    }

    dfs1(1, 1);
    //printv(et, cerr);
    buildLCA();
    st.build(1, 2 * n - 1);
    for(int i = 1; i <= n; i++){
        st.modify(in[i], out[i], 1, 1, 2 * n - 1, 0);
    }

    for(int i = 1; i <= k; i++){
        int v = city[i].front();
        for(int j = 1; j < city[i].size(); j++){
            v = unionPath(v, city[i][j]);
        }
    }
    /*for(int i = 1; i < 2 * n; i++){
        cerr << i << " " << st.query(i, i, 1, 2 * n - 1, 0) << "\n";
    }
    st.print(1, 2 * n - 1, 0);*/

    int ans = 0;
    while(dsu.cnt > 1){
        owo();
        ans++;
    }
    cout << ans << "\n";


    return 0;
}

Compilation message (stderr)

mergers.cpp:1: warning: ignoring #pragma opitimize  [-Wunknown-pragmas]
    1 | #pragma opitimize("O3")
      | 
mergers.cpp: In function 'int main()':
mergers.cpp:269:26: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  269 |         for(int j = 1; j < city[i].size(); j++){
      |                        ~~^~~~~~~~~~~~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...