Submission #1266838

#TimeUsernameProblemLanguageResultExecution timeMemory
1266838canhnam357Janjetina (COCI21_janjetina)C++20
110 / 110
243 ms15428 KiB
#include <bits/stdc++.h>
using namespace std;
#define N 100'005
int bit[N], mark[N]{}, timer = 0;
void add(int pos, int val) {
    while (pos < N) {
        if (mark[pos] != timer) {
            mark[pos] = timer;
            bit[pos] = 0;
        }
        bit[pos] += val;
        pos += pos & -pos;
    }
}
int get(int pos) {
    int res = 0;
    while (pos > 0) {
        if (mark[pos] != timer) {
            mark[pos] = timer;
            bit[pos] = 0;
        }
        res += bit[pos];
        pos -= pos & -pos;
    }
    return res;
}
int n, k;
vector<pair<int, int>> adj[N];
int del[N]{}, sz[N];
int dfs_sz(int u, int p) {
    sz[u] = 1;
    for (auto [v, w] : adj[u]) {
        if (v != p && !del[v]) {
            sz[u] += dfs_sz(v, u);
        }
    }
    return sz[u];
}
int centroid(int u, int n, int p) {
    for (auto [v, w] : adj[u]) {
        if (v != p && !del[v] && sz[v] > n / 2) return centroid(v, n, u);
    }
    return u;
}
vector<pair<long long, int>> s;
void dfs(int u, int p, int d, int h) {
    s.emplace_back(h, d);
    for (auto [v, w] : adj[u]) {
        if (v != p && !del[v]) dfs(v, u, d + 1, max(h, w));
    }
}
const long long inf = 1e18;
long long ans = 0;
void solve(int u) {
    timer++;
    int sz = dfs_sz(u, u);
    int c = centroid(u, sz, u);
    long long tot = 0;
    s = {{-inf, 0}};
    for (auto [v, w] : adj[c]) {
        if (del[v]) continue;
        dfs(v, c, 1, w);
    }
    sort(s.begin(), s.end());
    for (auto [v, d] : s) {
        if (v - d - k >= 0) ans += get(v - d - k + 1);
        add(d + 1, 1);
    }
    for (auto [v, w] : adj[c]) {
        if (del[v]) continue;
        s = {};
        dfs(v, c, 1, w);
        sort(s.begin(), s.end());
        timer++;
        for (auto [v, d] : s) {
            if (v - d - k > 0) ans -= get(v - d - k + 1);
            add(d + 1, 1);
        }
    }
    del[c] = 1;
    for (auto [v, w] : adj[c]) {
        if (!del[v]) solve(v);
    }
}
int main()
{
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    cin >> n >> k;
    for (int i = 1; i < n; i++) {
        int u, v, w;
        cin >> u >> v >> w;
        adj[u].emplace_back(v, w);
        adj[v].emplace_back(u, w);
    }
    solve(1);
    cout << ans * 2;
    return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...