제출 #223200

#제출 시각아이디문제언어결과실행 시간메모리
223200cheehengPutovanje (COCI20_putovanje)C++14
110 / 110
534 ms42232 KiB
#include <bits/stdc++.h>
using namespace std;

typedef pair<int, int> ii;

map<int, ii> EdgeW[200005];

vector<int> AdjList[200005];
static int heavy[200005];
static int depth[200005];
static int head[200005];
static int p[200005];
static int pos[200005];
static int sz[200005];
int cnt = 0; ///set to 1 if you're using fenwick tree

long long A[200005];

long long val[(1<<19)+5];
long long lazyadd[(1<<19)+5];

int N;

inline int left1(int x){
    return (x<<1)+1;
}

inline int right1(int x){
    return (x<<1)+2;
}

long long value(int i, int s, int e){
    if(s == e){
        val[i] += lazyadd[i];
        lazyadd[i] = 0;
        return val[i];
    }
    int l = left1(i);
    int r = right1(i);
    val[i] += lazyadd[i];
    lazyadd[l] += lazyadd[i];
    lazyadd[r] += lazyadd[i];
    lazyadd[i] = 0;
    return val[i];
}

void init(int i = 0, int s = 0, int e = N-1){
    int m = (s+e)>>1;
    lazyadd[i] = 0;
    if(s == e){
        val[i] = A[s];
    }else{
        int l = left1(i);
        int r = right1(i);
        init(l, s, m);
        init(r, m+1, e);
        val[i] = val[l]+val[r];
    }
}

long long rsq(int x, int y, int i = 0, int s = 0, int e = N-1){
    int m = (s+e)>>1;
    value(i, s, e);
    if(x <= s && e <= y){
        return val[i];
    }

    int l = left1(i);
    int r = right1(i);
    if(x > m){
        return rsq(x, y, r, m+1, e);
    }else if(y <= m){
        return rsq(x, y, l, s, m);
    }else{
        return rsq(x, y, r, m+1, e)+rsq(x, y, l, s, m);
    }
}

void update(int x, int y, int v, int i = 0, int s = 0, int e = N-1){
    int m = (s+e)>>1;
    if(x <= s && e <= y){
        lazyadd[i] += v;
        return;
    }
    int l = left1(i);
    int r = right1(i);
    if(x > m){
        update(x, y, v, r, m+1, e);
    }else if(y <= m){
        update(x, y, v, l, s, m);
    }else{
        update(x, y, v, r, m+1, e);
        update(x, y, v, l, s, m);
    }
    val[i] = value(l, s, m)+value(r, m+1, e);
}

void dfs(int u){
    sz[u] = 1;
    int maxChild = 0;
    for(int v : AdjList[u]){
        if(sz[v] == 0){
            depth[v] = depth[u] + 1;
            dfs(v);
            sz[u] += sz[v];
            p[v] = u;
            if(sz[v] > maxChild){
                maxChild = sz[v];
                heavy[u] = v;
            }
        }
    }
}

void decompose(int u, int h){
    head[u] = h;
    pos[u] = cnt;
    cnt++;
    if(heavy[u] != 0) decompose(heavy[u], h);
    for(int v : AdjList[u]){
        if(sz[v] < sz[u] && v != heavy[u]){
            decompose(v,v);
        }
    }
}

void update1(int a, int b){
    if(depth[a] > depth[b]) swap(a,b);
    for(;head[a] != head[b];b = p[head[b]]){
        if(depth[head[a]] > depth[head[b]]) swap(a,b);
        update(pos[head[b]],pos[b],1);
        ///any update and query affects pos[head[b]] inclusive to pos[b] inclusive
    }
    if(depth[a] > depth[b]) swap(a,b);
    if(pos[a] == pos[b]){return;}
    update(pos[a]+1,pos[b],1);
}

long long sum1(int a, int b){
    long long res = 0;
    if(depth[a] > depth[b]) swap(a,b);
    for(;head[a] != head[b];b = p[head[b]]){
        if(depth[head[a]] > depth[head[b]]) swap(a,b);
        res += rsq(pos[head[b]],pos[b]);
        ///any update and query affects pos[head[b]] inclusive to pos[b] inclusive
    }
    if(depth[a] > depth[b]) swap(a,b);
    if(pos[a] == pos[b]){return res;}
    res += rsq(pos[a]+1,pos[b]);
    return res;
}


int main(){
    scanf("%d", &N);

    for(int i = 1; i < N; i ++){
        int a, b, c, d;
        scanf("%d%d%d%d", &a, &b, &c, &d);
        AdjList[a].push_back(b);
        AdjList[b].push_back(a);
        EdgeW[a][b] = ii(c, d);
        EdgeW[b][a] = ii(c, d);
    }

    init();
    dfs(1);
    decompose(1,1);

    for(int i = 1; i < N; i ++){
        update1(i, i+1);
    }

    long long ans = 0;
    for(int i = 1; i <= N; i ++){
        for(int j: AdjList[i]){
            if(j < i){continue;}
            int x, y;
            tie(x, y) = EdgeW[i][j];
            //printf("%d %d\n", i, j);
            long long times = sum1(i, j);
            //printf("%d %d %lld\n", i, j, times);
            ans += min(x*times, (long long)y);
        }
    }

    printf("%lld", ans);
    return 0;
}

컴파일 시 표준 에러 (stderr) 메시지

putovanje.cpp: In function 'int main()':
putovanje.cpp:155:10: warning: ignoring return value of 'int scanf(const char*, ...)', declared with attribute warn_unused_result [-Wunused-result]
     scanf("%d", &N);
     ~~~~~^~~~~~~~~~
putovanje.cpp:159:14: warning: ignoring return value of 'int scanf(const char*, ...)', declared with attribute warn_unused_result [-Wunused-result]
         scanf("%d%d%d%d", &a, &b, &c, &d);
         ~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...