#include<bits/stdc++.h>
#define ll long long
#define pii pair<int,int>
#define pll pair<ll,ll>
#define double long double
#define f first
#define s second
#define pb push_back
#define all(x) x.begin(),x.end()
#define vi vector<int>
#define vvi vector<vi>
#define vp vector<pii>
using namespace std;
const int N=1e5+5;
vector<pii>g[N];
int sz[N],d[N],dp[N],dpf[N],a[N];
bool vis[N]{0};
vector<int>anc[N],aa[N];
int getsize(int u,int p){
sz[u]=1;
for(auto v:g[u])if(v.f!=p&&!vis[v.f])sz[u]+=getsize(v.f,u);
return sz[u];
}
int getct(int u,int p,int req){
for(auto v:g[u]){
if(v.f!=p&&!vis[v.f]&&sz[v.f]*2>req)return getct(v.f,u,req);
}return u;
}
void getd(int u,int p){
for(auto v:g[u]){
if(v.f==p||vis[v.f])continue;
d[v.f] = d[u]-v.s+a[v.f];
getd(v.f,u);
}
}
void getdp(int u,int p,int x){
for(auto v:g[u]){
if(v.f==p||vis[v.f])continue;
dp[v.f] = min(dp[u],d[v.f]-a[v.f]);
getdp(v.f,u,x);
}anc[x].pb(-dp[u]);
}
vector<int>fw[N];
void add(int i,int j,int amt){
j=upper_bound(anc[i].begin(),anc[i].end(),j)-anc[i].begin();
for(;j<fw[i].size();j+=j&-j)fw[i][j]+=amt;
}
int qr(int i,int j,int res=0){
j=upper_bound(anc[i].begin(),anc[i].end(),j)-anc[i].begin();
for(;j;j-=j&-j)res+=fw[i][j];
return res;
}ll ans=0;
void solve(int u,int p,int x,int mx){
add(x,-dp[u],-1);
if(d[u]>=mx)aa[x].pb(u);
for(auto v:g[u]){
if(vis[v.f]||v.f==p)continue;
solve(v.f,u,x,max(mx,d[u]));
}
}
void resolve(int u,int p,int x){
add(x,-dp[u],1);
for(auto v:g[u]){
if(vis[v.f]||v.f==p)continue;
resolve(v.f,u,x);
}
}
void play(int u){
u = getct(u,u,getsize(u,u));
vis[u]=1;d[u]=a[u];getd(u,u);
dp[u] = 0;getdp(u,u,u);
sort(anc[u].begin(),anc[u].end());
fw[u].resize(anc[u].size()+2);
for(auto it : anc[u])add(u,it,1);
add(u,-dp[u],-1);
ans += qr(u,d[u]-a[u]);
add(u,-dp[u],1);
for(auto v:g[u]){
if(vis[v.f])continue;
solve(v.f,u,u,a[u]);
for(auto it : aa[u]){
ans += qr(u,d[it]-a[u]);
}aa[u].clear();
resolve(v.f,u,u);
}
for(auto v:g[u]){
if(vis[v.f])continue;
play(v.f);
}
}
int main(){
ios_base::sync_with_stdio(0);cin.tie(0);
int n;cin>>n;
for(int i=1;i<=n;i++)cin>>a[i];
for(int i=1,u,v,w;i<=n-1;i++){
cin>>u>>v>>w;g[u].pb({v,w});g[v].pb({u,w});
}play(1);cout<<ans;
}
//go_dp[u] = max(W[i],d[u]);
//from_dp[u] = min(d[u]-d[lca],w[i]);
Compilation message
transport.cpp: In function 'void add(int, int, int)':
transport.cpp:47:11: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
47 | for(;j<fw[i].size();j+=j&-j)fw[i][j]+=amt;
| ~^~~~~~~~~~~~~
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
8 ms |
12380 KB |
Output is correct |
2 |
Correct |
7 ms |
12432 KB |
Output is correct |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
9 ms |
12892 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
108 ms |
24792 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
135 ms |
28632 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
182 ms |
36560 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
75 ms |
17360 KB |
Output is correct |
2 |
Correct |
42 ms |
16068 KB |
Output is correct |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
73 ms |
18640 KB |
Output is correct |
2 |
Correct |
128 ms |
20944 KB |
Output is correct |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
192 ms |
22756 KB |
Output is correct |
2 |
Correct |
192 ms |
24320 KB |
Output is correct |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
255 ms |
26060 KB |
Output is correct |
2 |
Correct |
254 ms |
28036 KB |
Output is correct |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
323 ms |
34000 KB |
Output is correct |
2 |
Correct |
344 ms |
33912 KB |
Output is correct |