이 제출은 이전 버전의 oj.uz에서 채점하였습니다. 현재는 제출 당시와는 다른 서버에서 채점을 하기 때문에, 다시 제출하면 결과가 달라질 수도 있습니다.
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
int n;
vector<int> adj[200200];
ll a[200200];
namespace foo{
ll dp[200200];
int nxt[200200];
void dfs(int s, int pa = 0){
dp[s] = a[s], nxt[s] = 0;
for (auto &v:adj[s]) if (v!=pa){
dfs(v, s);
if (dp[s] < dp[v] + a[s]){
dp[s] = dp[v] + a[s];
nxt[s] = v;
}
}
}
void solve1(){
dfs(1);
printf("%lld\n", dp[1]);
vector<int> V;
for (int i=1;i;i=nxt[i]){
V.push_back(i);
}
printf("%d\n", (int)V.size());
for (auto &x:V) printf("%d ", x);
printf("\n");
}
} // namespace foo
namespace bar{
constexpr ll INF = 4e18;
ll dp[200200][2][3];
int v1[200200], v2[200200][2], v3[200200][3], typ2[200200], typ3[200200], nxt[200200];
vector<int> G[200200];
void dfs(int s, int pa = 0){
vector<pair<ll, int>> C1, C2, C3;
for (int i=1;i<=3;i++){
C1.emplace_back(0, -i);
C2.emplace_back(0, -i);
C3.emplace_back(0, -i);
}
ll sum = 0;
int deg = 0;
for (auto &v:adj[s]) if (v!=pa){
G[s].push_back(v);
dfs(v, s);
sum += a[v];
deg++;
C1.emplace_back(dp[v][1][0] - a[v], v);
C2.emplace_back(dp[v][0][2] - a[v], v);
C3.emplace_back(dp[v][1][2], v);
}
sort(C1.begin(), C1.end(), greater<pair<ll, int>>());
sort(C2.begin(), C2.end(), greater<pair<ll, int>>());
sort(C3.begin(), C3.end(), greater<pair<ll, int>>());
// calc 0 1
dp[s][0][1] = a[s] + sum;
dp[s][0][1] += C1[0].first;
v1[s] = C1[0].second;
if (deg==0) dp[s][0][1] = -INF;
// calc 0 2
dp[s][0][2] = -INF;
for (int i=0;i<2;i++){
for (int j=0;j<2;j++) if (C1[i].second != C2[j].second){
ll val = a[s] + sum + C1[i].first + C2[j].first;
if (dp[s][0][2] < val){
dp[s][0][2] = val;
typ2[s] = 2;
v2[s][0] = C1[i].second;
v2[s][1] = C2[j].second;
}
}
}
if (dp[s][0][2] < a[s] + C3[0].first){
dp[s][0][2] = a[s] + C3[0].first;
typ2[s] = 1;
v2[s][0] = C3[0].second;
}
// calc 1 2
dp[s][1][2] = -INF;
vector<pair<ll, int>> C4 = C1;
for (auto &v:adj[s]) if (v!=pa) C4.emplace_back(0, n+v);
sort(C4.begin(), C4.end(), greater<pair<ll, int>>());
for (int i=0;i<3;i++){
for (int j=0;j<3;j++){
for (int k=0;k<3;k++){
int rc4 = C4[i].second;
if (rc4 > n) rc4 -= n;
if (rc4==C1[j].second || rc4==C2[k].second || C1[j].second==C2[k].second) continue;
if (rc4 < 0) continue;
ll val = a[s] + sum + C4[i].first + C1[j].first + C2[k].first;
if (dp[s][1][2] < val){
dp[s][1][2] = val;
typ3[s] = 3;
v3[s][0] = C4[i].second;
v3[s][1] = C1[j].second;
v3[s][2] = C2[k].second;
}
}
}
}
for (int i=0;i<2;i++){
for (int j=0;j<2;j++) if (C4[i].second != C3[j].second && C4[i].second > 0){
int rc4 = C4[i].second;
if (rc4 > n) rc4 -= n;
if (rc4==C3[j].second) continue;
if (rc4 < 0) continue;
ll val = a[s] + C4[i].first + C3[j].first;
if (C4[i].second > 0) val += a[C4[i].second];
if (dp[s][1][2] < val){
dp[s][1][2] = val;
typ3[s] = 2;
v3[s][0] = C4[i].second;
v3[s][1] = C3[j].second;
}
}
}
if (deg==0) dp[s][1][2] = -INF;
// calc 1 0
dp[s][1][0] = dp[s][0][1];
if (dp[s][1][2] < dp[s][1][0]) dp[s][1][2] = dp[s][1][0], typ3[s] = 0;
// printf("\nok %d done / sum = %lld / deg = %d / val = %lld\n", s, sum, deg, a[s]);
// printf("0 1: %lld / v1 = %d\n", dp[s][0][1], v1[s]);
// printf("1 0: %lld / v1 = %d\n", dp[s][1][0], v1[s]);
// printf("0 2: %lld / typ2 = %d / v2 = %d %d\n", dp[s][0][2], typ2[s], v2[s][0], v2[s][1]);
// printf("1 2: %lld / typ3 = %d / v3 = %d %d %d\n", dp[s][1][2], typ3[s], v3[s][0], v3[s][1], v3[s][2]);
}
pair<int, int> track(int s, int x, int y){
// printf(" call %d %d %d\n", s, x, y);
if (x==0 && y==0) return {s, s};
assert(s > 0);
assert(dp[s][x][y] >= 0);
if (x==1 && y==2 && typ3[s]==0) y = 0;
if (x==0 && y==1){
pair<int, int> ret = {s, s};
if (v1[s] > 0){
auto [l1, r1] = track(v1[s], 1, 0);
nxt[s] = l1;
ret.second = r1;
G[s].erase(find(G[s].begin(), G[s].end(), v1[s]));
}
for (auto &v:G[s]){
nxt[ret.second] = v;
ret.second = v;
}
return ret;
}
if (x==0 && y==2){
pair<int, int> ret = {s, s};
if (typ2[s]==1){
if (v2[s][0] > 0){
auto [l1, r1] = track(v2[s][0], 1, 2);
nxt[s] = l1;
ret.second = r1;
}
return ret;
}
// else (typ2[s]==2)
if (v2[s][0] > 0){
auto [l1, r1] = track(v2[s][0], 1, 0);
nxt[s] = l1;
ret.second = r1;
G[s].erase(find(G[s].begin(), G[s].end(), v2[s][0]));
}
if (v2[s][1] > 0){
G[s].erase(find(G[s].begin(), G[s].end(), v2[s][1]));
}
for (auto &v:G[s]){
nxt[ret.second] = v;
ret.second = v;
}
if (v2[s][1] > 0){
nxt[ret.second] = v2[s][1];
ret.second = v2[s][1];
auto [l2, r2] = track(v2[s][1], 0, 2);
ret.second = r2;
}
return ret;
}
if (x==1 && y==0){
pair<int, int> ret = {s, s};
if (v1[s] > 0){
auto [l1, r1] = track(v1[s], 0, 1);
nxt[r1] = s;
ret.first = l1;
G[s].erase(find(G[s].begin(), G[s].end(), v1[s]));
}
for (auto &v:G[s]){
nxt[v] = ret.first;
ret.first = v;
}
assert(ret.first != s);
return ret;
}
assert(x==1 && y==2);
pair<int, int> ret = {s, s};
if (typ3[s]==2){
if (v3[s][0] > 0){
int rv3 = v3[s][0], ri = 0, rj = 1;
if (v3[s][0] > n) rv3 = v3[s][0]-n, ri = 0, rj = 0;
auto [l1, r1] = track(rv3, ri, rj);
ret.first = l1;
nxt[r1] = s;
G[s].erase(find(G[s].begin(), G[s].end(), rv3));
}
if (v3[s][1] > 0){
auto [l2, r2] = track(v3[s][1], 1, 2);
nxt[s] = l2;
ret.second = r2;
}
assert(ret.first != s);
return ret;
}
//typ3[s] == 3
if (v3[s][0] > 0){
int rv3 = v3[s][0], ri = 0, rj = 1;
if (v3[s][0] > n) rv3 = v3[s][0]-n, ri = 0, rj = 0;
auto [l1, r1] = track(rv3, ri, rj);
ret.first = l1;
nxt[r1] = s;
G[s].erase(find(G[s].begin(), G[s].end(), rv3));
}
if (v3[s][1] > 0){
auto [l2, r2] = track(v3[s][1], 1, 0);
nxt[s] = l2;
ret.second = r2;
G[s].erase(find(G[s].begin(), G[s].end(), v3[s][1]));
}
if (v3[s][2] > 0){
G[s].erase(find(G[s].begin(), G[s].end(), v3[s][2]));
}
for (auto &v:G[s]){
nxt[ret.second] = v;
ret.second = v;
}
if (v3[s][2] > 0){
auto [l3, r3] = track(v3[s][2], 0, 2);
nxt[ret.second] = l3;
ret.second = r3;
}
assert(ret.first != s);
return ret;
}
void solve2(){
for (int i=1;i<=n;i++) nxt[i] = 0, G[i].clear();
dfs(1);
ll ans = dp[1][0][2];
printf("%lld\n", ans);
track(1, 0, 2);
ll fuck = 0;
vector<int> V;
for (int i=1;i;i=nxt[i]){
V.push_back(i);
fuck += a[i];
// printf(" %d", i);
}
// printf("\n");
assert(ans == fuck);
printf("%d\n", (int)V.size());
for (auto &x:V) printf("%d ", x);
printf("\n");
}
} // namespace bar
namespace baz{
void dfs1(int s, int pa);
void dfs2(int s, int pa);
vector<int> V;
void dfs1(int s, int pa){
V.push_back(s);
for (auto &v:adj[s]) if (v!=pa) dfs2(v, s);
}
void dfs2(int s, int pa){
for (auto &v:adj[s]) if (v!=pa) dfs1(v, s);
V.push_back(s);
}
void solve3(){
dfs1(1, 0);
ll sum = 0;
for (int i=1;i<=n;i++) sum += a[i];
printf("%lld\n", sum);
printf("%d\n", n);
for (auto &x:V) printf("%d ", x);
printf("\n");
}
} // namespace baz
struct DSU{
int path[100100];
void init(int n){for (int i=1;i<=n;i++) path[i] = i;}
int find(int s){
if (s==path[s]) return s;
return path[s] = find(path[s]);
}
int merge(int s, int v){
s = find(s), v = find(v);
if (s==v) return 0;
path[v] = s;
return 1;
}
}dsu;
mt19937 seed(1557);
uniform_int_distribution<int> rng(0, 2147483647);
int getrand(int l, int r){return rng(seed) % (r-l+1) + l;}
void gen(){
n = getrand(1, 5);
dsu.init(n);
for (int i=1;i<=n;i++) adj[i].clear();
for (int i=1;i<=n;i++) a[i] = getrand(1, 1);
for (int i=1;i<=n-1;i++){
int x, y;
do{
x = getrand(1, n);
y = getrand(1, n);
}while(dsu.find(x)==dsu.find(y));
dsu.merge(x, y);
adj[x].push_back(y);
adj[y].push_back(x);
}
}
void stress(int tc){
printf("------------------------\n");
printf("Stress #%d\n", tc);
// int ans = gen();
gen();
printf("Input:\n");
printf("%d 2\n", n);
for (int i=1;i<=n;i++){
for (auto &j:adj[i]) if (j<i) printf("%d %d\n", j, i);
}
for (int i=1;i<=n;i++) printf("%lld ", a[i]);
printf("\n");
// int out = solve();
// ans = naive(ans);
// assert(ans < INF);
bar::solve2();
// printf("Answer: %d\n", ans);
// printf("Output: %d\n", out);
// if (!(out==ans && out!=-1)){
// printf("------------------------\n");
// printf("Stress #%d\n", tc);
// printf("Input:\n");
// printf("%s\n%s\n", s+1, t+1);
// printf("Answer: %d\n", ans);
// printf("Output: %d\n", out);
// }
// assert(out == ans && out != -1);
if (tc%10000==0) printf("ok %d done\n", tc);
}
int main(){
// for (int i=1;;i++) stress(i);
int k;
scanf("%d %d", &n, &k);
for (int i=1;i<=n-1;i++){
int x, y;
scanf("%d %d", &x, &y);
adj[x].push_back(y);
adj[y].push_back(x);
}
for (int i=1;i<=n;i++) scanf("%lld", a+i);
if (k==1) foo::solve1();
else if (k==2) bar::solve2();
else baz::solve3();
}
컴파일 시 표준 에러 (stderr) 메시지
Main.cpp: In function 'int main()':
Main.cpp:473:7: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
473 | scanf("%d %d", &n, &k);
| ~~~~~^~~~~~~~~~~~~~~~~
Main.cpp:477:8: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
477 | scanf("%d %d", &x, &y);
| ~~~~~^~~~~~~~~~~~~~~~~
Main.cpp:482:30: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
482 | for (int i=1;i<=n;i++) scanf("%lld", a+i);
| ~~~~~^~~~~~~~~~~~~
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |