#include <bits/stdc++.h>
using namespace std;
using ll = long long;
vector<ll> vals;
vector<vector<ll>> adj;
vector<set<ll>*> sets;
vector<pair<ll, ll>> costs;
map<pair<ll, ll>, pair<ll, ll>> cost_map;
ll last;
void calculate(ll n, ll p){
if(cost_map.count({n, p}))
costs[n] = cost_map[{n, p}];
vector<ll> set_vec;
ll max_set = 0;
for(ll e: adj[n]){
if(e == p)
continue;
calculate(e, n);
if(sets[e]->size() > sets[max_set]->size()){
max_set = e;
}
set_vec.push_back(e);
}
if(!set_vec.empty()){
sets[n] = sets[max_set];
vals[n] = vals[max_set];
}
for(ll i: set_vec){
if(i==max_set)
continue;
for(ll j: *sets[i]){
if(sets[n]->count(j))
continue;
sets[n]->insert(j);
if(j == last-1){
vals[n]--;
}
ll adj_count = 0;
if(sets[n]->count(j+1)) adj_count++;
if(sets[n]->count(j-1)) adj_count++;
if(adj_count == 0) vals[n] += 2;
else if(adj_count == 2) vals[n] -= 2;
}
}
ll adj_count = 0;
if(sets[n]->count(n+1)) adj_count++;
if(sets[n]->count(n-1)) adj_count++;
sets[n]->insert(n);
if(n == last-1) vals[n]--;
if(adj_count == 0) vals[n] += 2;
else if(adj_count == 2) vals[n] -= 2;
return;
}
int main(){
ll n; cin >> n;
last = n;
costs.resize(n);
adj.resize(n);
vals.resize(n);
sets.resize(n);
for(ll i = 0; i < n; i++){
sets[i] = new set<ll>;
}
for(ll i = 1; i < n; i++){
ll a, b, c, d; cin >> a >> b >> c >> d;
adj[a-1].push_back(b-1);
adj[b-1].push_back(a-1);
cost_map[{a-1, b-1}] = {c, d};
cost_map[{b-1, a-1}] = {c, d};
}
calculate(0, -1);
ll res = 0;
for(ll i=1;i<n;i++)
res += min(costs[i].first * vals[i], costs[i].second);
cout << res << '\n';
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |