This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int MAX_N = 1e5 + 5;
int sz[MAX_N], dpt[MAX_N], mx[MAX_N];
vector<pair<int, int>> adj[MAX_N];
bool rem[MAX_N];
int n, k;
int init_dfs(int x, int p) {
sz[x] = 1;
for (const auto &elem: adj[x]) {
int y; tie(y, ignore) = elem;
if (y == p || rem[y]) {
continue;
}
sz[x] += init_dfs(y, x);
}
return sz[x];
}
int centroid(int x, int p, int n) {
for (const auto &elem: adj[x]) {
int y; tie(y, ignore) = elem;
if (y == p || rem[y]) continue;
if (sz[y] * 2 > n) {
return centroid(y, x, n);
}
}
return x;
}
vector<pair<int, int>> arr;
ll ans = 0;
void init(int x, int p) {
if (mx[x]) arr.emplace_back(mx[x], dpt[x]);
if (mx[x] - dpt[x] >= k) ++ans;
for (const auto &elem: adj[x]) {
int y, w; tie(y, w) = elem;
if (y == p || rem[y]) continue;
dpt[y] = dpt[x] + 1;
mx[y] = max(mx[x], w);
init(y, x);
}
}
const int H = 1e5;
int bit[H + 5];
void update(int x, int w) {
for (; x <= H; x += x & -x) {
bit[x] += w;
}
}
int get(int x) {
int ans = 0;
for (; x > 0; x -= x & -x) {
ans += bit[x];
}
return ans;
}
ll calc() {
sort(arr.begin(), arr.end());
ll ans = 0;
for (const auto &elem: arr) {
int val, h; tie(val, h) = elem;
int to_find = val - h - k;
// val - h - to_find >= k
// => to_find + h - val <= -k
// => to_find <= val - h - k
ans += get(to_find);
update(h, +1);
}
for (const auto &elem: arr)
update(elem.second, -1);
return ans;
}
void sub(int x, int p) {
arr.emplace_back(mx[x], dpt[x]);
for (const auto &elem: adj[x]) {
int y; tie(y, ignore) = elem;
if (y == p || rem[y]) continue;
sub(y, x);
}
}
void solve(int c) {
dpt[c] = mx[c] = 0;
vector<pair<int, int>>().swap(arr);
init(c, 0);
ans += calc();
for (const auto &elem: adj[c]) {
int y; tie(y, ignore) = elem;
if (rem[y]) continue;
vector<pair<int, int>>().swap(arr);
sub(y, c);
ans -= calc();
}
}
void build(int x) {
int n = init_dfs(x, 0);
int c = centroid(x, 0, n);
rem[c] = true;
solve(c);
for (const auto &elem: adj[c]) {
int y; tie(y, ignore) = elem;
if (rem[y]) continue;
build(y);
}
}
int32_t main() {
cin.tie(0)->sync_with_stdio(0);
cin >> n >> k;
for (int i = 1; i < n; i++) {
int x, y, w;
cin >> x >> y >> w;
adj[x].emplace_back(y, w);
adj[y].emplace_back(x, w);
}
build(1);
cout << ans * 2;
return 0;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |