#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 1e5 + 10;
const ll mod = 1e9 + 7;
int n, q, par[maxn], lvl[maxn];
ll d[maxn], dp[maxn][2], cl[maxn][2];
vector < pair < int, ll > > g[maxn];
void dfs(int v)
{
for (pair < int, ll > nb : g[v])
{
lvl[nb.first] = lvl[v] + 1;
dfs(nb.first);
ll len = nb.second + dp[nb.first][0];
if (len > dp[v][0])
{
dp[v][1] = dp[v][0];
cl[v][1] = cl[v][0];
dp[v][0] = len;
cl[v][0] = nb.first;
}
else
if (len > dp[v][1])
{
dp[v][1] = len;
cl[v][1] = nb.first;
}
}
}
ll ans;
void update(int idx, ll dx)
{
/**int cur_big = idx;
while(cur_big != 0 && cl[par[cur_big]][0] == cur_big)
cur_big = par[cur_big];
int cur_sec = cur_big;
while(cur_sec != 0 && cl[par[cur_sec]][1] == cur_sec)
cur_sec = par[cur_sec];
int new_big = cur_big;
while(new_big != 0 && dp[new_big][0] + dx > dp[par[new_big]][0])
new_big = par[new_big];
int new_sec = cur_sec;
while(new_sec != 0 && dp[new_sec][0] + dx > dp[par[new_sec]][1])*/
int cur = idx;
d[idx] += dx;
while(true)
{
if (cur == 0)
break;
///cout << cur << " " << par[cur] << " " << d[cur] << " " << dp[par[cur]][1] << endl;
if (dp[cur][0] + d[cur] > dp[par[cur]][0])
{
if (cl[par[cur]][0] == cur)
{
ans = ans - dp[par[cur]][0];
dp[par[cur]][0] = dp[cur][0] + d[cur];
ans = ans + dp[par[cur]][0];
}
else
if (cl[par[cur]][1] == cur)
{
ans = ans - dp[par[cur]][1];
swap(cl[par[cur]][0], cl[par[cur]][1]);
swap(dp[par[cur]][0], dp[par[cur]][1]);
dp[par[cur]][0] = dp[cur][0] + d[cur];
ans = ans + dp[par[cur]][0];
}
else
{
ans = ans - dp[par[cur]][1];
cl[par[cur]][1] = cl[par[cur]][0];
dp[par[cur]][1] = dp[par[cur]][0];
dp[par[cur]][0] = dp[cur][0] + d[cur];
ans = ans + dp[par[cur]][0];
cl[par[cur]][0] = cur;
}
cur = par[cur];
}
else
if (dp[cur][0] + d[cur] > dp[par[cur]][1] && cl[par[cur]][0] != cur)
{
///cout << "here" << endl;
if (cl[par[cur]][1] == cur)
{
ans = ans - dp[par[cur]][1];
dp[par[cur]][1] = dp[cur][0] + d[cur];
ans = ans + dp[par[cur]][1];
}
else
{
ans = ans - dp[par[cur]][1];
dp[par[cur]][1] = dp[cur][0] + d[cur];
ans = ans + dp[par[cur]][1];
cl[par[cur]][1] = cur;
}
cur = par[cur];
}
else
break;
}
ans = ans % mod;
//for (int i = 0; i < n; i ++)
// cout << i << " " << dp[i][0] << " " << dp[i][1] << endl;
}
void solve()
{
cin >> n;
for (int i = 1; i < n; i ++)
cin >> par[i];
for (int i = 1; i < n; i ++)
{
cin >> d[i];
g[par[i]].push_back({i, d[i]});
}
for (int i = 0; i < n; i ++)
cl[i][0] = cl[i][1] = -1;
dfs(0);
for (int i = 1; i <= n; i ++)
{
ans = (ans + dp[i][0] + dp[i][1]) % mod;
}
cout << ans << endl;
cin >> q;
for (int i = 1; i <= q; i ++)
{
int idx;
ll dx;
cin >> idx >> dx;
update(idx, dx);
cout << ans << endl;
}
}
int main()
{
solve();
return 0;
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
4 ms |
2644 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
255 ms |
10872 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
780 ms |
12512 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
4 ms |
2644 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |