#include <bits/stdc++.h>
#include <cassert>
#pragma GCC optimize("O3")
#pragma GCC optimize("Ofast")
#pragma GCC optimize("unroll-loops")
#pragma GCC target("avx,avx2,fma")
using namespace std;
typedef long long ll;
typedef pair<ll, ll> pll;
typedef pair<int, int> pii;
#define MAX 1010101
#define MAXS 500
#define INF 1000000000000000001
#define bb ' '
#define ln '\n'
#define Ln '\n'
int A[MAX] = { -1e9 };
ll F[MAX];
int N;
namespace segtree {
int N;
ll tree[MAX];
ll lazy[MAX];
void init(int s, int e, int loc = 1) {
if (s == e) {
tree[loc] = F[s];
return;
}
int m = s + e >> 1;
init(s, m, loc * 2);
init(m + 1, e, loc * 2 + 1);
tree[loc] = min(tree[loc * 2], tree[loc * 2 + 1]);
}
void prop(int loc) {
for (auto c : { loc * 2, loc * 2 + 1 }) {
tree[c] += lazy[loc];
lazy[c] += lazy[loc];
}
lazy[loc] = 0;
}
void upd(int s, int e, int i, ll x, int loc = 1) {
if (e < i || i < s) return;
if (s == e) {
tree[loc] += x;
return;
}
int m = s + e >> 1;
upd(s, m, i, x, loc * 2);
upd(m + 1, e, i, x, loc * 2 + 1);
tree[loc] = min(tree[loc * 2], tree[loc * 2 + 1]);
}
void upd(int i, ll x) { upd(1, N, i, x); }
ll query(int s, int e, int l, int r, int loc = 1) {
if (s != e) prop(loc);
if (e < l || r < s) return 1e18;
if (l <= s && e <= r) return tree[loc];
int m = s + e >> 1;
return min(query(s, m, l, r, loc * 2), query(m + 1, e, l, r, loc * 2 + 1));
}
ll query(int l, int r) { return query(1, N, l, r); }
};
vector<int> adj[MAX];
ll ans[MAX];
ll aans[MAX]; // x=a, v=0
vector<pll> st[MAX];
ll all[MAX];
int dep[MAX];
int num[MAX];
void dfs(int x, int p = 0) { num[x] = 1; for (auto v : adj[x]) if (v != p) dep[v] = dep[x] = 1, dfs(v, x), num[x] += num[v]; }
int vis[MAX];
void calc(int x, int p = 0) {
vis[x] = 1;
int i;
int ptr = 0;
for (i = 0; i < adj[x].size(); i++) {
int v = adj[x][i];
if (dep[v] < dep[x]) continue;
int c = 0;
if (!A[v]) c = 1;
else if (A[v] == A[x] + 1) c = 1;
if (c) adj[x][ptr++] = adj[x][i];
}
adj[x].resize(ptr);
sort(adj[x].begin(), adj[x].end(), [&](int i, int j) {return num[i] < num[j]; });
int pv = 0;
for (auto v : adj[x]) {
if (v == p) continue;
if (pv) { for (auto& [a, b] : st[pv]) segtree::upd(a, -b); }
calc(v, x);
pv = v;
}
if (pv) {
if (!A[pv]) {
for (auto& [a, b] : st[pv]) segtree::upd(a, -b);
st[pv].clear();
all[pv] = ans[pv];
st[pv].emplace_back(A[x] + 1, -ans[pv] + aans[pv]);
segtree::upd(A[x] + 1, -ans[pv] + aans[pv]);
}
else st[x].emplace_back(A[pv], ans[pv] - all[pv]);
all[x] += all[pv];
swap(st[x], st[pv]);
for (auto v : adj[x]) if (v != pv) {
if (!A[v]) {
all[x] += ans[v];
st[x].emplace_back(A[x], -ans[v] + aans[v]);
segtree::upd(A[x] + 1, -ans[v] + aans[v]);
}
else {
all[x] += all[v];
st[v].emplace_back(A[v], ans[v] - all[v]);
for (auto& [a, b] : st[v]) st[x].emplace_back(a, b), segtree::upd(a, b);
}
}
}
ans[x] = all[x] + segtree::query(A[x] + 1, N);
if (!A[x] && p) aans[x] = all[x] + segtree::query(A[p] + 1, A[p] + 1) - F[A[p] + 1];
if (!p) {
for (auto& [a, b] : st[x]) segtree::upd(a, -b);
}
}
signed main() {
ios::sync_with_stdio(false), cin.tie(0);
cin >> N;
segtree::N = N;
int i, a, b;
for (i = 1; i <= N; i++) cin >> A[i];
for (i = 1; i <= N; i++) cin >> F[i];
segtree::init(1, N);
for (i = 1; i < N; i++) {
cin >> a >> b;
adj[a].push_back(b);
adj[b].push_back(a);
}
dfs(1);
ll sum = 0;
vector<int> v;
for (i = 1; i <= N; i++) v.push_back(i);
sort(v.begin(), v.end(), [&](int i, int j) {return dep[i] < dep[j]; });
for (auto x : v) if (!vis[x]) calc(x), sum += ans[x];
cout << sum << ln;
}
Compilation message
code1.cpp:17:21: error: narrowing conversion of '-1.0e+9' from 'double' to 'int' [-Wnarrowing]
17 | int A[MAX] = { -1e9 };
| ^
code1.cpp: In function 'void segtree::init(int, int, int)':
code1.cpp:29:13: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
29 | int m = s + e >> 1;
| ~~^~~
code1.cpp: In function 'void segtree::upd(int, int, int, ll, int)':
code1.cpp:47:13: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
47 | int m = s + e >> 1;
| ~~^~~
code1.cpp: In function 'll segtree::query(int, int, int, int, int)':
code1.cpp:57:13: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
57 | int m = s + e >> 1;
| ~~^~~
code1.cpp: In function 'void calc(int, int)':
code1.cpp:75:16: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
75 | for (i = 0; i < adj[x].size(); i++) {
| ~~^~~~~~~~~~~~~~~