Submission #1200147

#TimeUsernameProblemLanguageResultExecution timeMemory
1200147zNatsumiJanjetina (COCI21_janjetina)C++20
110 / 110
173 ms39040 KiB
#include <bits/stdc++.h>

using namespace std;

#define int long long
using ii = pair<int, int>;

const int N = 1e5 + 5;

int n, k, sz[N], depth[N], weight[N], pa[N], res;
bool del[N];
vector<int> vertex;
vector<ii> ad[N];

void find_sz(int u, int p){
  sz[u] = 1;
  for(auto [v, w] : ad[u]) if(v != p && !del[v]){
    find_sz(v, u);
    sz[u] += sz[v];
  }
}

int centroid(int u, int p, int s){
  for(auto [v, w] : ad[u])
    if(v != p && !del[v] &&  sz[v] > s/2) return centroid(v, u, s);
  return u;
}

struct BIT{
  int n;
  vector<int> bit;

  void init(int s){
    bit.clear();
    bit.resize(n = s, 0);
  }
  void update(int i, int y){
    for(; i <= n; i += i & -i) bit[i - 1] += y;
  }
  int get(int i){
    int res = 0; i = min(i, n);
    for(; i > 0; i -= i & -i) res += bit[i - 1];
    return res;
  }
} d[N];

int dfs(int u, int p, int root){
  pa[u] = root;
  vertex.push_back(u);
  if(weight[u] >= depth[u]) res += 1;

  int mx_depth = depth[u];
  for(auto [v, w] : ad[u]) if(v != p && !del[v]){
    depth[v] = depth[u] + 1;
    weight[v] = max(weight[u], w);
    mx_depth = max(mx_depth, dfs(v, u, root));
  }
  return mx_depth;
}

void solve(int u){
  find_sz(u, -1);
  u = centroid(u, -1, sz[u]);
  del[u] = true;

  depth[u] = 0;
  vertex.clear();
  int mx_depth = 0;
  for(auto [v, w] : ad[u]) if(!del[v]){
    depth[v] = 1;
    weight[v] = w;
    int tmp = dfs(v, u, v);
    mx_depth = max(mx_depth, tmp);
    d[v].init(tmp);
  }
  d[u].init(mx_depth);
  sort(vertex.begin(), vertex.end(), [&](int x, int y){
          return weight[x] < weight[y];
       });
  for(auto x : vertex){
    res += d[u].get(weight[x] - depth[x]) - d[pa[x]].get(weight[x] - depth[x]);
    d[u].update(depth[x], 1);
    d[pa[x]].update(depth[x], 1);
  }

  for(auto [v, w] : ad[u]) if(!del[v]) solve(v);
}

int32_t main(){
  cin.tie(0)->sync_with_stdio(0);
//  #define task "test"
//  if(fopen(task ".inp", "r")){
//    freopen(task ".inp", "r", stdin);
//    freopen(task ".out", "w", stdout);
//  }
  cin >> n >> k;
  for(int i = 2; i <= n; i++){
    int u, v, w; cin >> u >> v >> w;
    ad[u].push_back({v, w - k});
    ad[v].push_back({u, w - k});
  }
  solve(1);
  cout << 2 * res << "\n";

  return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...