#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int N = 1e5 + 5, mod = 1e9 + 7;
int n, m, dem, tin[N], tout[N], par[N][18], h[N];
vector <int> g[N];
ll dp[N], f[N], seg[4*N], lay[4*N];
struct elec {
int a,b,c;
};
vector <elec> luu[N];
void pro(int x, int l, int r) {
if(lay[x] == 0) return;
seg[x] += lay[x] * (r-l+1);
if(l != r) {
lay[2*x+1] += lay[x];
lay[2*x] += lay[x];
}
lay[x] = 0;
}
void up(int x, int l, int r, int i, int j, ll val) {
pro(x,l,r);
if(l > j || r < i) return;
if(l >= i && r <= j) {
lay[x] += val;
pro(x,l,r);
return;
}
int mid = (l + r)/2;
up(2*x,l,mid,i,j,val);
up(2*x+1,mid+1,r,i,j,val);
seg[x] = seg[2*x] + seg[2*x+1];
}
int get(int x, int l, int r, int i, int j) {
pro(x,l,r);
if(l > j || r < i) return 0;
if(l >= i && r <= j) return seg[x];
int mid = (l + r)/2;
return get(2*x,l,mid,i,j) + get(2*x+1,mid+1,r,i,j);
}
void dfs(int u) {
tin[u] = ++dem;
tout[u] = tin[u];
for(int v : g[u]) {
if(v == par[u][0]) continue;
h[v] = h[u] + 1;
par[v][0] = u;
for(int i=1;i<18;++i) par[v][i] = par[par[v][i-1]][i-1];
dfs(v);
tout[u] = max(tout[u], tout[v]);
}
}
int lca(int u,int v){
if(h[u] != h[v]) {
if(h[u] < h[v]) swap(u,v);
int k = h[u] - h[v];
for(int j=0;(1 << j) <= k;++j) if (k >> j & 1) u = par[u][j];
}
if(u == v) return u;
int k = __lg(h[u]);
for(int j=k;j>=0;--j) if (par[u][j] != par[v][j]) u = par[u][j], v = par[v][j];
return par[u][0];
}
void DFS(int u) {
for(int v : g[u]) if(v != par[u][0]) DFS(v), f[u] += dp[v];
dp[u] = f[u];
for(elec p : luu[u]) {
int x = p.a, y = p.b, z = p.c;
if(h[x] > h[y]) swap(x,y);
if(x == u && y == u) dp[u] = max(dp[u], f[u] + z);
else if(x == u) {
int v2 = y;
int k = h[y] - h[u] - 1;
for(int i=0;(1 << i) <= k;++i)
if(k >> i & 1) v2 = par[v2][i];
ll vl = f[u] - dp[v2] + f[y];
vl += get(1,1,n,tin[y],tin[y]);
dp[u] = max(dp[u], vl + z);
//if(u == 1) {
// cout << f[u] << ' ' << dp[v2] << ' ' << f[y] << ' ' << get(1,1,n,tin[y],tin[y]) << '\n';
// cout << y << ' ' << v2 << '\n';
//}
//if(u == 1) for(int i=1;i<=n;++i) cout << get(1,1,n,tin[i],tin[i]) << ' ';
}
else {
int v1 = x, v2 = y;
int k = h[x] - h[u] - 1;
for(int i=0;(1 << i) <= k;++i)
if(k >> i & 1) v1 = par[v1][i];
k = h[y] - h[u] - 1;
for(int i=0;(1 << i) <= k;++i)
if(k >> i & 1) v2 = par[v2][i];
ll vl = f[u] - dp[v1] - dp[v2] + f[x] + f[y];
vl += get(1,1,n,tin[x],tin[x]) + get(1,1,n,tin[y],tin[y]);
dp[u] = max(dp[u], vl + z);
}
}
for(int v : g[u]) if(v != par[u][0]) up(1,1,n,tin[v],tout[v],f[u]-dp[v]);
//cout << u << '\n';
//for(int i=1;i<=n;++i) cout << get(1,1,n,tin[i],tin[i]) << ' '; cout << '\n';
}
int main() {
ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0);
cin >> n;
for(int i=1;i<n;++i) {
int u,v;
cin >> u >> v;
g[u].push_back(v);
g[v].push_back(u);
}
dfs(1);
cin >> m;
while(m--) {
int _a,_b,_c;
cin >> _a >> _b >> _c;
luu[lca(_a, _b)].push_back({_a,_b,_c});
}
DFS(1); //cout << '\n';
//for(elec pp : luu[1]) cout << pp.a << ' ' << pp.b << ' ' << pp.c << '\n';
//for(int i=1;i<=n;++i) cout << dp[i] << ' '; cout << '\n';
//for(int i=1;i<=n;++i) cout << f[i] << ' '; cout << '\n';
cout << dp[1];
return 0;
}
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |