#include "job.h"
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
struct DSU {
vector<int> link, sz, u, d;
vector<vector<int>> seq;
DSU(int n, vector<int> &U, vector<int> &D) {
sz.resize(n, 1);
link.resize(n);
iota(link.begin(), link.end(), 0);
seq.resize(n);
for (int i = 0; i < n; i++) seq[i].push_back(i);
u = U;
d = D;
}
int find(int x) {
while (x != link[x]) x = link[x];
return x;
}
void merge(int a, int b) {
a = find(a);
b = find(b);
if (a == b) return;
sz[a] += b;
link[b] = a;
for (auto x : seq[b]) seq[a].push_back(x);
seq[b].clear();
u[a] += u[b];
d[a] += d[b];
}
};
struct s {
int u, d;
int i;
bool operator<(const s &other) const {
// u/d < other.u/other.d
return 1ll*u*other.d < d*other.u;
}
};
long long scheduling_cost(vector<int> p, vector<int> u, vector<int> d) {
int n = p.size();
vector<vector<int>> adj(n);
for (int i = 1; i < n; i++) adj[p[i]].push_back(i);
multiset<s> m;
for (int i = 1; i < n; i++) m.insert({u[i], d[i], i});
DSU dsu(n, u, d);
while (!m.empty()) {
s a = *m.rbegin();
m.erase(--m.end());
int par = p[a.i];
if (par != 0) {
int x = dsu.find(par);
m.erase(m.find({dsu.u[x], dsu.d[x], x}));
}
dsu.merge(p[a.i], a.i);
int temp = dsu.find(a.i);
if (temp != 0) m.insert({dsu.u[temp], dsu.d[temp], temp});
}
ll ans = 0, cur_time = 0;
for (int i = 0; i < dsu.seq[0].size(); i++) {
int j = dsu.seq[0][i];
cur_time += d[j];
ans += u[j]*cur_time;
}
return ans;
}