#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 10;
int n, m, dp[N], dp2[N], sz[N], par[N], head[N], big[N], h[N], st[N], ft[N];
vector<int> adj[N], ver[N], tr;
vector<array<int, 3>> vec[N];
struct FEN {
int fen[N];
void add(int i, int v) {
// cout << i << ' ' << v << '\n';
for (++i; i < N; i += (i & -i))
fen[i] += v;
}
int get2(int i, int res = 0) {
for (; i > 0; i -= (i & -i))
res += fen[i];
return res;
}
int get(int l, int r) {
// cout << l << ' ' << r << ' ' << get2(r) - get2(l) << '\n';
return get2(r) - get2(l);
}
} fn;
void dfs_sz(int u) {
sz[u] = 1, big[u] = -1;
for (int v: adj[u])
if (v != par[u]) {
par[v] = u, h[v] = h[u] + 1;
dfs_sz(v);
if (big[u] == -1 || sz[v] > sz[big[u]])
big[u] = v;
sz[u] += sz[v];
}
}
void dfs_hld(int u) {
st[u] = tr.size();
tr.push_back(u);
if (big[u] > -1) {
head[big[u]] = head[u];
dfs_hld(big[u]);
}
for (int v: adj[u])
if (v != par[u] && v != big[u]) {
head[v] = v;
dfs_hld(v);
}
ft[u] = tr.size();
}
int get_lca(int u, int v) {
while (true) {
if (head[u] == head[v]) {
if (h[u] > h[v])
swap(u, v);
return u;
}
if (h[head[u]] < h[head[v]])
swap(u, v);
u = par[head[u]];
}
}
int get_path(int v, int u) {
// cout << v << ' ' << u << " ========================== ";
int res = 0;
while (true) {
if (head[u] == head[v]) {
res += fn.get(st[u] + 1, st[v] + 1);
// cout << res << '\n';
return res;
}
res += fn.get(st[head[v]], st[v] + 1);
v = par[head[v]];
}
}
int calc(int u) {
int res = dp2[u];
for (array<int, 3> x: vec[u]) {
int v1 = x[0], v2 = x[1], c = x[2];
if (h[v1] > h[v2])
swap(v1, v2);
if (v1 == u)
res = max(res, dp2[v2] + get_path(v2, u) + c);
else
res = max(res, dp2[v1] + dp2[v2] + get_path(v1, u) + get_path(v2, u) + c - dp2[u]);
}
return res;
}
void input() {
cin >> n;
for (int i = 0; i < n - 1; i++) {
int u, v;
cin >> u >> v;
u--, v--;
adj[u].push_back(v);
adj[v].push_back(u);
}
}
void solve() {
dfs_sz(0);
dfs_hld(0);
cin >> m;
for (int i = 0; i < m; i++) {
int u, v, c;
cin >> u >> v >> c;
u--, v--;
int w = get_lca(u, v);
vec[w].push_back({u, v, c});
}
for (int i = 0; i < n; i++)
ver[h[i]].push_back(i);
for (int i = n; i >= 0; i--) {
for (int u: ver[i]) {
dp[u] = calc(u);
if (u > 0)
dp2[par[u]] += dp[u];
}
for (int u: ver[i])
if (u > 0) {
fn.add(st[u], dp2[par[u]] - dp[u]);
// cout << u + 1 << ">>>>>>>>>>>>>>>>>>>>>" << dp2[par[u]] - dp[u] << '\n';
}
}
// for (int i = 0; i < n; i++)
// cout << i + 1 << " : " << dp[i] << ' ' << dp2[i] << '\n';
cout << dp[0] << '\n';
}
int main() {
ios:: sync_with_stdio(0), cin.tie(0), cout.tie(0);
input();
solve();
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |