답안 #746913

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
746913 2023-05-23T08:39:34 Z Sami_Massah Construction of Highway (JOI18_construction) C++17
0 / 100
3 ms 5076 KB
#include <bits/stdc++.h>
using namespace std;



const int maxn = 2e5 + 12, lg = 20;
int n, s, L[maxn * 4], R[maxn * 4], sum[maxn * 4], par[maxn][lg], mx[maxn * 4], st[maxn], en[maxn], h[maxn], U[maxn], V[maxn], bcnv[maxn], c[maxn];
vector <int> conn[maxn];
vector <tuple<int, int, int>> seg;
bitset <maxn> marked;
void make_tree(int l, int r, int ind){
    int mid = (l + r) / 2;
    L[ind] = l;
    R[ind] = r;
    if(l == r)
        return;
    make_tree(l, mid, ind * 2);
    make_tree(mid + 1, r, ind * 2 + 1);
}
void update_tree_sum(int l, int r, int u, int k, bool add = 0){
    if(r < L[u] || R[u] < l)
        return;
    if(l <= L[u] && R[u] <= r){
        if(add == 0)
            sum[u] = k;
        else
            sum[u] += k;
        return;
    }
    update_tree_sum(l, r, u * 2, k, add);
    update_tree_sum(l, r, u * 2 + 1, k, add);
    sum[u] = (sum[u * 2] + sum[u * 2 + 1]);
}
void update_tree_mx(int l, int r, int u, int k){
    if(r < L[u] || R[u] < l)
        return;
    if(l <= L[u] && R[u] <= r){
        mx[u] = k;
        return;
    }
    update_tree_mx(l, r, u * 2, k);
    update_tree_mx(l, r, u * 2 + 1, k);
    mx[u] = max(mx[u * 2], mx[u * 2 + 1]);

}
int get_sum(int l, int r, int u){
    if(r < L[u] || R[u] < l)
        return 0;
    if(l <= L[u] && R[u] <= r)
        return sum[u];
    return get_sum(l, r, u * 2) + get_sum(l, r, u * 2 + 1);
}
int get_max(int l, int r, int u){
    if(r < L[u] || R[u] < l)
        return 0;
    if(l <= L[u] && R[u] <= r)
        return mx[u];

    return max(get_max(l, r, u * 2), get_max(l, r, u * 2 + 1));
}
int kpar(int u, int k){
    for(int i = 0; i < lg; i++)
        if((k >> i) & 1)
            u = par[u][i];
    return u;
}
int lca(int a, int b){
    if(h[a] < h[b])
        swap(a, b);
   // cout << a << '-' << b << endl;
    a = kpar(a, h[a] - h[b]);
    if(a == b)
        return a;
    //cout << a << '-' << b << endl;
    for(int i = lg - 1; i >= 0; i--)
        if(par[a][i] != par[b][i]){
            a = par[a][i];
            b = par[b][i];
        }
    return par[a][0];
}
void dfs_set(int u){
    marked[u] = 1;
    for(int i = 1; i < lg; i++)
        par[u][i] = par[par[u][i - 1]][i - 1];
    st[u] = s;
    s += 1;
    for(int v: conn[u])
        if(!marked[v]){
            h[v] = h[u] + 1;
            par[v][0] = u;
            dfs_set(v);
        }
    en[u] = s - 1;
}
void get_seg(int u){
    seg.clear();
    int now = 1, res = -1;

    while(res != u){
       // cout << st[now] << ' ' << en[now] << endl;
        res = get_max(st[now], en[now], 1);
        //cout << res << endl;
        res = bcnv[res];
        //cout << u << ' ' << res << endl;
        int nc = c[res];
        res = lca(res, u);
    //    cout << res << endl;
       // cout << now << ' ' << res << endl;

        seg.push_back(make_tuple(now, res, nc));
        if(res != u)
            now = kpar(u, h[u] - h[res] - 1);
        //cout << now << endl;

    }
    return;
}
long long get_ans(){
    long long nat = 0;
    for(auto i: seg){
     //   cout << get<0>(i) << ' ' << get<1>(i) << ' ' << get<2>(i) << endl;
        int len = get<1>(i) - get<0>(i) + 1;
        nat += get_sum(get<2>(i) + 1, n, 1) * len;
        update_tree_sum(get<2>(i), get<2>(i), 1, len, 1);
    }
    for(auto i: seg)
        update_tree_sum(get<2>(i), get<2>(i), 1, 0);
    return nat;
}
int main(){
    ios_base::sync_with_stdio(false), cin.tie(0);
    cin >> n;
    for(int i = 0; i < n; i++)
        cin >> c[i + 1];
    for(int i = 0; i < n - 1; i++){
        cin >> U[i] >> V[i];
        conn[U[i]].push_back(V[i]);
        conn[V[i]].push_back(U[i]);
    }

    dfs_set(1);
    make_tree(0, n, 1);
    update_tree_mx(st[1], st[1], 1, 1);

    bcnv[1] = 1;

    for(int i = 0; i < n - 1; i++){
        long long ans = 0;
        get_seg(U[i]);
        ans += get_ans();
       // cout << st[V[i]] << endl;
        update_tree_mx(st[V[i]], st[V[i]], 1, i + 2);
      //  cout << get_max(0, 4, 1) << endl;
        bcnv[i + 2] = V[i];
        cout << ans << "\n";
  //      cout << endl;
       // cout << endl;
        //if(i == 2)
          //  return 0;
    }







}
# 결과 실행 시간 메모리 Grader output
1 Correct 3 ms 5036 KB Output is correct
2 Incorrect 3 ms 5076 KB Output isn't correct
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 3 ms 5036 KB Output is correct
2 Incorrect 3 ms 5076 KB Output isn't correct
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 3 ms 5036 KB Output is correct
2 Incorrect 3 ms 5076 KB Output isn't correct
3 Halted 0 ms 0 KB -