답안 #753484

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
753484 2023-06-05T10:01:16 Z keta_tsimakuridze Arboras (RMI20_arboras) C++14
0 / 100
98 ms 29816 KB
#include<bits/stdc++.h>
#define f first
#define s second
#define int long long
#define pii pair<int,int>
using namespace std;
const int N = 2e5 + 5, mod = 1e9 + 7; // !
int dep[4 * N], lzd[4 * N], cur, st[N], ch[N], h[N], id[N], depths;
int sz[4 * N], in[N], out[N], timer, ans, lz[4 * N];
pii mx[N][2];
vector<pair<int,int>> V[N];
int p[N];
//////////Sigrmeebis SEG TREE//////////////////////////
void push_d(int u, int l, int r) {
    dep[u] += lzd[u];
    if(l != r) lzd[2 * u] += lzd[u], lzd[2 * u + 1] += lzd[u];
    lzd[u] = 0;
}
void upd_d(int u, int st, int en, int l, int r, int val) {
    if(lzd[u]) push_d(u, l, r);
    if(l > en || r < st) return;
    if(st <= l && r <= en) {
        lzd[u] = val;
        push_d(u, l, r);
        return;
    }
    int mid = (l + r) / 2;
    upd_d(2 * u, st, en, l, mid, val); upd_d(2 * u + 1, st, en, mid + 1, r, val);
    dep[u] = max(dep[2 * u], dep[2 * u + 1]);
}
int get(int u, int st, int en, int l, int r) {
    if(lzd[u]) push_d(u, l, r);
    if(l > en || r < st) return 0;
    if(st <= l && r <= en) return dep[u];
    int mid = (l + r) / 2;
    return max(get(2 * u, st, en, l, mid), get(2 * u + 1, st, en, mid + 1, r));
}
////////////////////////////////////////////////////////
bool cmp(pii x, pii y) {
    return (sz[x.f] > sz[y.f]);
}
int H[N];
void dfs0(int u) {
    sz[u] = 1;
    for(int i = 0; i < V[u].size(); i++) {
        int v = V[u][i].f;
        H[v] = H[u] + V[u][i].s;
        dfs0(v);
        sz[u] += sz[v];
    }
    sort(V[u].begin(), V[u].end(), cmp);
}

struct node {
    int cn, mn, mn2, ans;
} t[4 * N];
const int inf = 1e18;
////////////SEGMENT TREE BEATS/////////////////
void build(int u, int l, int r) {
    lz[u] = -inf;
    t[u].cn = r - l + 1; t[u].mn2 = 1e18;
    if(l == r) return;
    build(2 * u, l, (l + r) / 2); build(2 * u + 1, (l + r) / 2 + 1, r);
}
node merge(node a, node b) {
    if(a.cn == 0) return b;
    if(b.cn == 0) return a;
    if(a.mn > b.mn) swap(a, b);
    node c;
    c.ans = a.ans + b.ans;
    if(a.mn == b.mn) {
        c.mn = a.mn;
        c.cn = a.cn + b.cn;
        c.mn2 = min(a.mn2, b.mn2);
        return c;
    }
    c.cn = a.cn; c.mn = a.mn;
    c.mn2 = min(a.mn2, b.mn);
    return c;
}
int add[4 * N];
void push(int u, int l, int r) {
    t[u].ans -= t[u].mn * t[u].cn % mod;
    t[u].mn = max(t[u].mn, lz[u]);
    t[u].ans += t[u].mn * t[u].cn % mod;
    t[u].ans = (t[u].ans + mod) % mod;
    if(t[u].mn > lz[u]) return;
    if(l != r) {
        lz[2 * u] = max(lz[2 * u], lz[u] - add[2 * u]);
        lz[2 * u + 1] = max(lz[2 * u + 1], lz[u] - add[2 * u + 1]);
    }
    lz[u] = -inf;
}

