제출 #529051

#제출 시각아이디문제언어결과실행 시간메모리
529051wiwihoConstruction of Highway (JOI18_construction)C++14
100 / 100
1282 ms22844 KiB
#include<bits/stdc++.h>

#define printv(a, b) { \
    for(auto pv : a) b << pv << " "; \
    b << "\n"; \
}
#define eb emplace_back
#define pob pop_back()
#define iter(a) a.begin(), a.end()
#define lsort(a) sort(iter(a))
#define uni(a) a.resize(unique(iter(a)) - a.begin())
#define mp make_pair
#define F first
#define S second

using namespace std;

typedef long long ll;

const ll MOD = 1000000007;

using pll = pair<ll, ll>;

#define lc (2 * id + 1)
#define rc (2 * id + 2)

int n;
vector<int> c, tme;
vector<vector<int>> g;
vector<int> pr;
vector<int> in, out, eu, dpt;
vector<vector<int>> anc;
int ts = 0;
const int SZ = 20;

void init(){
    g.resize(n + 1);
    pr.resize(n + 1);
    c.resize(n + 1);
    tme.resize(n + 1);
    in.resize(n + 1);
    out.resize(n + 1);
    eu.resize(n + 1);
    dpt.resize(n + 1);
    anc.resize(SZ, vector<int>(n + 1));
}

struct SegmentTree{
    vector<int> st;

    int argmax(int a, int b){
        if(tme[a] >= tme[b]) return a;
        else return b;
    }

    void pull(int id){
        st[id] = argmax(st[lc], st[rc]);
    }
    
    void build(int L, int R, int id){
        if(L == R){
            st[id] = L;
            return;
        }
        int M = (L + R) / 2;
        build(L, M, lc);
        build(M + 1, R, rc);
        pull(id);
    }

    void init(){
        st.resize(4 * n);
    }

    void upd(int x, int L = 1, int R = n, int id = 0){
        if(L == R) return;
        int M = (L + R) / 2;
        if(x <= M) upd(x, L, M, lc);
        else upd(x, M + 1, R, rc);
        pull(id);
    }

    int query(int l, int r, int L = 1, int R = n, int id = 0){
        if(l <= L && R <= r) return st[id];
        int M = (L + R) / 2;
        if(r <= M) return query(l, r, L, M, lc);
        else if(l > M) return query(l, r, M + 1, R, rc);
        else return argmax(query(l, r, L, M, lc), query(l, r, M + 1, R, rc));
    }
};

int lowbit(int x){
    return x & -x;
}

struct BIT{
    vector<ll> bit;
    void init(){
        bit.resize(n + 1);
    }
    void modify(int x, ll v){
        for(; x < bit.size(); x += lowbit(x)) bit[x] += v;
    }
    ll query(int x){
        ll ans = 0;
        for(; x > 0; x -= lowbit(x)) ans += bit[x];
        return ans;
    }
};

void dfs(int now){
    in[now] = ++ts;
    eu[ts] = now;
    dpt[now] = dpt[pr[now]] + 1;
    for(int i : g[now]) dfs(i);
    out[now] = ts;
}

void buildAnc(){
    for(int i = 1; i < SZ; i++){
        for(int j = 1; j <= n; j++){
            anc[i][j] = anc[i - 1][anc[i - 1][j]];
        }
    }
}

bool isAnc(int a, int b){
    return in[a] <= in[b] && out[b] <= out[a];
}

int lca(int a, int b){
    //cerr << "LCA " << a << " " << b << "  ";
    if(isAnc(a, b)){
        //cerr << a << "\n";
        return a;
    }
    for(int i = SZ - 1; i >= 0; i--){
        if(!isAnc(anc[i][a], b)) a = anc[i][a];
    }
    //cerr << anc[0][a] << "\n";
    return anc[0][a];
}

int under(int a, int b){
    for(int i = SZ - 1; i >= 0; i--){
        if(!isAnc(anc[i][a], b)) a = anc[i][a];
    }
    return a;
}

SegmentTree st;

int getp(int v){
    int pos = st.query(in[v], out[v]);
    return pos;
}

int getv(int v){
    int pos = getp(v);
    //cerr << "getv " << v << " " << pos << "\n";
    return c[eu[pos]];
}

BIT bit;
ll calc(int v){
    //cerr << "calc " << v << "\n";
    ll ans = 0;
    int now = 1;
    vector<pll> opt;
    while(now != v){
        int pos = getp(now);
        int cv = c[eu[pos]];
        int t = lca(v, eu[pos]);
        int len = dpt[t] - dpt[now] + 1;
        //cerr << "test " << now << " " << t << " " << eu[pos] << "\n";
        ans += (bit.query(n) - bit.query(cv)) * len;
        bit.modify(cv, len);
        now = under(v, t);
        opt.eb(mp(cv, len));
    }
    for(pll i : opt) bit.modify(i.F, -i.S);

    //cerr << "ok " << ans << "\n";
    return ans;
}

int main(){
    ios_base::sync_with_stdio(false);
    cin.tie(0);
    cerr.tie(0);

    cin >> n;
    init();

    for(int i = 1; i <= n; i++){
        cin >> c[i];
    }
    vector<int> tmp = c;
    lsort(tmp);
    uni(tmp);
    for(int i = 1; i <= n; i++){
        c[i] = lower_bound(iter(tmp), c[i]) - tmp.begin();
    }
    //printv(c, cerr);

    vector<int> qry(n + 1);
    pr[1] = anc[0][1] = 1;
    for(int i = 2; i <= n; i++){
        int u, v;
        cin >> u >> v;
        g[u].eb(v);
        pr[v] = anc[0][v] = u;
        qry[i] = v;
    }

    dfs(1);
    buildAnc();

    //for(int i = 0; i < SZ; i++) printv(anc[i], cerr);

    st.init();
    st.build(1, n, 0);
    bit.init();

    for(int i = 2; i <= n; i++){
        int now = qry[i];
        cout << calc(now) << "\n";
        tme[in[now]] = i;
        st.upd(in[now]);
    }

    return 0;
}

컴파일 시 표준 에러 (stderr) 메시지

construction.cpp: In member function 'void BIT::modify(int, ll)':
construction.cpp:102:17: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<long long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  102 |         for(; x < bit.size(); x += lowbit(x)) bit[x] += v;
      |               ~~^~~~~~~~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...