Submission #794304

#TimeUsernameProblemLanguageResultExecution timeMemory
794304qwerasdfzxclTravelling Trader (CCO23_day2problem2)C++17
11 / 25
183 ms34732 KiB
#include <bits/stdc++.h> using namespace std; typedef long long ll; int n; vector<int> adj[200200]; ll a[200200]; namespace foo{ ll dp[200200]; int nxt[200200]; void dfs(int s, int pa = 0){ dp[s] = a[s], nxt[s] = 0; for (auto &v:adj[s]) if (v!=pa){ dfs(v, s); if (dp[s] < dp[v] + a[s]){ dp[s] = dp[v] + a[s]; nxt[s] = v; } } } void solve1(){ dfs(1); printf("%lld\n", dp[1]); vector<int> V; for (int i=1;i;i=nxt[i]){ V.push_back(i); } printf("%d\n", (int)V.size()); for (auto &x:V) printf("%d ", x); printf("\n"); } } // namespace foo namespace bar{ constexpr ll INF = 4e18; ll dp[200200][2][3]; int v1[200200], v2[200200][2], v3[200200][3], typ2[200200], typ3[200200], nxt[200200]; vector<int> G[200200]; void dfs(int s, int pa = 0){ vector<pair<ll, int>> C1, C2, C3; for (int i=1;i<=3;i++){ C1.emplace_back(0, -i); C2.emplace_back(0, -i); C3.emplace_back(0, -i); } ll sum = 0; int deg = 0; for (auto &v:adj[s]) if (v!=pa){ G[s].push_back(v); dfs(v, s); sum += a[v]; deg++; C1.emplace_back(dp[v][1][0] - a[v], v); C2.emplace_back(dp[v][0][2] - a[v], v); C3.emplace_back(dp[v][1][2], v); } sort(C1.begin(), C1.end(), greater<pair<ll, int>>()); sort(C2.begin(), C2.end(), greater<pair<ll, int>>()); sort(C3.begin(), C3.end(), greater<pair<ll, int>>()); // calc 0 1 dp[s][0][1] = a[s] + sum; dp[s][0][1] += C1[0].first; v1[s] = C1[0].second; if (deg==0) dp[s][0][1] = -INF; // calc 0 2 dp[s][0][2] = -INF; for (int i=0;i<2;i++){ for (int j=0;j<2;j++) if (C1[i].second != C2[j].second){ ll val = a[s] + sum + C1[i].first + C2[j].first; if (dp[s][0][2] < val){ dp[s][0][2] = val; typ2[s] = 2; v2[s][0] = C1[i].second; v2[s][1] = C2[j].second; } } } if (dp[s][0][2] < a[s] + C3[0].first){ dp[s][0][2] = a[s] + C3[0].first; typ2[s] = 1; v2[s][0] = C3[0].second; } // calc 1 2 dp[s][1][2] = -INF; vector<pair<ll, int>> C4 = C1; for (auto &v:adj[s]) if (v!=pa) C4.emplace_back(0, n+v); sort(C4.begin(), C4.end(), greater<pair<ll, int>>()); for (int i=0;i<min(6, (int)C4.size());i++){ for (int j=0;j<3;j++){ for (int k=0;k<3;k++){ int rc4 = C4[i].second; if (rc4 > n) rc4 -= n; if (rc4==C1[j].second || rc4==C2[k].second || C1[j].second==C2[k].second) continue; if (rc4 < 0) continue; ll val = a[s] + sum + C4[i].first + C1[j].first + C2[k].first; if (dp[s][1][2] < val){ dp[s][1][2] = val; typ3[s] = 3; v3[s][0] = C4[i].second; v3[s][1] = C1[j].second; v3[s][2] = C2[k].second; } } } } // if (s==5) printf(" (%lld, %d) / (%lld, %d)\n", C3[0].first, C3[0].second, C3[1].first, C3[1].second); // if (s==5) printf(" (%lld, %d) / (%lld, %d)\n", C4[0].first, C4[0].second, C4[1].first, C4[1].second); for (int i=0;i<min(4, (int)C4.size());i++){ for (int j=0;j<2;j++) if (C4[i].second != C3[j].second && C4[i].second > 0){ int rc4 = C4[i].second; if (rc4 > n) rc4 -= n; if (rc4==C3[j].second) continue; if (rc4 < 0) continue; ll val = a[s] + C4[i].first + C3[j].first; if (C4[i].second > 0) val += a[rc4]; // printf("ok fuck %lld + %lld + %lld -> %lld\n", a[s], C4[i].first, C3[j].first, val); if (dp[s][1][2] < val){ dp[s][1][2] = val; typ3[s] = 2; v3[s][0] = C4[i].second; v3[s][1] = C3[j].second; } } } if (deg==0) dp[s][1][2] = -INF; // calc 1 0 dp[s][1][0] = dp[s][0][1]; if (dp[s][1][2] < dp[s][1][0]) dp[s][1][2] = dp[s][1][0], typ3[s] = 0; // printf("\nok %d done / sum = %lld / deg = %d / val = %lld\n", s, sum, deg, a[s]); // printf("0 1: %lld / v1 = %d\n", dp[s][0][1], v1[s]); // printf("1 0: %lld / v1 = %d\n", dp[s][1][0], v1[s]); // printf("0 2: %lld / typ2 = %d / v2 = %d %d\n", dp[s][0][2], typ2[s], v2[s][0], v2[s][1]); // printf("1 2: %lld / typ3 = %d / v3 = %d %d %d\n", dp[s][1][2], typ3[s], v3[s][0], v3[s][1], v3[s][2]); } pair<int, int> track(int s, int x, int y){ // printf(" call %d %d %d\n", s, x, y); if (x==0 && y==0) return {s, s}; assert(s > 0); assert(dp[s][x][y] >= 0); if (x==1 && y==2 && typ3[s]==0) y = 0; if (x==0 && y==1){ pair<int, int> ret = {s, s}; if (v1[s] > 0){ auto [l1, r1] = track(v1[s], 1, 0); nxt[s] = l1; ret.second = r1; G[s].erase(find(G[s].begin(), G[s].end(), v1[s])); } for (auto &v:G[s]){ nxt[ret.second] = v; ret.second = v; } return ret; } if (x==0 && y==2){ pair<int, int> ret = {s, s}; if (typ2[s]==1){ if (v2[s][0] > 0){ auto [l1, r1] = track(v2[s][0], 1, 2); nxt[s] = l1; ret.second = r1; } return ret; } // else (typ2[s]==2) if (v2[s][0] > 0){ auto [l1, r1] = track(v2[s][0], 1, 0); nxt[s] = l1; ret.second = r1; G[s].erase(find(G[s].begin(), G[s].end(), v2[s][0])); } if (v2[s][1] > 0){ G[s].erase(find(G[s].begin(), G[s].end(), v2[s][1])); } for (auto &v:G[s]){ nxt[ret.second] = v; ret.second = v; } if (v2[s][1] > 0){ nxt[ret.second] = v2[s][1]; ret.second = v2[s][1]; auto [l2, r2] = track(v2[s][1], 0, 2); ret.second = r2; } return ret; } if (x==1 && y==0){ pair<int, int> ret = {s, s}; if (v1[s] > 0){ auto [l1, r1] = track(v1[s], 0, 1); nxt[r1] = s; ret.first = l1; G[s].erase(find(G[s].begin(), G[s].end(), v1[s])); } for (auto &v:G[s]){ nxt[v] = ret.first; ret.first = v; } assert(ret.first != s); return ret; } assert(x==1 && y==2); pair<int, int> ret = {s, s}; if (typ3[s]==2){ if (v3[s][0] > 0){ int rv3 = v3[s][0], ri = 0, rj = 1; if (v3[s][0] > n) rv3 = v3[s][0]-n, ri = 0, rj = 0; auto [l1, r1] = track(rv3, ri, rj); ret.first = l1; nxt[r1] = s; G[s].erase(find(G[s].begin(), G[s].end(), rv3)); } if (v3[s][1] > 0){ auto [l2, r2] = track(v3[s][1], 1, 2); nxt[s] = l2; ret.second = r2; } assert(ret.first != s); return ret; } //typ3[s] == 3 if (v3[s][0] > 0){ int rv3 = v3[s][0], ri = 0, rj = 1; if (v3[s][0] > n) rv3 = v3[s][0]-n, ri = 0, rj = 0; auto [l1, r1] = track(rv3, ri, rj); ret.first = l1; nxt[r1] = s; G[s].erase(find(G[s].begin(), G[s].end(), rv3)); } if (v3[s][1] > 0){ auto [l2, r2] = track(v3[s][1], 1, 0); nxt[s] = l2; ret.second = r2; G[s].erase(find(G[s].begin(), G[s].end(), v3[s][1])); } if (v3[s][2] > 0){ G[s].erase(find(G[s].begin(), G[s].end(), v3[s][2])); } for (auto &v:G[s]){ nxt[ret.second] = v; ret.second = v; } if (v3[s][2] > 0){ auto [l3, r3] = track(v3[s][2], 0, 2); nxt[ret.second] = l3; ret.second = r3; } assert(ret.first != s); return ret; } ll solve2(){ for (int i=1;i<=n;i++) nxt[i] = 0, G[i].clear(); dfs(1); ll ans = dp[1][0][2]; printf("%lld\n", ans); track(1, 0, 2); ll fuck = 0; vector<int> V; for (int i=1;i;i=nxt[i]){ V.push_back(i); fuck += a[i]; // printf(" %d %lld", i, fuck); } // printf("\n"); assert(ans == fuck); printf("%d\n", (int)V.size()); for (auto &x:V) printf("%d ", x); printf("\n"); return ans; } } // namespace bar namespace baz{ void dfs1(int s, int pa); void dfs2(int s, int pa); vector<int> V; void dfs1(int s, int pa){ V.push_back(s); for (auto &v:adj[s]) if (v!=pa) dfs2(v, s); } void dfs2(int s, int pa){ for (auto &v:adj[s]) if (v!=pa) dfs1(v, s); V.push_back(s); } void solve3(){ dfs1(1, 0); ll sum = 0; for (int i=1;i<=n;i++) sum += a[i]; printf("%lld\n", sum); printf("%d\n", n); for (auto &x:V) printf("%d ", x); printf("\n"); } } // namespace baz namespace naive { constexpr int inf = 1e9 + 100; int dp[1<<20][20][2], arr[20]; vector<int> gph[20]; int solve() { int N = n; for (int i = 0; i < N; i++) gph[i].clear(); for (int i = 1; i <= N; i++) { for (int j:adj[i]) { gph[i-1].push_back(j-1); } } for (int i = 1; i <= N; i++) arr[i-1] = a[i]; for (int i = 0; i < (1 << N); i++) { for (int j = 0; j < N; j++) { dp[i][j][0] = dp[i][j][1] = -inf; } } dp[1][0][0] = arr[0]; for (int i = 1; i < (1 << N); i++) { for (int j = 0; j < N; j++) { for (int k:gph[j]) { dp[i][k][1] = max(dp[i][k][1], dp[i][j][0]); } } for (int j = 0; j < N; j++) { for (int k:gph[j]) { if (!(i & (1 << k))) { int nxt = i | (1 << k); dp[nxt][k][0] = max(dp[nxt][k][0], max(dp[i][j][0], dp[i][j][1]) + arr[k]); } } } } int ans = 0; for (int i = 0; i < (1 << N); i++) { for (int j = 0; j < N; j++) { ans = max(ans, dp[i][j][0]); ans = max(ans, dp[i][j][1]); } } return ans; } } // namespace naive struct DSU{ int path[100100]; void init(int n){for (int i=1;i<=n;i++) path[i] = i;} int find(int s){ if (s==path[s]) return s; return path[s] = find(path[s]); } int merge(int s, int v){ s = find(s), v = find(v); if (s==v) return 0; path[v] = s; return 1; } }dsu; mt19937 seed(1557); uniform_int_distribution<int> rng(0, 2147483647); int getrand(int l, int r){return rng(seed) % (r-l+1) + l;} void gen(){ n = getrand(1, 10); dsu.init(n); for (int i=1;i<=n;i++) adj[i].clear(); for (int i=1;i<=n;i++) a[i] = getrand(1, 1); for (int i=1;i<=n-1;i++){ int x, y; do{ x = getrand(1, n); y = getrand(1, n); }while(dsu.find(x)==dsu.find(y)); dsu.merge(x, y); adj[x].push_back(y); adj[y].push_back(x); } } void stress(int tc){ // printf("------------------------\n"); // printf("Stress #%d\n", tc); // int ans = gen(); gen(); // printf("Input:\n"); // printf("%d 2\n", n); // for (int i=1;i<=n;i++){ // for (auto &j:adj[i]) if (j<i) printf("%d %d\n", j, i); // } // for (int i=1;i<=n;i++) printf("%lld ", a[i]); // printf("\n"); // int out = solve(); // ans = naive(ans); // assert(ans < INF); ll ans = naive::solve(); ll out = bar::solve2(); // printf("Answer: %lld\n", ans); // printf("Output: %lld\n", out); if (ans!=out){ printf("------------------------\n"); printf("Stress #%d\n", tc); printf("Input:\n"); printf("%d 2\n", n); for (int i=1;i<=n;i++){ for (auto &j:adj[i]) if (j<i) printf("%d %d\n", j, i); } for (int i=1;i<=n;i++) printf("%lld ", a[i]); printf("\n"); printf("Answer: %lld\n", ans); printf("Output: %lld\n", out); } assert(ans==out); // if (!(out==ans && out!=-1)){ // printf("------------------------\n"); // printf("Stress #%d\n", tc); // printf("Input:\n"); // printf("%s\n%s\n", s+1, t+1); // printf("Answer: %d\n", ans); // printf("Output: %d\n", out); // } // assert(out == ans && out != -1); if (tc%10000==0) printf("ok %d done\n", tc); } int main(){ // for (int i=1;;i++) stress(i); int k; scanf("%d %d", &n, &k); for (int i=1;i<=n-1;i++){ int x, y; scanf("%d %d", &x, &y); adj[x].push_back(y); adj[y].push_back(x); } for (int i=1;i<=n;i++) scanf("%lld", a+i); if (k==1) foo::solve1(); else if (k==2) bar::solve2(); else baz::solve3(); }

Compilation message (stderr)

Main.cpp: In function 'int main()':
Main.cpp:547:7: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
  547 |  scanf("%d %d", &n, &k);
      |  ~~~~~^~~~~~~~~~~~~~~~~
Main.cpp:551:8: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
  551 |   scanf("%d %d", &x, &y);
      |   ~~~~~^~~~~~~~~~~~~~~~~
Main.cpp:556:30: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
  556 |  for (int i=1;i<=n;i++) scanf("%lld", a+i);
      |                         ~~~~~^~~~~~~~~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...