#include <bits/stdc++.h>
using namespace std;
using ll = long long;
#define all(x) begin(x), end(x)
#define sz(x) (int)x.size()
#define pb push_back
const int maxn = 10001;
vector<int> adj[maxn];
int A[maxn];
bitset<100> dp[maxn], x[100][100];
void dfs(int s, int p){
vector<int> v1, v2;
for(auto u: adj[s]){
if(u == p) continue;
if(A[u] > A[s]) v1.pb(u);
else if(A[u] < A[s]) v2.pb(u);
dfs(u, s);
}
sort(all(v1), [](int a, int b){
return A[a] < A[b];
});
dp[s][A[s]] = 1;
for(auto u: v1){
auto y = (dp[s] & x[A[s]][A[u]-1]) & ((dp[u] & x[A[s]+1][A[u]])>>1);
if(!y.none()){
dp[s] |= dp[u] & x[A[u]][99];
}
}
sort(all(v2), [](int a, int b){
return A[a] > A[b];
});
for(auto u: v2){
auto y = (dp[s] & x[A[u]+1][A[s]]) & ((dp[u] & x[A[u]][A[s]-1])<<1);
if(!y.none()){
dp[s] |= dp[u] & x[0][A[u]];
}
}
}
int main(){
ios::sync_with_stdio(0);
cin.tie(0);
int n; cin >> n;
for(int i=1; i<=n; i++){
cin >> A[i];
A[i]--;
}
for(int i=0; i<n-1; i++){
int a, b; cin >> a >> b;
adj[a].pb(b); adj[b].pb(a);
}
for(int i=0; i<100; i++){
for(int j=i; j<100; j++){
for(int k=i; k<=j; k++){
x[i][j][k] = 1;
}
}
}
dfs(1, 0);
int ans = 0;
for(int i=0; i<=A[1]; i++){
for(int j=A[1]; j<100; j++){
if(dp[1][i] == 1 and dp[1][j] == 1){
ans++;
}
}
}
cout << ans;
}
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |