#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/detail/standard_policies.hpp>
using namespace std;
using namespace __gnu_pbds;
mt19937_64 rng((int)std::chrono::steady_clock::now().time_since_epoch().count());
const int MAXN = 1e5 + 10;
const int MOD = 1e9 + 7;
#define int long long
typedef tree<pair<int, int>, null_type, less<pair<int, int> >, rb_tree_tag, tree_order_statistics_node_update> ordered_set;
int rnd(int x, int y) {
int u = uniform_int_distribution<int>(x, y)(rng); return u;
}
int n, a[MAXN]; set<pair<int, int> > adj[MAXN];
//Centroid: when you cut this vertex (remove it and
//remove all edges from this vertex), the size of the largest
//connected component of the remaining graph is the smallest possible.
//https://codeforces.com/problemset/problem/1406/C
//Total time complexity: O(n log^2 n)
int sz, mi, ctd, ans = 0; // ctd = centroid
int dfs0(int node, int prv) {
int tot = 1;
for(auto x: adj[node]) {
if(x.first != prv) {
tot += dfs0(x.first, node);
}
}
return tot;
}
int dfs1(int node, int prv) {
int ret = 1;
int ma = 0;
for(auto x: adj[node]) {
if(x.first != prv) {
int q = dfs1(x.first, node);
ret += q;
ma = max(ma, q);
}
}
int rem = sz - ret;
ma = max(ma, rem);
if(ma < mi) {
mi = ma;
ctd = node;
}
return ret;
}
int id = 0;
void decomp(int st) {
//st: an arbitrary node in the tree in question
//step 1: find the centroid + count number of nodes
sz = dfs0(st, -1);
mi = 1e9, ctd = -1;
dfs1(st, -1);
//step 2: if nodes = 1, break
if(sz == 1) return;
//step 3: root the subtree at centroid, do a bfs
// all weights = {sum, min}
//step 3a: count id[u] < id[v]
vector<pair<int,int> > all;
for(auto x: adj[ctd]) all.push_back(x);
ordered_set s;
for(int i=0; i<all.size(); i++) {
adj[ctd].erase(all[i]);
adj[all[i].first].erase({ctd, all[i].second});
unordered_map<int, pair<int, int> > w1; // weights to count
unordered_map<int, pair<int, int> > w2; // weights to insert
queue<int> q;
unordered_map<int, int> vis;
vector<int> omg;
w1[all[i].first] = {a[ctd] + all[i].second, a[ctd] + all[i].second};
w2[all[i].first] = {a[all[i].first] + all[i].second, a[all[i].first] + all[i].second};
q.push(all[i].first);
while(q.size()) {
int f=q.front(); q.pop();
if(vis[f]) continue;
vis[f] = 1; omg.push_back(f);
if(s.size()>0) {
if(w1[f].second>=0) ans += s.size();
else ans += s.size() - s.order_of_key({-w1[f].second, 0});
}
for(auto x: adj[f]) {
if(vis[x.first]) continue;
q.push(x.first);
w1[x.first] = {w1[f].first + a[f] + x.second, min(w1[f].second, w1[f].first + a[f] + x.second)};
w2[x.first] = {w2[f].first + a[x.first] + x.second, min(w2[f].second, 0ll) + a[x.first] + x.second};
}
}
//insert
for(int x: omg) {
if(w2[x].second >= 0) s.insert({w2[x].first, ++id});
}
adj[ctd].insert(all[i]);
adj[all[i].first].insert({ctd, all[i].second});
}
reverse(all.begin(), all.end());
s.clear();
//step 3b: count id[u] > id[v]
for(int i=0; i<all.size(); i++) {
adj[ctd].erase(all[i]);
adj[all[i].first].erase({ctd, all[i].second});
unordered_map<int, pair<int, int> > w1; // weights to count
unordered_map<int, pair<int, int> > w2; // weights to insert
queue<int> q;
unordered_map<int, int> vis;
vector<int> omg;
w1[all[i].first] = {a[ctd] + all[i].second, a[ctd] + all[i].second};
w2[all[i].first] = {a[all[i].first] + all[i].second, a[all[i].first] + all[i].second};
q.push(all[i].first);
while(q.size()) {
int f=q.front(); q.pop();
if(vis[f]) continue;
vis[f] = 1; omg.push_back(f);
if(s.size()>0) {
if(w1[f].second>=0) ans += s.size();
else ans += s.size() - s.order_of_key({-w1[f].second, 0});
}
for(auto x: adj[f]) {
if(vis[x.first]) continue;
q.push(x.first);
w1[x.first] = {w1[f].first + a[f] + x.second, min(w1[f].second, w1[f].first + a[f] + x.second)};
w2[x.first] = {w2[f].first + a[x.first] + x.second, min(w2[f].second, 0ll) + a[x.first] + x.second};
}
}
//insert
for(int x: omg) {
if(w2[x].second >= 0) {
s.insert({w2[x].first, ++id});
}
}
adj[ctd].insert(all[i]);
adj[all[i].first].insert({ctd, all[i].second});
//step 3c: count u = root (= ctd)
for(int x: omg) {
if(w1[x].second >= 0) {
ans++;
}
}
}
//step 3d: count v = root (= ctd)
ans += s.size();
//step 4: remove all edges from the centroid
for(auto x: adj[ctd]) adj[x.first].erase({ctd, x.second});
for(auto x: adj[ctd]) decomp(x.first);
adj[ctd].clear();
}
void solve(int tc) {
cin >> n;
for(int i=1; i<=n; i++) cin >> a[i];
for(int i=1; i<n; i++) {
int u, v, w; cin >> u >> v >> w;
w = -w;
adj[u].insert({v, w});
adj[v].insert({u, w});
}
decomp(1);
cout << ans << "\n";
}
int32_t main(){
ios::sync_with_stdio(0); cin.tie(0);
int t = 1; //cin >> t;
for(int i=1; i<=t; i++) solve(i);
}
Compilation message
transport.cpp: In function 'void decomp(long long int)':
transport.cpp:66:17: warning: comparison of integer expressions of different signedness: 'long long int' and 'std::vector<std::pair<long long int, long long int> >::size_type' {aka 'long unsigned int'} [-Wsign-compare]
66 | for(int i=0; i<all.size(); i++) {
| ~^~~~~~~~~~~
transport.cpp:104:17: warning: comparison of integer expressions of different signedness: 'long long int' and 'std::vector<std::pair<long long int, long long int> >::size_type' {aka 'long unsigned int'} [-Wsign-compare]
104 | for(int i=0; i<all.size(); i++) {
| ~^~~~~~~~~~~
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
23 ms |
5708 KB |
Output is correct |
2 |
Correct |
32 ms |
5836 KB |
Output is correct |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
42 ms |
6220 KB |
Output is correct |
2 |
Correct |
58 ms |
6476 KB |
Output is correct |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
536 ms |
18500 KB |
Output is correct |
2 |
Correct |
407 ms |
15808 KB |
Output is correct |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
690 ms |
22308 KB |
Output is correct |
2 |
Correct |
814 ms |
23652 KB |
Output is correct |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Execution timed out |
1087 ms |
28528 KB |
Time limit exceeded |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
430 ms |
11804 KB |
Output is correct |
2 |
Correct |
196 ms |
10392 KB |
Output is correct |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
482 ms |
14512 KB |
Output is correct |
2 |
Correct |
657 ms |
13912 KB |
Output is correct |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
749 ms |
16948 KB |
Output is correct |
2 |
Correct |
749 ms |
17848 KB |
Output is correct |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Execution timed out |
1053 ms |
22180 KB |
Time limit exceeded |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Execution timed out |
1084 ms |
28504 KB |
Time limit exceeded |
2 |
Halted |
0 ms |
0 KB |
- |