# | 제출 시각 | 아이디 | 문제 | 언어 | 결과 | 실행 시간 | 메모리 |
---|---|---|---|---|---|---|---|
163144 | arnold518 | Chase (CEOI17_chase) | C++14 | 2951 ms | 346980 KiB |
이 제출은 이전 버전의 oj.uz에서 채점하였습니다. 현재는 제출 당시와는 다른 서버에서 채점을 하기 때문에, 다시 제출하면 결과가 달라질 수도 있습니다.
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
const int MAXN = 1e5;
const int MAXK = 100;
int N, K;
ll P[MAXN+10], S[MAXN+10], dp1[MAXN+10][MAXK+10], dp2[MAXN+10][MAXK+10], ans;
vector<int> adj[MAXN+10];
void dfs1(int now, int bef)
{
int i, j;
for(int nxt : adj[now])
{
if(nxt==bef) continue;
dfs1(nxt, now);
}
for(int nxt : adj[now]) S[now]+=P[nxt];
for(i=1; i<=K; i++)
{
ll val=0;
for(int nxt : adj[now])
{
if(nxt==bef) continue;
val=max(val, dp1[nxt][i-1]);
dp1[now][i]=max(dp1[now][i], dp1[nxt][i]);
}
dp1[now][i]=max(dp1[now][i], S[now]-P[bef]+val);
}
}
vector<pll> V[MAXK+10];
void dfs2(int now, int bef)
{
int i, j;
for(i=0; i<=K; i++) V[i].clear();
for(int nxt : adj[now])
{
if(nxt==bef) continue;
for(i=0; i<=K; i++) V[i].push_back({dp1[nxt][i], nxt});
}
if(bef) for(i=0; i<=K; i++) V[i].push_back({dp2[bef][i], bef});
for(i=0; i<=K; i++) sort(V[i].begin(), V[i].end(), greater<pll>());
for(int nxt : adj[now])
{
if(nxt==bef) continue;
for(i=1; i<=K; i++)
{
for(j=0; j<V[i].size(); j++)
{
if(V[i][j].second==nxt) continue;
dp2[nxt][i]=max(dp2[nxt][i], V[i][j].first);
break;
}
for(j=0; j<V[i-1].size(); j++)
{
if(V[i-1][j].second==nxt) continue;
dp2[nxt][i]=max(dp2[nxt][i], V[i-1][j].first+S[now]-P[nxt]);
break;
}
}
}
for(int nxt : adj[now])
{
if(nxt==bef) continue;
dfs2(nxt, now);
}
}
void dfs3(int now, int bef)
{
for(int nxt : adj[now])
{
if(nxt==bef) continue;
dfs3(nxt, now);
}
if(bef==0)
{
ans=max(ans, dp1[now][K]);
}
else
{
ll val=0;
for(int nxt : adj[now])
{
if(nxt==bef) continue;
val=max(val, dp1[nxt][K]);
}
val=max(val, dp2[now][K]); ans=max(ans, val);
val=0;
for(int nxt : adj[now])
{
if(nxt==bef) continue;
val=max(val, dp1[nxt][K-1]);
}
val=max(val, dp2[now][K-1]);
val+=S[now];
ans=max(ans, val);
}
}
int main()
{
int i, j, k;
scanf("%d%d", &N, &K);
for(i=1; i<=N; i++) scanf("%lld", &P[i]);
for(i=1; i<N; i++)
{
int u, v;
scanf("%d%d", &u, &v);
adj[u].push_back(v);
adj[v].push_back(u);
}
if(K==0) return !printf("0");
for(i=0; i<=K; i++) dp2[1][i]=-1e17;
dfs1(1, 0);
dfs2(1, 0);
dfs3(1, 0);
printf("%lld", ans);
}
컴파일 시 표준 에러 (stderr) 메시지
# | Verdict | Execution time | Memory | Grader output |
---|---|---|---|---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|---|---|---|---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|---|---|---|---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|---|---|---|---|
Fetching results... |