Submission #1146668

#TimeUsernameProblemLanguageResultExecution timeMemory
1146668VMaksimoski008Election Campaign (JOI15_election_campaign)C++20
100 / 100
165 ms45124 KiB
#include <bits/stdc++.h>
#define ar array
//#define int long long

using namespace std;

using ll = long long;
using pii = pair<int, int>;
using pll = pair<ll, ll>;

const int mod = 1e9 + 7;
const ll inf = 1e18;
const int maxn = 1e5 + 5;

struct fenwick {
    int n;
    vector<ll> tree;
    void init(int _n) { n = _n + 10; tree.resize(n+50); }

    void update(int p, ll v) {
        for(p++; p<n; p+=p&-p) tree[p] += v;
    }

    ll query(int p) {
        ll ans = 0;
        for(p++; p; p-=p&-p) ans += tree[p];
        return ans;
    }
} fwt[2];

int n, m, up[maxn][20], dep[maxn], in[maxn], out[maxn], timer = 1;
vector<ar<int, 3> > P[maxn][2];
ll dp[maxn][2];
vector<int> G[maxn];

void dfs(int u, int p) {
    in[u] = timer++;
    for(int i=1; i<20; i++) up[u][i] = up[ up[u][i-1] ][i-1];

    for(int v : G[u]) {
        if(v == p) continue;
        dep[v] = dep[u] + 1;
        up[v][0] = u;
        dfs(v, u);
    }
    out[u] = timer++;
}

int jmp(int u, int d) {
    for(int j=19; j>=0; j--)
        if(d & (1 << j)) u = up[u][j];
    return u;
}

int lca(int a, int b) {
    if(dep[a] < dep[b]) swap(a, b);
    a = jmp(a, dep[a] - dep[b]);
    if(a == b) return a;
    for(int j=19; j>=0; j--)
        if(up[a][j] != up[b][j]) a = up[a][j], b = up[b][j];
    return up[a][0];
}

void add(int t, int u, ll v) {
    fwt[t].update(in[u], v);
    fwt[t].update(out[u], -v);
}

ll query(int t, int u, int v) {
    int l = lca(u, v);
    return fwt[t].query(in[u]) + fwt[t].query(in[v]) - fwt[t].query(in[l]) - fwt[t].query(in[up[l][0]]);
}

void dfs2(int u, int p) {
    for(int v : G[u]) {
        if(v == p) continue;
        dfs2(v, u);
        dp[u][0] += max(dp[v][0], dp[v][1]);
    }

    ll sum = dp[u][0];
    for(auto [_, v, c] : P[u][0]) {
        int x = jmp(v, dep[v] - dep[u] - 1);
        sum -= max(dp[x][0], dp[x][1]);
        ll val = query(1, x, v);

        if(x != v) {
            int x2 = jmp(v, dep[v] - dep[x] - 1);
            val -= query(0, x2, v);
        }

        dp[u][1] = max(dp[u][1], sum + val + c);
        sum += max(dp[x][0], dp[x][1]);
    }

    for(auto [v1, v2, c] : P[u][1]) {
        int x1 = jmp(v1, dep[v1] - dep[u] - 1);
        int x2 = jmp(v2, dep[v2] - dep[u] - 1);
        sum -= max(dp[x1][0], dp[x1][1]);
        sum -= max(dp[x2][0], dp[x2][1]);

        ll val = query(1, x1, v1) + query(1, x2, v2);

        if(x1 != v1) {
            int x3 = jmp(v1, dep[v1] - dep[x1] - 1);
            val -= query(0, x3, v1);
        }

        if(x2 != v2) {
            int x3 = jmp(v2, dep[v2] - dep[x2] - 1);
            val -= query(0, x3, v2);
        }

        dp[u][1] = max(dp[u][1], sum + val + c);

        sum += max(dp[x1][0], dp[x1][1]);
        sum += max(dp[x2][0], dp[x2][1]);
    }

    add(0, u, max(dp[u][0], dp[u][1]));
    for(int v : G[u]) {
        if(v == p) continue;
        add(1, u, max(dp[v][0], dp[v][1]));
    }
}

signed main() {
    ios_base::sync_with_stdio(false);
    cout.tie(0); cin.tie(0);

    cin >> n;
    fwt[0].init(2 * n);
    fwt[1].init(2 * n);
    for(int i=1; i<n; i++) {
        int a, b; cin >> a >> b;
        G[a].push_back(b);
        G[b].push_back(a);
    }

    dfs(1, 1);
    cin >> m;
    while(m--) {
        int a, b, c; cin >> a >> b >> c;
        int l = lca(a, b);
        if(l == a) P[a][0].push_back({ a, b, c });
        else if(l == b) P[b][0].push_back({ b, a, c });
        else P[l][1].push_back({ a, b, c });
    }

    dfs2(1, 1);
    cout << max(dp[1][0], dp[1][1]) << '\n';
}   
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...