이 제출은 이전 버전의 oj.uz에서 채점하였습니다. 현재는 제출 당시와는 다른 서버에서 채점을 하기 때문에, 다시 제출하면 결과가 달라질 수도 있습니다.
#include <bits/stdc++.h>
#include <cassert>
#pragma GCC optimize("O3")
#pragma GCC optimize("Ofast")
#pragma GCC optimize("unroll-loops")
#pragma GCC target("avx,avx2,fma")
using namespace std;
typedef long long ll;
typedef pair<ll, ll> pll;
typedef pair<int, int> pii;
#define MAX 202300
#define MAXS 20
#define INF 100000000000000001
#define bb ' '
#define ln '\n'
#define Ln '\n'
int C[MAX];
vector<int> adj[MAX];
int N, K;
namespace k1 {
ll sum[MAX];
int mv[MAX];
void dfs(int x, int p = 0) {
sum[x] = C[x];
for (auto v : adj[x]) if (v != p) {
dfs(v, x);
if (sum[v] > sum[mv[x]]) mv[x] = v;
}
sum[x] += sum[mv[x]];
}
void solve() {
dfs(1);
vector<int> ansv;
int v = 1;
ll ans = 0;
while (1) {
ansv.push_back(v);
ans += C[v];
v = mv[v];
if (!v) break;
}
cout << ans << ln;
cout << ansv.size() << ln;
for (auto v : ansv) cout << v << bb;
}
}
namespace k2 {
typedef pair<ll, int> pli;
ll dp[MAX];
ll end[2][MAX];
int dpath[MAX]; // dp path
pii epath[MAX]; // end path
int chk[MAX];
int e1path[MAX][3];
// 0 : child -> calc(c) -> another calc(c) -> down(c, 0)
// 1 : child -> calc(c) -> down(c, 1)
const int DEBUG = 0;
int sp[MAX][MAXS];
int dep[MAX] = { 0, 1 };
void dfs(int x, int p = 0) {
if (DEBUG) {
sp[x][0] = p;
int i;
for (i = 1; i < MAXS; i++) sp[x][i] = sp[sp[x][i - 1]][i - 1];
}
pli me[3]; //max end
pli me1[3]; //max end
pli md[3]; //max dp
int i, j, k;
for (i = 0; i < 3; i++) me1[i] = me[i] = md[i] = pli(-INF, -1);
dp[x] += C[x];
end[0][x] += C[x];
end[1][x] += C[x];
if (p && adj[x].size() == 1) {
end[1][x] = -INF;
return;
}
int pv = 0;
pli mme1 = pli(-INF, -1);
for (auto v : adj[x]) if (v != p) {
pv = v;
if (DEBUG) dep[v] = dep[x] + 1;
dfs(v, x);
dp[x] += C[v];
end[0][x] += C[v];
end[1][x] += C[v];
pli d = pli(dp[v] - C[v], v);
pli e = pli(end[0][v] - C[v], v);
pli e1 = pli(end[1][v] - C[v], v);
me[2] = max(me[2], e);
md[2] = max(md[2], d);
me1[2] = max(me1[2], e1);
mme1 = max(mme1, pli(end[1][v], v));
for (i = 2; i >= 1; i--) if (me[i] > me[i - 1]) swap(me[i], me[i - 1]);
for (i = 2; i >= 1; i--) if (me1[i] > me1[i - 1]) swap(me1[i], me1[i - 1]);
for (i = 2; i >= 1; i--) if (md[i] > md[i - 1]) swap(md[i], md[i - 1]);
}
dpath[x] = md[0].second;
dp[x] += md[0].first;
epath[x].second = pv;
int c = 0;
if (p && adj[x].size() <= 2) c = 1;
if (!p && adj[x].size() == 1) c = 1;
if (c) {
end[0][x] += me[0].first;
if (end[0][x] < end[1][pv] + C[x]) {
epath[x] = pii(-1, pv);
end[0][x] = end[1][pv] + C[x];
}
end[1][x] = -INF;
return;
}
ll mx = -INF;
if (md[0].second != me[0].second) {
mx = md[0].first + me[0].first;
epath[x] = pii(md[0].second, me[0].second);
}
else {
if (mx < md[0].first + me[1].first) {
mx = md[0].first + me[1].first;
epath[x] = pii(md[0].second, me[1].second);
assert(md[0].second != me[1].second);
}
if (mx < md[1].first + me[0].first) {
mx = md[1].first + me[0].first;
epath[x] = pii(md[1].second, me[0].second);
assert(md[1].second != me[0].second);
}
}
//check mme1
if (end[0][x] + mx < mme1.first + C[x]) {
end[0][x] = mme1.first + C[x];
epath[x] = pii(-1, mme1.second);
}
else end[0][x] += mx;
mx = -INF;
for (i = 0; i < 3; i++) for (j = i + 1; j < 3; j++) {
if (!~me[i].second) continue;
if (!~me[j].second) continue;
for (k = 0; k < 3; k++) {
if (!~me[k].second) continue;
if (me[k].second == md[i].second) continue;
if (me[k].second == md[j].second) continue;
ll sum = me[k].first + md[i].first + md[j].first;
if (mx < sum) {
mx = sum;
e1path[x][0] = md[i].second;
e1path[x][1] = md[j].second;
e1path[x][2] = me[k].second;
}
}
}
for (i = 0; i < 2; i++) {
if (!~md[i].second) continue;
for (j = 0; j < 2; j++) {
if (!~me1[j].second) continue;
if (md[i].second == me1[j].second) continue;
ll sum = md[i].first + me1[j].first;
if (mx < sum) {
mx = sum;
chk[x] = 1;
e1path[x][0] = md[i].second;
e1path[x][1] = me1[j].second;
}
}
}
end[1][x] += mx;
}
inline int lca(int u, int v) {
int i;
if (dep[u] != dep[v]) {
if (dep[u] > dep[v]) swap(u, v);
int d = dep[v] - dep[u];
for (i = 0; i < MAXS; i++) if (d >> i & 1) v = sp[v][i];
}
if (u == v) return u;
for (i = MAXS - 1; i >= 0; i--) if (sp[u][i] != sp[v][i]) u = sp[u][i], v = sp[v][i];
return sp[u][0];
}
int dis(int u, int v) {
return dep[u] + dep[v] - 2 * dep[lca(u, v)];
}
vector<int> ansv;
ll sum = 0;
void calc(int x, int c, int p = 0) {
if (adj[x].size() == 1) {
ansv.push_back(x);
sum += C[x];
return;
}
if (!c) ansv.push_back(x), calc(dpath[x], c ^ 1, x), sum += C[x];
for (auto v : adj[x]) if (v != p && dpath[x] != v) ansv.push_back(v), sum += C[v];
if (c) calc(dpath[x], c ^ 1, x), ansv.push_back(x), sum += C[x];
}
void down(int x, int c, int p = 0) {
if (!c) {
ansv.push_back(x);
sum += C[x];
if (p && adj[x].size() == 1) return;
if (!~epath[x].first) {
down(epath[x].second, 1, x);
return;
}
if (epath[x].first) calc(epath[x].first, 1, x);
for (auto v : adj[x]) if (v != p) {
if (v == epath[x].first) continue;
if (v == epath[x].second) continue;
ansv.push_back(v);
sum += C[v];
}
down(epath[x].second, 0, x);
}
else {
assert(adj[x].size() > 2);
for (auto v : adj[x]) if (v != p) {
if (v == e1path[x][0]) continue;
if (v == e1path[x][1]) continue;
if (v == e1path[x][2]) continue;
ansv.push_back(v);
sum += C[v];
}
calc(e1path[x][0], 0, x);
ansv.push_back(x);
sum += C[x];
if (chk[x]) down(e1path[x][1], 1, x);
else {
calc(e1path[x][1], 1, x);
down(e1path[x][2], 0, x);
}
}
}
void solve() {
dfs(1);
down(1, 0);
cout << sum << ln;
cout << ansv.size() << ln;
for (auto v : ansv) cout << v << bb;
int i;
for (i = 1; i < ansv.size(); i++) {
if (DEBUG) cout << i << ln;
assert(dis(ansv[i], ansv[i - 1]) <= 2);
}
vector<int> cpy = ansv;
sort(cpy.begin(), cpy.end());
cpy.erase(unique(cpy.begin(), cpy.end()), cpy.end());
assert(cpy.size() == ansv.size());
ll asdfsum = 0;
for (auto v : ansv) asdfsum += C[v];
assert(asdfsum == end[0][1]);
}
}
namespace k3 {
vector<int> ansv;
void dfs(int x, int c, int p = 0) {
if (c) ansv.push_back(x);
for (auto v : adj[x]) if (v != p) dfs(v, c ^ 1, x);
if (!c) ansv.push_back(x);
}
void solve() {
ll sum = 0;
int i;
for (i = 1; i <= N; i++) sum += C[i];
dfs(1, 1);
cout << sum << ln;
cout << N << Ln;
for (auto v : ansv) cout << v << bb;
}
}
signed main() {
ios::sync_with_stdio(false), cin.tie(0);
cin >> N >> K;
int i, a, b;
for (i = 1; i < N; i++) {
cin >> a >> b;
adj[a].push_back(b);
adj[b].push_back(a);
}
for (i = 1; i <= N; i++) cin >> C[i];
if (K == 1) k1::solve();
if (K == 2) k2::solve();
if (K == 3) k3::solve();
}
컴파일 시 표준 에러 (stderr) 메시지
Main.cpp: In function 'void k2::solve()':
Main.cpp:239:17: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
239 | for (i = 1; i < ansv.size(); i++) {
| ~~^~~~~~~~~~~~~
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |