#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int N = 100100;
vector<pair<int, int>> adj[N];
ll ans = 0;
bool vis[N];
int sz[N], m;
ll now[N], fw[N], a[N];
vector<ll> f, c;
int dfs_sz(int v, int p = -1) {
sz[v] = 1;
for (auto& [x, w] : adj[v]) {
if (x == p || vis[x]) continue;
sz[v] += dfs_sz(x, v);
}
return sz[v];
}
int centroid(int v, int p, int s) {
for (auto& [x, w] : adj[v]) {
if (x == p || vis[x]) continue;
if (sz[x] > s) return centroid(x, v, s);
}
return v;
}
void init(int v, int p, ll sum, ll d) {
if (d >= 0) ans++;
f.push_back(-d);
sum += a[v];
for (auto& [x, w] : adj[v]) {
if (x == p || vis[x]) continue;
init(x, v, sum - w, min(d, sum - w));
}
}
void add(int v, int p, ll sum, ll d) {
int idx = upper_bound(c.begin(), c.end(), -d) - c.begin();
for (int i = idx;i <= m;i += i & -i) fw[i]++;
sum += a[v];
for (auto& [x, w] : adj[v]) {
if (x == p || vis[x]) continue;
int W = a[v] - w;
add(x, v, sum - w, min(d, sum - w));
}
}
void rem(int v, int p, ll sum, ll d) {
int idx = upper_bound(c.begin(), c.end(), -d) - c.begin();
for (int i = idx;i <= m;i += i & -i) fw[i]--;
sum += a[v];
for (auto& [x, w] : adj[v]) {
if (x == p || vis[x]) continue;
rem(x, v, sum - w, min(d, sum - w));
}
}
void cal(int v, int p, ll sum, ll mn) {
if (mn >= 0) {
int idx = upper_bound(c.begin(), c.end(), sum) - c.begin();
for (int i = idx;i > 0;i -= i & -i) ans += fw[i];
}
for (auto& [x, w] : adj[v]) {
if (x == p || vis[x]) continue;
ll W = a[x] - w;
cal(x, v, sum + W, min(mn + W, W));
}
}
void decom(int v) {
v = centroid(v, v, dfs_sz(v) / 2);
vis[v] = 1;
init(v, v, 0, 0);
ans--;
// compress
sort(f.begin(), f.end());
for (auto& x : f) {
if (c.empty() || x != c.back()) c.push_back(x);
}
m = c.size();
add(v, v, 0, 0);
for (auto& [x, w] : adj[v]) {
if (vis[x]) continue;
rem(x, v, a[v] - w, min(0ll, a[v] - w));
cal(x, v, a[x] - w, min(0ll, a[x] - w));
add(x, v, a[v] - w, min(0ll, a[v] - w));
}
rem(v, v, 0, 0);
f.clear();
c.clear();
for (auto& [x, w] : adj[v]) {
if (!vis[x]) decom(x);
}
}
int main() {
cin.tie(nullptr)->sync_with_stdio(false);
int n;
cin >> n;
for (int i = 1;i <= n;i++) cin >> a[i];
for (int i = 1;i < n;i++) {
int u, v, w;
cin >> u >> v >> w;
adj[u].push_back({ v, w });
adj[v].push_back({ u, w });
}
decom(1);
cout << ans;
return 0;
}
Compilation message
transport.cpp: In function 'void add(int, int, ll, ll)':
transport.cpp:41:13: warning: unused variable 'W' [-Wunused-variable]
41 | int W = a[v] - w;
| ^
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
8 ms |
5464 KB |
Output is correct |
2 |
Correct |
7 ms |
5812 KB |
Output is correct |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
6 ms |
5724 KB |
Output is correct |
2 |
Correct |
10 ms |
6212 KB |
Output is correct |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
65 ms |
10200 KB |
Output is correct |
2 |
Correct |
77 ms |
10976 KB |
Output is correct |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
98 ms |
11744 KB |
Output is correct |
2 |
Correct |
136 ms |
14676 KB |
Output is correct |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
139 ms |
14284 KB |
Output is correct |
2 |
Correct |
249 ms |
18636 KB |
Output is correct |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
65 ms |
7000 KB |
Output is correct |
2 |
Correct |
41 ms |
7516 KB |
Output is correct |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
60 ms |
8148 KB |
Output is correct |
2 |
Correct |
124 ms |
8916 KB |
Output is correct |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
160 ms |
8400 KB |
Output is correct |
2 |
Correct |
189 ms |
10660 KB |
Output is correct |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
225 ms |
9604 KB |
Output is correct |
2 |
Correct |
245 ms |
11560 KB |
Output is correct |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
306 ms |
10824 KB |
Output is correct |
2 |
Correct |
313 ms |
12756 KB |
Output is correct |