Submission #971798

#TimeUsernameProblemLanguageResultExecution timeMemory
971798CookieWinter Driving (CCO19_day1problem3)C++14
25 / 25
104 ms59860 KiB
#include<bits/stdc++.h> #include<fstream> using namespace std; #define sz(a) (int)a.size() #define ALL(v) v.begin(), v.end() #define ALLR(v) v.rbegin(), v.rend() #define ll long long #define pb push_back #define forr(i, a, b) for(int i = a; i < b; i++) #define dorr(i, a, b) for(int i = a; i >= b; i--) #define ld long double #define vt vector #include<fstream> #define fi first #define se second #define pll pair<ll, ll> #define pii pair<int, int> #define mpp make_pair const ld PI = 3.14159265359, prec = 1e-9;; //using u128 = __uint128_t; //const int x[4] = {1, 0, -1, 0}; //const int y[4] = {0, -1, 0, 1}; const ll mod =1e9 + 7, pr = 31; const int mxn = 1e6 + 5, mxq = 1e5 + 5, sq = 500, mxv = 5e4 + 1; //const int base = (1 <<18); const ll inf = 1e13 + 5, neg = -69420, inf2 = 1e14; mt19937_64 rng(chrono::steady_clock::now().time_since_epoch().count()); // have fun! int n; ll a[mxn + 1], dp[mxn + 1], siz[mxn + 1]; int p[mxn + 1]; vt<int>adj[mxn + 1]; void dfs(int s, int pre, ll dep = 0){ siz[s] = a[s]; for(auto i: adj[s]){ if(i != pre){ dfs(i, s, dep + 1); siz[s] += siz[i]; } } dp[1] += a[s] * (siz[s] - a[s]); } void dfs2(int s, int pre){ for(auto i: adj[s]){ if(i != pre){ dp[i] = dp[s] + a[i] * (siz[1] - siz[i]) - siz[i] * a[s]; dfs2(i, s); } } } void get(int l, int r, ll tot, vt<ll>&v, vt<ll>&cand){ if(l == r){ v.pb(tot); return; } get(l + 1, r, tot + cand[l], v, cand); get(l + 1, r, tot, v, cand); } void solve(){ cin >> n; ll tot = 0; for(int i = 1; i <= n; i++){ cin >> a[i]; tot += (a[i] * (a[i] - 1)); } for(int i = 2; i <= n; i++){ cin >> p[i]; adj[p[i]].pb(i); adj[i].pb(p[i]); } dfs(1, -1); dfs2(1, -1); ll ans = 0; for(int i = 1; i <= n; i++){ //cout << dp[i] << " "; vt<ll>cand; for(auto j: adj[i]){ if(j == p[i]){ cand.pb(siz[1] - siz[i]); }else{ cand.pb(siz[j]); } } ll mx = *max_element(ALL(cand)); ll all = siz[1] - a[i]; if(mx >= all / 2){ ans = max(ans, mx * (all - mx) + dp[i]); }else{ // centroid vt<ll>fir, sec; get(0, sz(cand) / 2, 0, fir, cand); get(sz(cand) / 2, sz(cand), 0, sec, cand); sort(ALL(fir)); sort(ALL(sec)); int rp = sz(sec) - 1; for(int j = 0; j < sz(fir); j++){ while(rp >= 0 && fir[j] + sec[rp] >= all / 2){ ll tot = fir[j] + sec[rp]; ans = max(ans, tot * (all - tot) + dp[i]); rp--; } if(rp != -1){ ll tot = fir[j] + sec[rp]; ans = max(ans, tot * (all - tot) + dp[i]); } } } } cout << ans + tot << "\n"; } signed main(){ ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0); //freopen("THREE.inp", "r", stdin); //freopen("THREE.out", "w", stdout); int tt; tt = 1; while(tt--){ solve(); } return(0); }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...