Submission #848313

#TimeUsernameProblemLanguageResultExecution timeMemory
848313qthang2k11Janjetina (COCI21_janjetina)C++17
110 / 110
303 ms16680 KiB
#include <bits/stdc++.h>
using namespace std;

using ll = long long;

const int MAX_N = 1e5 + 5;

int sz[MAX_N], dpt[MAX_N], mx[MAX_N];
vector<pair<int, int>> adj[MAX_N];
bool rem[MAX_N];
int n, k;

int init_dfs(int x, int p) {
  sz[x] = 1;
  for (const auto &elem: adj[x]) {
    int y; tie(y, ignore) = elem;
    if (y == p || rem[y]) {
      continue;
    }
    sz[x] += init_dfs(y, x); 
  }
  return sz[x];
}

int centroid(int x, int p, int n) {
  for (const auto &elem: adj[x]) {
    int y; tie(y, ignore) = elem;
    if (y == p || rem[y]) continue;
    if (sz[y] * 2 > n) {
      return centroid(y, x, n);
    }
  }
  return x;
}

vector<pair<int, int>> arr;

ll ans = 0;

void init(int x, int p) {
  if (mx[x]) arr.emplace_back(mx[x], dpt[x]);
  if (mx[x] - dpt[x] >= k) ++ans;
  for (const auto &elem: adj[x]) {
    int y, w; tie(y, w) = elem;
    if (y == p || rem[y]) continue;
    dpt[y] = dpt[x] + 1;
    mx[y] = max(mx[x], w);
    init(y, x);
  }
}

const int H = 1e5;
int bit[H + 5];

void update(int x, int w) {
  for (; x <= H; x += x & -x) {
    bit[x] += w;
  }
}

int get(int x) {
  int ans = 0;
  for (; x > 0; x -= x & -x) {
    ans += bit[x];
  }
  return ans;
}

ll calc() {
  sort(arr.begin(), arr.end());
  ll ans = 0;
  for (const auto &elem: arr) {
    int val, h; tie(val, h) = elem;
    int to_find = val - h - k;
    // val - h - to_find >= k
    // => to_find + h - val <= -k
    // => to_find <= val - h - k
    ans += get(to_find);
    update(h, +1);
  }
  for (const auto &elem: arr)
    update(elem.second, -1);
  return ans;
}

void sub(int x, int p) {
  arr.emplace_back(mx[x], dpt[x]);
  for (const auto &elem: adj[x]) {
    int y; tie(y, ignore) = elem;
    if (y == p || rem[y]) continue;
    sub(y, x);
  }
}

void solve(int c) {
  dpt[c] = mx[c] = 0;
  vector<pair<int, int>>().swap(arr);
  init(c, 0);
  ans += calc();
  for (const auto &elem: adj[c]) {
    int y; tie(y, ignore) = elem;
    if (rem[y]) continue;
    vector<pair<int, int>>().swap(arr);
    sub(y, c);
    ans -= calc();
  }
}

void build(int x) {
  int n = init_dfs(x, 0);
  int c = centroid(x, 0, n);
  rem[c] = true;
  solve(c);
  for (const auto &elem: adj[c]) {
    int y; tie(y, ignore) = elem;
    if (rem[y]) continue;
    build(y);
  }
}

int32_t main() {
  cin.tie(0)->sync_with_stdio(0);
  cin >> n >> k;
  for (int i = 1; i < n; i++) {
    int x, y, w; 
    cin >> x >> y >> w;
    adj[x].emplace_back(y, w);
    adj[y].emplace_back(x, w);
  }
  build(1);
  cout << ans * 2;
  return 0;
}  
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...