#include <bits/stdc++.h>
using namespace std;
#ifndef _LOCAL
//#pragma GCC optimize("O3,Ofast")
#else
#pragma GCC optimize("O0")
#endif
template<typename t> inline void umin(t &a, const t b) {a = min(a, b);}
template<typename t> inline void umax(t &a, const t b) {a = max(a, b);}
typedef pair<int, int> pii;
typedef long long ll;
typedef long double ld;
typedef int8_t byte;
ll time() {return chrono::system_clock().now().time_since_epoch().count();}
mt19937 rnd(time());
#define ft first
#define sd second
#define len(f) int((f).size())
#define bnd(f) (f).begin(), (f).end()
#define _ <<' '<<
const int inf = 1e9 + 5;
const ll inf64 = 4e18 + 5;
const int md = 998244353;
namespace MD {
void add(int &a, const int b) {if((a += b) >= md) a -= md;}
void sub(int &a, const int b) {if((a -= b) < 0) a += md;}
int prod(const int a, const int b) {return ll(a) * b % md;}
};
const int N = 1e5 + 5;
int n, k, si[N];
vector<pii> g[N];
bool mk[N];
ll ans;
void sizes(int v, int pr = -1) {
si[v] = 1;
for(auto i : g[v]) {
int u = i.ft;
if(u == pr || mk[u]) continue;
sizes(u, v);
si[v] += si[u];
}
}
int cent(int v, int pr = -1, int n = -1) {
if(pr < 0) n = si[v];
for(auto i : g[v]) {
int u = i.ft;
if(u == pr || mk[u]) continue;
if(si[u] << 1 >= n) return cent(u, v, n);
}
return v;
}
vector<int> t, h;
void dfs2(int v, int pr = -1, int F = 0, int H = 0) {
// cerr << v _ F _ H _ F - H + 1 << endl;
t.push_back(F - H + 1);
h.push_back(H);
for(auto i : g[v]) {
int u = i.ft;
if(u == pr || mk[u]) continue;
dfs2(u, v, max(F, i.sd), H + 1);
}
}
int fw[N + 228], pt[N], ph[N], mx;
void fadd(int i) {for(i = mx - i + 1; i <= mx + 3; i += i & -i) ++fw[i];}
int fget(int i) {int v{}; for(i = mx - i + 1; i; i -= i & -i) v += fw[i]; return v;}
ll calc(int v) {
t.clear();
h.clear();
dfs2(v);
int m = len(t);
mx = *max_element(bnd(h)) + 1;
for(int &i : t) {
umax(i, 0);
umin(i, mx);
}
iota(pt, pt + m, 0); sort(pt, pt + m, [&] (const int &a, const int &b) {return t[a] < t[b];});
iota(ph, ph + m, 0); sort(ph, ph + m, [&] (const int &a, const int &b) {return h[a] < h[b];});
ll res = 0;
++mx;
memset(fw, 0, mx + 5 << 2);
for(int it = 0, it2 = 0; it < m; ++it) {
int H = h[ph[it]];
int T = t[ph[it]];
if(T <= H) --res;
while(it2 < m) {
int H2 = h[pt[it2]];
int T2 = t[pt[it2]];
if(T2 <= H) {
fadd(H2);
++it2;
} else break;
}
res += fget(T);
}
return res;
}
void dfs(int v = 0) {
sizes(v);
v = cent(v);
mk[v] = true;
ans += calc(v);
for(auto i : g[v]) {
int u = i.ft;
if(mk[u]) continue;
ans -= calc(u);
dfs(u);
}
}
void solve() {
cin >> n >> k;
for(int i = 0; i < n; ++i)
g[i].clear();
for(int i = 1; i < n; ++i) {
int x, y, w;
cin >> x >> y >> w;
--x, --y;
g[x].push_back({y, w - k});
g[y].push_back({x, w - k});
}
memset(mk, 0, n);
ans = 0;
dfs();
cout << n * ll(n - 1) - ans << endl;
}
signed main() {
ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
#ifndef _LOCAL
// freopen("file.in", "r", stdin);
// freopen("file.out", "w", stdout);
#else
system("color a");
freopen("in.txt", "r", stdin);
int t; cin >> t;
while(t--)
#endif
solve();
}
Compilation message
Main.cpp: In function 'll calc(int)':
Main.cpp:88:22: warning: suggest parentheses around '+' inside '<<' [-Wparentheses]
88 | memset(fw, 0, mx + 5 << 2);
| ~~~^~~
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
2 ms |
2668 KB |
Output is correct |
2 |
Correct |
3 ms |
2668 KB |
Output is correct |
3 |
Incorrect |
3 ms |
2668 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
3 ms |
2668 KB |
Output is correct |
2 |
Correct |
3 ms |
2668 KB |
Output is correct |
3 |
Incorrect |
3 ms |
2668 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
2 ms |
2668 KB |
Output is correct |
2 |
Correct |
3 ms |
2668 KB |
Output is correct |
3 |
Incorrect |
3 ms |
2668 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |