제출 #202079

#제출 시각아이디문제언어결과실행 시간메모리
202079MilkiPutovanje (COCI20_putovanje)C++14
110 / 110
166 ms23768 KiB
#include<bits/stdc++.h>
using namespace std;

#define FOR(i, a, b) for(int i = a; i < b; ++i)
#define REP(i, n) FOR(i, 0, n)
#define _ << " " <<
#define sz(x) ((int) x.size())
#define pb(x) push_back(x)
#define TRACE(x) cerr << #x << " = " << x << endl

typedef long long ll;
typedef pair<int, int> point;

const int mod = 1e9 + 7;

int add(int x, int y) {x += y; if(x >= mod) return x - mod; return x;}
int sub(int x, int y) {x -= y; if(x < 0) return x + mod; return x;}
int mul(int x, int y) {return (ll) x * y % mod;}

const int MAXN = 2e5 + 5, LOG = 19;

struct Edge{
  int x = 0, cost1 = 0, cost2 = 0;
  Edge(){}
  Edge(int x, int cost1, int cost2) : x(x), cost1(cost1), cost2(cost2) {}
};

int n, anc[LOG][MAXN], dep[MAXN];
vector <Edge> E[MAXN];
Edge par[MAXN];

void dfs(int x, int p = -1){
  for(auto e : E[x]){
    if(e.x == p) continue;
    dep[e.x] = dep[x] + 1;
    anc[0][e.x] = x;
    par[e.x] = e;
    dfs(e.x, x);
  }
}

int get_lca(int x, int y){
  if(dep[x] < dep[y])
    swap(x, y);

  for(int i = LOG - 1; i >= 0; --i)
    if( dep[x] - (1 << i) >= dep[y] )
      x = anc[i][x];

  if(x == y) return x;
  for(int i = LOG - 1; i >= 0; --i)
    if( anc[i][x] != anc[i][y] ){
      x = anc[i][x];
      y = anc[i][y];
    }
  return anc[0][x];
}

ll extra[MAXN], sum[MAXN], after[MAXN];

int main(){
  ios_base::sync_with_stdio(false); cin.tie(0);

  cin >> n;
  REP(i, n - 1){
    int a, b, cost1, cost2; cin >> a >> b >> cost1 >> cost2;
    a --; b --;
    E[a].pb(Edge(b, cost1, cost2));
    E[b].pb(Edge(a, cost1, cost2));
  }

  dfs(0);

  FOR(i, 1, LOG) REP(j, n){
    int x = anc[i - 1][j];
    anc[i][j] = anc[i - 1][x];
  }

  FOR(i, 1, n){
    int x = i, y = i - 1;
    if(dep[x] > dep[y])
      swap(x, y);
    int lca = get_lca(x, y);

    if(lca != x && lca != y){
      extra[x] ++; extra[y] ++;
      extra[lca] -= 2;
    }
    else if(lca == x){
      extra[y] ++;
      extra[lca] --;
    }
    else
      assert(n == 0);
  }

  vector <point> v;
  FOR(i, 1, n)
    v.pb(point( dep[i], i ));
  sort(v.rbegin(), v.rend());

  ll sol = 0;
  for(int i = 0; i < n - 1; ++i){
    int x = v[i].second;
    extra[anc[0][x]] += extra[x];
    ll add1 = (ll)par[x].cost1 * (ll)extra[x];
    ll add2 = (ll)par[x].cost2;
    sol += min(add1, add2);
  }
  cout << sol;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...