void push2(int u, int l, int r) {
    t[u].mn += add[u];
    t[u].mn2 += add[u];
    t[u].ans += add[u] * (r - l + 1);
    if(l != r) {
        add[2 * u] += add[u];
        add[2 * u + 1] += add[u];
    }
    add[u] = 0;
}
void upd(int u, int st, int en, int l, int r, int v) {
    if(lz[u] != -inf) push(u, l, r);
    if(add[u]) push2(u, l, r);
    if(l > en || r < st || t[u].mn >= v) return;
    if(st <= l && r <= en && t[u].mn2 >= v) {

        lz[u] = v;
        push(u, l, r); //cout << u << " +++" << l << " " << r << " " << t[u].mn << " " << lz[2 * u] << " " << lz[2 * u + 1] << endl;
        return;
    }
    int mid = (l + r) / 2;
    upd(2 * u, st, en, l, mid, v); upd(2 * u + 1, st, en, mid + 1, r, v);
    t[u] = merge(t[2 * u], t[2 * u + 1]);
}
void go(int u, int l, int r) {
        if(lz[u] != -inf) push(u, l, r);
    if(add[u]) push2(u, l, r);
    if(l == r) {
        cout << l << " _ " << r << " " << t[u].ans << endl;
        return;
    }
    go(2 * u, l, (l + r )/ 2); go(2 * u + 1, (l + r)/ 2 + 1, r);
}
void upd2(int u, int st, int en, int l, int r, int v) {
    if(lz[u] != -inf) push(u, l, r);
    if(add[u]) push2(u, l, r);
    if(l > en || r < st) return;
    if(st <= l && r <= en)  {
        add[u] = v;
        push2(u, l, r);
        return;
    }
    int mid = (l + r) / 2;
    upd2(2 * u, st, en, l, mid, v); upd2(2 * u + 1, st, en, mid + 1, r, v);
    t[u] = merge(t[2 * u], t[2 * u + 1]);
}
///////////////////////////////////////////////////////////////////////////////////////
int n;
void dfs(int u) {
    if(cur && !st[cur]) st[cur] = u;
    ch[u] = cur; in[u] = ++timer;
    if(V[u].size()) {
        int v = V[u][0].f;
        dfs(v);
        h[u] = h[v] + V[u][0].s;
        upd_d(1, in[v], out[v], 1, n, V[u][0].s);
    }
    for(int i = 1; i < V[u].size(); i++) {
        int v = V[u][i].f, d = V[u][i].s;
        ++cur;
        dfs(v);
        upd_d(1, in[v], out[v], 1, n, d);
        if(mx[v][0].f + d > mx[u][0].f) swap(mx[u][0], mx[u][1]), mx[u][0] = {mx[v][0].f + d, v};
        else mx[u][1] = max(mx[u][1], {mx[v][0].f + d, v});
    }
    ans += mx[u][0].f; ans %= mod;
    upd(1, in[u], in[u], 1, n, max(mx[u][1].f + H[u], h[u]));
    out[u] = timer;
}
void up(int u, int d) {
    while(true) {
        upd(1, in[st[ch[u]]], in[u] - 1, 1, n, d); //cout << st[ch[u]] << " __ " << u << " _ " << d << endl;
        u = st[ch[u]];
        if(u == 0) return;
        int v = u;
        u = p[u];
        ans -= mx[u][0].f; ans = (ans + mod) % mod;
        if(mx[u][0].f <= mx[v][0].f + V[u][id[v]].s) {
            if(mx[u][0].s == v) mx[u][0].f = mx[v][0].f + V[u][id[v]].s;
            else swap(mx[u][0], mx[u][1]), mx[u][0] = {mx[v][0].f + V[u][id[v]].s, v};
        } else if(mx[u][0].s != v) mx[u][1] = max(mx[u][1],  {mx[v][0].f + V[u][id[v]].s, v});
        ans += mx[u][0].f; ans %= mod;
        upd(1, in[u], in[u], 1, n, mx[u][1].f + get(1, in[u], in[u], 1, n));
    }
}
main(){
//    int n;
    cin >> n;
    for(int i = 1; i < n; i++) cin >> p[i];
    int F = 0;
    for(int i = 1; i < n; i++) {
        int d;
        cin >> d;
        id[i] = (int)V[p[i]].size();
        V[p[i]].push_back({i, d});
    }
   dfs0(0);
   dfs(0);
   build(1, 1, n);
   depths %= mod;
   cout << (t[1].ans - depths + ans + mod) % mod << "\n";
   int q; cin >> q;
   set<int> s;
   s.insert({1, 0});
   while(q--) {
    int u, d;
    cin >> u >> d;
    V[p[u]][id[u]].s += d;
    depths += sz[u] * d % mod;
    depths %= mod;
    upd2(1, in[u], out[u], 1, n, d);
    upd_d(1, in[u], out[u], 1, n, d);
    d = get(1, in[u], out[u], 1, n);
    up(u, d);// go(1, 1, n);
    cout << (t[1].ans - depths + ans + mod) % mod << "\n";
   }

}

Compilation message

arboras.cpp: In function 'void dfs0(long long int)':
arboras.cpp:45:22: warning: comparison of integer expressions of different signedness: 'long long int' and 'std::vector<std::pair<long long int, long long int> >::size_type' {aka 'long unsigned int'} [-Wsign-compare]
   45 |     for(int i = 0; i < V[u].size(); i++) {
      |                    ~~^~~~~~~~~~~~~
arboras.cpp: In function 'void dfs(long long int)':
arboras.cpp:152:22: warning: comparison of integer expressions of different signedness: 'long long int' and 'std::vector<std::pair<long long int, long long int> >::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  152 |     for(int i = 1; i < V[u].size(); i++) {
      |                    ~~^~~~~~~~~~~~~
arboras.cpp: At global scope:
arboras.cpp:180:1: warning: ISO C++ forbids declaration of 'main' with no type [-Wreturn-type]
  180 | main(){
      | ^~~~
arboras.cpp: In function 'int main()':
arboras.cpp:184:9: warning: unused variable 'F' [-Wunused-variable]
  184 |     int F = 0;
      |         ^
# 결과 실행 시간 메모리 Grader output
1 Runtime error 10 ms 10452 KB Execution killed with signal 11
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Runtime error 72 ms 21988 KB Execution killed with signal 11
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Runtime error 98 ms 29816 KB Execution killed with signal 11
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Runtime error 10 ms 10452 KB Execution killed with signal 11
2 Halted 0 ms 0 KB -