제출 #387206

#제출 시각아이디문제언어결과실행 시간메모리
387206phathnvJanjetina (COCI21_janjetina)C++11
110 / 110
432 ms18016 KiB
#include <bits/stdc++.h>
using namespace std;

typedef long long ll;

const int N = 1e5 + 7;

struct Edge{
    int v, c;
    Edge(int _v, int _c){
        v = _v;
        c = _c;
    }
};

struct Bit{
    int d[N];
    void update(int x, int val){
        for(; x < N; x += x & -x)
            d[x] += val;
    }
    int get(int x){
        int res = 0;
        for(; x > 0; x -= x & -x)
            res += d[x];
        return res;
    }
};

int n, k, sz[N];
vector<Edge> adj[N];
ll answer = 0;
Bit bit;
bool done[N];

void DfsForCentroid(int u, int p){
    sz[u] = 1;
    for(Edge e : adj[u]){
        int v = e.v;
        if (v == p || done[v])
            continue;
        DfsForCentroid(v, u);
        sz[u] += sz[v];
    }
}

int FindCentroid(int u, int p, int treeSize){
    for(Edge e : adj[u]){
        int v = e.v;
        if (v == p || done[v])
            continue;
        if (sz[v] * 2 > treeSize)
            return FindCentroid(v, u, treeSize);
    }
    return u;
}

void Dfs(int u, int p, int len, int maxW, vector<pair<int, int>> &a){
    if (done[u])
        return;

    a.push_back({maxW, len});
    for(Edge e : adj[u]){
        int v = e.v;
        int c = e.c;
        if (v == p)
            continue;
        Dfs(v, u, len + 1, max(maxW, c), a);
    }
}

ll Calc(vector<pair<int, int>> &a){
    ll res = 0;
    sort(a.begin(), a.end());
    for(auto p : a){
        res += bit.get(p.first - p.second);
        bit.update(p.second + k, 1);
    }
    for(auto p : a)
        bit.update(p.second + k, -1);
    return res;
}

void Solve(int u){
    if (done[u])
        return;

    DfsForCentroid(u, -1);
    u = FindCentroid(u, -1, sz[u]);
    vector<pair<int, int>> a;
    Dfs(u, -1, 0, 0, a);
    answer += Calc(a);
    for(Edge e : adj[u]){
        a.clear();
        int v = e.v;
        int c = e.c;
        Dfs(v, u, 1, c, a);
        answer -= Calc(a);
    }

    done[u] = 1;
    for(Edge e : adj[u])
        Solve(e.v);
}

int main(){
    ios_base::sync_with_stdio(0);
    cin.tie(0);

    cin >> n >> k;
    for(int i = 1; i < n; i++){
        int u, v, c;
        cin >> u >> v >> c;
        adj[u].push_back(Edge(v, c));
        adj[v].push_back(Edge(u, c));
    }

    Solve(1);
    cout << answer * 2;

    return 0;
}

#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...