#include<bits/stdc++.h>
using namespace std;
using ll = long long;
const ll N = 1e3 + 10;
ll dp[N][102][102];
ll c[N];
vector < ll > adj[N];
void Go(ll node, ll par) {
ll j, r, sz, pd[102][102] ={0};
for ( ll chi : adj[node]) {
if ( chi == par) continue;
Go(chi, node);
for ( j= 1; j <= 100; j ++) {
for ( r = j; r <= 100; r ++) {
pd[j][r] += dp[chi][j][r];
}
}
}
for ( sz = 1; sz <= 100; sz ++) {
for ( j = 1; j <= 100; j ++) {
for ( r = 1; r< sz; r ++) {
if ( j + sz - 1 > 100) continue;
pd[j][j + sz - 1] = pd[j][j + sz - 1] + (pd[j][j + r - 1] * pd[j + r][j + sz - 1]);
}
}
}
ll p = c[node];
for ( j= p; j >= 1; j --) {
for ( r = p; r <= 100; r ++) {
if( r== p && j == p) {
dp[node][j][r] = 1;
continue;
}
if ( r == p) {
dp[node][j][r] = pd[j][p - 1];
continue;
}
if ( j == p) {
dp[node][j][r] = pd[p + 1][r];
continue;
}
dp[node][j][r] = pd[j][p - 1] * pd[p + 1][r];
}
}
}
int main() {
ll n, m, r, x, y, i, j, ans, t;
cin >> n;
for (i = 1; i <= n; i++) {
cin >> c[i];
}
for (i = 1; i < n; i ++) {
cin >> x >> y;
adj[x].push_back(y);
adj[y].push_back(x);
}
Go(1, 1);
ans = 0;
for (i = 1; i<= 100; i ++){
for (j = i; j <= 100; j ++) {
ans += dp[1][i][j];
}
}
cout << ans << endl;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |