제출 #1288934

#제출 시각아이디문제언어결과실행 시간메모리
1288934tschav_말 (IOI15_horses)C++20
54 / 100
1596 ms44824 KiB
#include <bits/stdc++.h> using namespace std; using ll = unsigned long long; template<typename T> class SegTree { public: int size; vector<T> t; function<T(T, T)> f; T init; SegTree(int n, T init, function<T(T, T)> f, vector<T> &a): init(init), f(f) { size = 1; while(size < n) size <<= 1; t.assign(size << 1, init); build(a,0,0,size); } void build(vector<T> &a, int pos, int tl, int tr){ if(tr - tl == 1) { if(tl < a.size()){ t[pos] = a[tl]; } } else { int tm = (tr + tl) >> 1; build(a, 2*pos+1, tl, tm); build(a, 2*pos+2, tm, tr); t[pos] = f(t[2*pos+1],t[2*pos+2]); } } void build(vector<T> &a) { build(a,0,0,size); } void update(int i, T val, int pos, int tl, int tr) { if(tr - tl == 1){ t[pos] = val; return; } int mid = (tl + tr) >> 1; if(i < mid){ update(i,val,2*pos+1,tl,mid); } else { update(i,val,2*pos+2,mid,tr); } t[pos] = f(t[2*pos+1],t[2*pos+2]); } void update(int i, T x) { update(i,x,0,0,size); } T query(int l, int r, int pos, int tl, int tr) { if (r <= tl or tr <= l) return init; if (l <= tl and tr <= r) return t[pos]; int tm = (tl + tr) >> 1; return f(query(l,r,2*pos+1,tl,tm),query(l,r,2*pos+2,tm,tr)); } T query(int l, int r) { return query(l,r,0,0,size); } }; vector<ll> em; SegTree<ll> stree = SegTree<ll>(0ll,0ll,[](ll x,ll y) {return max(x,y);},em); vector<ll> x, y; int n; int m; ll prod = 1; int idx = -1; set<int> st; static const ll MOD = 1e9+7; ll fullprod = 1; ll binpow(ll a, ll b, ll m) { a %= m; ll res = 1; while (b > 0) { if (b & 1) res = res * a % m; a = a * a % m; b >>= 1; } return res; } ll gety(int id) { int I = id; int J = n; auto it = st.upper_bound(id); if(it != st.end()){ J = *it; } if(id == *st.begin()) { I = 0; } ll mn = stree.query(I,J); return mn; } ll small() { ll mx = 1; ll curr = 1; for(int i = 0; i < n; ++i) { curr *= x[i]; mx = max(curr * y[i], mx); } return (mx % MOD); } ll solve() { if(idx == -1) { return small(); } if (m == 0) { ll mx = 0; for (int i = 0; i < n; ++i) mx = max(mx, y[i]); ll ans = mx % MOD; return (ans * (prod % MOD)) % MOD; } // cout << prod << "\n"; // cout << idx << "\n"; ll mx = 0ll; if (idx != -1) mx = gety(idx); ll curr = 1; auto rit = st.rbegin(); vector<int> stuff; vector<int> Q; vector<int> ret; for (int i = m - 1; rit != st.rend(); ++rit, --i) { if(i == idx) { mx = gety(*rit); break; } if (i < idx) break; stuff.emplace_back(x[*rit]); Q.emplace_back(*rit); ret.emplace_back(i); } reverse(stuff.begin(),stuff.end()); reverse(ret.begin(),ret.end()); reverse(Q.begin(),Q.end()); for(int i = 0; i < stuff.size() ; ++i){ curr *= stuff[i]; ll G = gety(Q[i]); ll val = curr * G; if (val > mx) mx = val; } ll ans = mx % MOD; return (ans * (prod % MOD)) % MOD; } void ind() { idx = -1; ll P = 1ll; int i = m - 1; for (auto it = st.rbegin(); it != st.rend(); ++it, --i) { int pos = *it; if (P > 1e9) { break; } else { P *= x[pos]; idx = i; } } if(P < 1e9) { idx = -1; } } void getprod() { prod = 1ll; fullprod = 1ll; for (auto it = st.begin(); it != st.end(); ++it) { fullprod = (fullprod * (x[*it] % MOD)) % MOD; } //cerr << endl << fullprod << "-->\n"; if (idx < 0) return; auto it = st.begin(); for (int i = 0; i <= idx and it != st.end(); ++i, ++it) { prod = (prod * (x[*it] % MOD)) % MOD; } } void updateprod() { prod = 1ll; if (idx < 0) return; ll curr = 1ll; auto rit = st.rbegin(); for (int i = m - 1; rit != st.rend(); ++rit, --i) { if (i <= idx) break; curr *= x[*rit]; curr %= MOD; } prod = (fullprod * binpow(curr, MOD-2, MOD)) % MOD; } int init(int N, int X[], int Y[]) { st.clear(); x.clear(); y.clear(); prod = 1; fullprod = 1; idx = -1; n = N; x.resize(n,0ll); y.resize(n,0ll); for(int i = 0; i < n; ++i) { x[i] = X[i]; y[i] = Y[i]; if(x[i] >= 2){ st.insert(i); } } stree = SegTree<ll>(n,0,[&](ll i, ll j) { return max(i,j); }, y); if(!st.empty() and *st.begin() != 0) { st.insert(0); } m = st.size(); ind(); getprod(); //cerr << fullprod << "->\n"; return int(solve()); } int updateX(int pos, int val) { ll old = x[pos]; if(pos == 0) { // . . . } else { if(val == 1 and old != 1) { st.erase(pos); } if(val != 1 and old == 1) { st.insert(pos); } } m = st.size(); //int I = *(next(st.begin(),idx)); // if(pos <= I) { // prod = (prod * binpow(old, MOD-2,MOD)) % MOD; // prod = (prod * val) % MOD; // } x[pos] = val; ind(); // ll oldpr1 = 1ll; // for (auto it = st.begin(); it != st.end(); ++it) { // oldpr1 = (oldpr1 * (x[*it] % MOD)) % MOD; // } fullprod = (fullprod * binpow(old, MOD-2,MOD)) % MOD; fullprod = (fullprod * val) % MOD; // if(fullprod != oldpr1) { // cerr << old << " " << val << "\n"; // cerr << fullprod << " " << oldpr1 << "\n"; // assert(0); // } // ll fullprod1 = 1ll; // for (auto it = st.begin(); it != st.end(); ++it) { // fullprod1 = (fullprod1 * (x[*it] % MOD)) % MOD; // } // if(fullprod1 != fullprod) { // cerr << oldpr1 << "\n"; // cerr << (oldpr1 * binpow(old, MOD-2,MOD)) % MOD << "\n"; // cerr << old << " " << val << "\n"; // cerr << fullprod1 << " " << fullprod << "\n"; // } // assert(fullprod1 == fullprod); updateprod(); //cout << idx << "->"; //cout << m << " "; return int(solve()); } int updateY(int pos, int val) { y[pos] = val; stree.update(pos,val); ind(); return int(solve()); }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...