제출 #642085

#제출 시각아이디문제언어결과실행 시간메모리
642085vovamrTransport (COCI19_transport)C++17
91 / 130
1085 ms21224 KiB
#include <bits/stdc++.h> #include <ext/pb_ds/assoc_container.hpp> #include <ext/pb_ds/tree_policy.hpp> #define fi first #define se second #define ll long long #define ld long double #define sz(x) ((int)(x).size()) #define all(x) (x).begin(), (x).end() #define pb push_back #define mpp make_pair #define ve vector using namespace std; using namespace __gnu_pbds; template<class T> using oset = tree<T,null_type,less_equal<T>,rb_tree_tag,tree_order_statistics_node_update>; const ll inf = 1e18; const int iinf = 1e9; typedef pair<ll, ll> pll; typedef pair<int, int> pii; mt19937_64 rng(chrono::steady_clock::now().time_since_epoch().count()); template <typename T> inline bool chmin(T& a, T b) { return (a > b ? a = b, 1 : 0); } template <typename T> inline bool chmax(T& a, T b) { return (a < b ? a = b, 1 : 0); } const int N = 1e5 + 10; ve<pii> gr[N]; int sz[N], used[N], a[N]; inline void dfs1(int v, int p) { sz[v] = 1; for (auto &[to, w] : gr[v]) { if (used[to] || to == p) continue; dfs1(to, v); sz[v] += sz[to]; } } inline int centroid(int v, int p, int n) { for (auto &[to, w] : gr[v]) { if (used[to] || to == p || sz[to] <= n / 2) continue; return centroid(to, v, n); } return v; } ll ans = 0; int PTR = 0; pll al[N]; inline void dfs(int v, int p, ll cur_sum, ll mn_sum) { if (mn_sum >= 0) ++ans; al[PTR++] = {cur_sum, mn_sum}; for (auto &[to, w] : gr[v]) { if (to == p || used[to]) continue; ll s = cur_sum + a[v] - w; dfs(to, v, s, min(mn_sum, s)); } } ll t[4 * N], p[4 * N]; inline int mg(const ll &a, const ll &b) { return min(a, b); } inline void mg(const int &v) { t[v] = mg(t[2 * v + 1], t[2 * v + 2]); } inline void push(int v) { if (!p[v]) return; for (int u : {2 * v + 1, 2 * v + 2}) { t[u] += p[v], p[u] += p[v]; } p[v] = 0; } inline void upd(int v, int vl, int vr, int l, int r, ll x) { if (l > r) return; else if (vl == l && vr == r) { t[v] += x, p[v] += x; return; } push(v); int m = vl + vr >> 1; upd(2 * v + 1, vl, m, l, min(r, m), x); upd(2 * v + 2, m + 1, vr, max(l, m + 1), r, x); mg(v); } inline ll get(int v, int vl, int vr, int l, int r) { if (l > r) return inf; else if (vl == l && vr == r) return t[v]; push(v); int m = vl + vr >> 1; return min(get(2 * v + 1, vl, m, l, min(r, m)), get(2 * v + 2, m + 1, vr, max(l, m + 1), r)); } int n; oset<ll> sums; inline void dfs2(int v, int p, int pw, int cnt) { upd(0, 0, n - 1, 0, cnt - 1, a[v] - pw); ll mn_sum = get(0, 0, n - 1, 0, cnt); ll tot = get(0, 0, n - 1, 0, 0); if (mn_sum >= 0) { ans += 1 + sz(sums) - sums.order_of_key(-tot); } // number of [sum, mn_sum] s.t. // mn_sum + tot >= 0 => mn_sum >= -tot for (auto &[to, w] : gr[v]) { if (to == p || used[to]) continue; dfs2(to, v, w, cnt + 1); } upd(0, 0, n - 1, 0, cnt - 1, pw - a[v]); } inline void cd(int v, int p) { used[v] = 1; dfs1(v, p); PTR = 0; for (auto &[to, w] : gr[v]) { if (to == p || used[to]) continue; dfs(to, v, a[v] - w, a[v] - w); } sums.clear(); for (int i = 0; i < PTR; ++i) { auto &[sum, mn_sum] = al[i]; sums.insert(mn_sum); } int pos = 0; for (auto &[to, w] : gr[v]) { if (to == p || used[to]) continue; for (int i = pos; i < pos + sz[to]; ++i) { auto &[sum, mn_sum] = al[i]; sums.erase(sums.lower_bound(mn_sum - 1)); } dfs2(to, v, w, 1); for (int i = pos; i < pos + sz[to]; ++i) { auto &[sum, mn_sum] = al[i]; sums.insert(mn_sum); } pos += sz[to]; } for (auto &[to, w] : gr[v]) { if (to == p || used[to]) continue; int c = centroid(to, v, sz[to]); cd(c, v); } } inline void solve() { cin >> n; for (int i = 0; i < n; ++i) cin >> a[i]; for (int i = 1; i < n; ++i) { int v, u, w; cin >> v >> u >> w, --v, --u; gr[v].pb({u, w}), gr[u].pb({v, w}); } dfs1(0, 0); int c = centroid(0, 0, n); cd(c, c); cout << ans; } signed main() { ios_base::sync_with_stdio(false); cin.tie(0); cout.tie(0); int q = 1; // cin >> q; while (q--) solve(); cerr << fixed << setprecision(3) << "Time execution: " << (double)clock() / CLOCKS_PER_SEC << endl; }

컴파일 시 표준 에러 (stderr) 메시지

transport.cpp: In function 'void upd(int, int, int, int, int, long long int)':
transport.cpp:76:13: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
   76 |  int m = vl + vr >> 1;
      |          ~~~^~~~
transport.cpp: In function 'long long int get(int, int, int, int, int)':
transport.cpp:85:13: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
   85 |  int m = vl + vr >> 1;
      |          ~~~^~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...