#include <bits/stdc++.h>
#define ar array
//#define int long long
using namespace std;
using ll = long long;
using pii = pair<int, int>;
using pll = pair<ll, ll>;
const int mod = 1e9 + 7;
const ll inf = 1e18;
const int maxn = 1e5 + 5;
struct fenwick {
int n;
vector<ll> tree;
void init(int _n) { n = _n + 10; tree.resize(n+50); }
void update(int p, ll v) {
for(p++; p<n; p+=p&-p) tree[p] += v;
}
ll query(int p) {
ll ans = 0;
for(p++; p; p-=p&-p) ans += tree[p];
return ans;
}
} fwt[2];
int n, m, up[maxn][20], dep[maxn], in[maxn], out[maxn], timer = 1;
vector<ar<int, 3> > P[maxn][2];
ll dp[maxn][2];
vector<int> G[maxn];
void dfs(int u, int p) {
in[u] = timer++;
for(int i=1; i<20; i++) up[u][i] = up[ up[u][i-1] ][i-1];
for(int v : G[u]) {
if(v == p) continue;
dep[v] = dep[u] + 1;
up[v][0] = u;
dfs(v, u);
}
out[u] = timer++;
}
int jmp(int u, int d) {
for(int j=19; j>=0; j--)
if(d & (1 << j)) u = up[u][j];
return u;
}
int lca(int a, int b) {
if(dep[a] < dep[b]) swap(a, b);
a = jmp(a, dep[a] - dep[b]);
if(a == b) return a;
for(int j=19; j>=0; j--)
if(up[a][j] != up[b][j]) a = up[a][j], b = up[b][j];
return up[a][0];
}
void add(int t, int u, ll v) {
fwt[t].update(in[u], v);
fwt[t].update(out[u], -v);
}
ll query(int t, int u, int v) {
int l = lca(u, v);
return fwt[t].query(in[u]) + fwt[t].query(in[v]) - fwt[t].query(in[l]) - fwt[t].query(in[up[l][0]]);
}
void dfs2(int u, int p) {
for(int v : G[u]) {
if(v == p) continue;
dfs2(v, u);
dp[u][0] += max(dp[v][0], dp[v][1]);
}
ll sum = dp[u][0];
for(auto [_, v, c] : P[u][0]) {
int x = jmp(v, dep[v] - dep[u] - 1);
sum -= max(dp[x][0], dp[x][1]);
ll val = query(1, x, v);
if(x != v) {
int x2 = jmp(v, dep[v] - dep[x] - 1);
val -= query(0, x2, v);
}
dp[u][1] = max(dp[u][1], sum + val + c);
sum += max(dp[x][0], dp[x][1]);
}
for(auto [v1, v2, c] : P[u][1]) {
int x1 = jmp(v1, dep[v1] - dep[u] - 1);
int x2 = jmp(v2, dep[v2] - dep[u] - 1);
sum -= max(dp[x1][0], dp[x1][1]);
sum -= max(dp[x2][0], dp[x2][1]);
ll val = query(1, x1, v1) + query(1, x2, v2);
if(x1 != v1) {
int x3 = jmp(v1, dep[v1] - dep[x1] - 1);
val -= query(0, x3, v1);
}
if(x2 != v2) {
int x3 = jmp(v2, dep[v2] - dep[x2] - 1);
val -= query(0, x3, v2);
}
dp[u][1] = max(dp[u][1], sum + val + c);
sum += max(dp[x1][0], dp[x1][1]);
sum += max(dp[x2][0], dp[x2][1]);
}
add(0, u, max(dp[u][0], dp[u][1]));
for(int v : G[u]) {
if(v == p) continue;
add(1, u, max(dp[v][0], dp[v][1]));
}
}
signed main() {
ios_base::sync_with_stdio(false);
cout.tie(0); cin.tie(0);
cin >> n;
fwt[0].init(2 * n);
fwt[1].init(2 * n);
for(int i=1; i<n; i++) {
int a, b; cin >> a >> b;
G[a].push_back(b);
G[b].push_back(a);
}
dfs(1, 1);
cin >> m;
while(m--) {
int a, b, c; cin >> a >> b >> c;
int l = lca(a, b);
if(l == a) P[a][0].push_back({ a, b, c });
else if(l == b) P[b][0].push_back({ b, a, c });
else P[l][1].push_back({ a, b, c });
}
dfs2(1, 1);
cout << max(dp[1][0], dp[1][1]) << '\n';
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |