Submission #1288936

#TimeUsernameProblemLanguageResultExecution timeMemory
1288936tschav_말 (IOI15_horses)C++20
20 / 100
929 ms44848 KiB
#include <bits/stdc++.h>
using namespace std;
using ll = unsigned long long;

template<typename T>
class SegTree {
public:
    int size;
    vector<T> t;
    function<T(T, T)> f;
    T init;
 
    SegTree(int n, T init, function<T(T, T)> f, vector<T> &a): init(init), f(f) {
        size = 1;
        while(size < n) size <<= 1;
        t.assign(size << 1, init);
        build(a,0,0,size);
    }
 
    void build(vector<T> &a, int pos, int tl, int tr){
        if(tr - tl == 1) {
            if(tl < a.size()){
                t[pos] = a[tl];
            }
        } else {
            int tm = (tr + tl) >> 1;
            build(a, 2*pos+1, tl, tm);
            build(a, 2*pos+2, tm, tr);
            t[pos] = f(t[2*pos+1],t[2*pos+2]);
        }
    }
 
    void build(vector<T> &a) {
        build(a,0,0,size);
    }
 
    void update(int i, T val, int pos, int tl, int tr) {
        if(tr - tl == 1){
            t[pos] = val;
            return;
        }
        int mid = (tl + tr) >> 1;
        if(i < mid){
            update(i,val,2*pos+1,tl,mid);
        } else {
            update(i,val,2*pos+2,mid,tr);
        }
        t[pos] = f(t[2*pos+1],t[2*pos+2]);
    }
 
    void update(int i, T x) {
        update(i,x,0,0,size);
    }
 
    T query(int l, int r, int pos, int tl, int tr) {
        if (r <= tl or tr <= l) return init;
        if (l <= tl and tr <= r) return t[pos];
 
        int tm = (tl + tr) >> 1;
 
        return f(query(l,r,2*pos+1,tl,tm),query(l,r,2*pos+2,tm,tr));
    }
 
    T query(int l, int r) {
        return query(l,r,0,0,size);
    }
};
vector<ll> em;
SegTree<ll> stree = SegTree<ll>(0ll,0ll,[](ll x,ll y) {return max(x,y);},em);
vector<ll> x, y;
int n;
int m;
ll prod = 1;
int idx = -1;
set<int> st;
static const ll MOD = 1e9+7;
ll fullprod = 1;

ll binpow(ll a, ll b, ll m) {
    a %= m;
    ll res = 1;
    while (b > 0) {
        if (b & 1)
            res = res * a % m;
        a = a * a % m;
        b >>= 1;
    }
    return res;
}

ll gety(int id) {
    int I = id;
    int J = n;
    auto it = st.upper_bound(id);

    if(it != st.end()){
        J = *it;
    }

    if(id == *st.begin()) {
        I = 0;
    }

    ll mn = stree.query(I,J);
    return mn;
}

ll small() {
    ll mx = 0; 
    ll curr = 1;

    for(int id : st) {

        curr *= x[id]; 

        ll Y = gety(id);
        mx = max(mx, curr * Y);
    }
    return (mx % MOD);
}

ll solve() {
    if(idx == -1) {
        return small();
    }
    if (m == 0) {
        ll mx = stree.query(0, n);
        ll ans = mx % MOD;
        return (ans * (prod % MOD)) % MOD;
    }
    // cout << prod << "\n";
    // cout << idx << "\n";
    ll mx = 0ll;


    if (idx != -1) mx = gety(idx);

    ll curr = 1;

    auto rit = st.rbegin();
    
    vector<int> stuff;
    vector<int> Q;
    vector<int> ret;

    for (int i = m - 1; rit != st.rend(); ++rit, --i) {

        if(i == idx) {
            mx = gety(*rit);
            break;
        }

        if (i < idx) break; 

        stuff.emplace_back(x[*rit]);
        Q.emplace_back(*rit);
        ret.emplace_back(i);
    }
    reverse(stuff.begin(),stuff.end());
    reverse(ret.begin(),ret.end());
    reverse(Q.begin(),Q.end());
    for(int i = 0; i < stuff.size() ; ++i){
        curr *= stuff[i];
        ll G = gety(Q[i]);
        ll val = curr * G;
        if (val > mx) mx = val;
    }

    ll ans = mx % MOD;
    return (ans * (prod % MOD)) % MOD;
}

void ind() {
    
    idx = -1;
    ll P = 1ll;
    int i = m - 1;
    for (auto it = st.rbegin(); it != st.rend(); ++it, --i) {
        int pos = *it;
        if (P > 1e9) {
            break;
        } else {
            P *= x[pos];
            idx = i;
        }
    }
    if(P < 1e9) {
        idx = -1;
    }
}

void getprod() {
    prod = 1ll;
    fullprod = 1ll;
    for (auto it = st.begin(); it != st.end(); ++it) {
        fullprod = (fullprod * (x[*it] % MOD)) % MOD;
    }
    //cerr << endl << fullprod << "-->\n";
    if (idx < 0) return;
    auto it = st.begin();
    for (int i = 0; i <= idx and it != st.end(); ++i, ++it) {
        prod = (prod * (x[*it] % MOD)) % MOD;
    }
}

void updateprod() {
    prod = 1ll;
    if (idx < 0) return;
    ll curr = 1ll;
    auto rit = st.rbegin();
    for (int i = m - 1; rit != st.rend(); ++rit, --i) {
        if (i <= idx) break; 
        curr *= x[*rit];
        curr %= MOD;
    }
    prod = (fullprod * binpow(curr, MOD-2, MOD)) % MOD;
}

int init(int N, int X[], int Y[]) {
    st.clear();
    x.clear();
    y.clear();
    prod = 1;
    fullprod = 1;
    idx = -1;
	n = N;
	x.resize(n,0ll);
	y.resize(n,0ll);
	for(int i = 0; i < n; ++i) {
		x[i] = X[i];
		y[i] = Y[i];
		if(x[i] >= 2){
			st.insert(i);
		}
	}

    stree = SegTree<ll>(n,0,[&](ll i, ll j) {
        return max(i,j);
    }, y);

    if(!st.empty() and *st.begin() != 0) {
        st.insert(0);
    }
	m = st.size();
    

	ind();
	getprod();
	//cerr << fullprod << "->\n";
    return int(solve());
}

int updateX(int pos, int val) {
	ll old = x[pos];

	if(pos == 0) {
        // . . . 
    } else {
        if(val == 1 and old != 1) {
            st.erase(pos);
        }
        if(val != 1 and old == 1) {
            st.insert(pos);
        }
    }
    m = st.size();

	
	//int I = *(next(st.begin(),idx));
	// if(pos <= I) {
    //     prod = (prod * binpow(old, MOD-2,MOD)) % MOD;
    //     prod = (prod * val) % MOD;
	// }
    x[pos] = val;
    ind();
    // ll oldpr1 = 1ll;
    // for (auto it = st.begin(); it != st.end(); ++it) {
    //     oldpr1 = (oldpr1 * (x[*it] % MOD)) % MOD;
    // }
    fullprod = (fullprod * binpow(old, MOD-2,MOD)) % MOD;
    fullprod = (fullprod * val) % MOD;
    // if(fullprod != oldpr1) {
    //     cerr << old << " " << val << "\n";
    //     cerr << fullprod << " " << oldpr1 << "\n";
    //     assert(0);
    // }
	// ll fullprod1 = 1ll;
    // for (auto it = st.begin(); it != st.end(); ++it) {
    //     fullprod1 = (fullprod1 * (x[*it] % MOD)) % MOD;
    // }
    // if(fullprod1 != fullprod) {
    //     cerr << oldpr1 << "\n";
    //     cerr << (oldpr1 * binpow(old, MOD-2,MOD)) % MOD << "\n";
    //     cerr << old << " " << val << "\n";
    //     cerr << fullprod1 << " " << fullprod << "\n";
    // }
    // assert(fullprod1 == fullprod);
    updateprod();
    //cout << idx << "->";
    //cout << m << " ";
	return int(solve());
}

int updateY(int pos, int val) {
	y[pos] = val;
    stree.update(pos,val);
	ind();
	return int(solve());
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...