이 제출은 이전 버전의 oj.uz에서 채점하였습니다. 현재는 제출 당시와는 다른 서버에서 채점을 하기 때문에, 다시 제출하면 결과가 달라질 수도 있습니다.
// #pragma GCC optimize("O3,Ofast,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include <bits/stdc++.h>
using namespace std;
#define int long long int
#define MP make_pair
#define pb push_back
#define REP(i,n) for(int i = 0; (i) < (n); (i)++)
#define all(x) x.begin(), x.end()
#define rall(x) x.rbegin(), x.rend()
void fastio() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
}
const double EPS = 0.00001;
const int INF = 1e9+500;
const int N = 1e5+5;
const int ALPH = 26;
const int LGN = 25;
constexpr int MOD = 1e9+7;
int n,m;
vector<int> p(N, 0), sm(N, 0);
array<vector<int>, 105> dpd, dpu;
vector<vector<int> > adj(N, vector<int>());
vector<int> pr(N, 0);
void dfs(int node, int par) {
pr[node] = par;
for(auto c : adj[node]) {
if(c == par) continue;
dfs(c, node);
}
for(int k = 0; k <= m; k++) {
for(int c : adj[node]) {
if(c == par) continue;
dpd[k][node] = max(dpd[k][node], dpd[k][c]);
}
}
for(int k = 1; k <= m; k++) {
for(int c : adj[node]) {
if(c == par) continue;
dpd[k][node] = max(dpd[k][node], dpd[k - 1][c] + sm[node] - p[c]);
}
}
for(int k = 0; k <= m; k++) {
for(int c : adj[node]) {
if(c == par) continue;
dpu[k][node] = max(dpu[k][node], dpu[k][c]);
}
}
for(int k = 1; k <= m; k++) {
for(int c : adj[node]) {
if(c == par) continue;
dpu[k][node] = max(dpu[k][node], dpu[k - 1][c] + sm[c] - p[node]);
}
}
}
inline void solve() {
cin>>n>>m;
for(int i = 1; i <= n; i++) {
cin >> p[i];
}
REP(i, n - 1) {
int u, v;
cin >> u >> v;
adj[u].pb(v);
adj[v].pb(u);
}
for(int i = 1; i<=n; i++) {
for(auto c : adj[i]) {
sm[i] += p[c];
}
}
REP(i, m + 3) {
dpd[i].assign(n + 2, 0);
dpu[i].assign(n + 2, 0);
}
for(int i = 1; i <= n; i++) {
dpd[1][i] = sm[i];
}
dfs(1, 0);
for(int i = 1; i <= n; i++) {
for(int k = 1; k <= m; k++) {
dpd[k][i] = max(dpd[k][i], dpd[k - 1][i]);
dpu[k][i] = max(dpu[k][i], dpu[k - 1][i]);
}
}
// for(int i = 1; i<=n; i++) {
// cout << "i:" << i << " dpd:" << dpd[2][i] << "\n";
// }
int ans = 0;
for(int i = 1; i <= n; i++) {
for(int v = 0; v <= m; v++) {
int v2 = m - v;
vector<array<int, 2> > dmax, umax;
int tmp = 0;
if(v > 0) tmp = sm[i];
dmax.pb({tmp, -1}); umax.pb({0, -2});
dmax.pb({tmp, -4}); umax.pb({0, -3});
for(auto c : adj[i]) {
if(pr[i] == c) continue;
int ret = dpd[v][c];
if(v > 0) {
ret = max(ret, dpd[v - 1][c] + sm[i] - p[c]);
}
dmax.pb({ret, c});
ret = dpu[v2][c];
if(v2 > 0) {
ret = max(ret, dpu[v2 - 1][c] + sm[c] - p[i]);
}
umax.pb({ret, c});
}
sort(rall(dmax));
sort(rall(umax));
if(dmax[0][1] == umax[0][1]) swap(dmax[0], dmax[1]);
ans = max(ans, dmax[0][0] + umax[0][0]);
ans = max(ans, dmax[1][0] + umax[1][0]);
}
}
cout << ans << "\n";
}
signed main() {
fastio();
int test = 1;
//cin>>test;
while(test--) {
solve();
}
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |