Submission #1328615

#TimeUsernameProblemLanguageResultExecution timeMemory
1328615trimkusRace (IOI11_race)C++20
0 / 100
7 ms9796 KiB
#include "race.h"
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int MAXN = 2e5;
vector<pair<int, int>> adj[MAXN];
vector<pair<int, int>> nadj[MAXN];
bool vis[MAXN];
int par[MAXN];
int siz[MAXN];
int best_path(int N, int K, int H[][2], int L[]);
struct centroid {
    #define pb(x) push_back(x)
  vector<vector<int> > edges;
  vector<bool> vis;
  vector<int> par;
  vector<int> sz;
  int n;

  void init(int s) {
    n = s;
    edges = vector<vector<int> >(n, vector<int>());
    vis = vector<bool>(n, 0);
    par = vector<int>(n);
    sz = vector<int>(n);
  }

  void edge(int a, int b) {
    edges[a].pb(b);
    edges[b].pb(a);
  }

  int find_size(int v, int p = -1) {
    if (vis[v]) return 0;
    sz[v] = 1;

    for (int x: edges[v]) {
      if (x != p) {
        sz[v] += find_size(x, v);
      }
    }

    return sz[v];
  }

  int find_centroid(int v, int p, int n) {
    for (int x: edges[v]) {
      if (x != p) {
        if (!vis[x] && sz[x] > n / 2) {
          return find_centroid(x, v, n);
        }
      }
    }

    return v;
  }

  void init_centroid(int v = 0, int p = -1) {
    find_size(v);

    int c = find_centroid(v, -1, sz[v]);
    vis[c] = true;
    par[c] = p;

    for (int x: edges[c]) {
      if (!vis[x]) {
        init_centroid(x, c);
      }
    }
  }
};


struct LCA {
    int n;
    const int LOG = 20;
    vector<vector<int>> up;
    vector<int> dist, dep;
    void dfs(int v, int p, int d = 0) {
        for (auto& [u, w] : adj[v]) {
            if (u == p) continue;
            up[u][0] = v;
            dist[u] = d + w;
            dep[u] = dep[v] + 1;
            dfs(u, v, d + w);
        }
    }
    void init(int n) {
        this->n = n;
        up = vector<vector<int>>(n, vector<int>(LOG, -1));
        dist = vector<int>(n);
        dep = vector<int>(n);
        dfs(0, -1);
        for (int i = 1; i < LOG; ++i) {
            for (int j = 0; j < n; ++j) {
                if (up[j][i - 1] == -1) continue;
                up[j][i] = up[up[j][i - 1]][i - 1];
            }
        }
    }
    int common(int u, int v) {
        if (u == v) return u;
        if (dep[u] < dep[v]) swap(u, v);
        int K = dep[u] - dep[v];
        for (int i = 0; i < LOG; ++i) {
            if (K >> i & 1) u = up[u][i];
        }
        if (u == v) return u;
        for (int i = LOG - 1; i >= 0; --i) {
            if (up[u][i] == -1) continue;
            if (up[u][i] != up[v][i]) {
                u = up[u][i];
                v = up[v][i];
            }
        }
        assert(up[v][0] == up[u][0]);
        assert(up[v][0] != -1);
        return up[v][0];
    }
    int DIST(int u, int v) {
        return dist[u] + dist[v] - 2 * dist[common(u, v)];
    }
};
map<int, int> dp[MAXN];
int ret = -1;
int need;
void calc(int u, int v, int add) {
    int nneed = need - add;
    if (dp[u].size() < dp[v].size()) swap(u, v);
    for (auto& [dist, val] : dp[v]) {
        if (dp[u].count(nneed - dist)) {
//            cerr << "Found at " << v << " "  << u << " = " << dist << " " <<  nneed - dist << " :: " << dp[u][nneed - dist] << " " << val << endl;
            int now = dp[u][nneed - dist] + val + 1;
            if (ret == -1 || ret > now) ret = now;
        }
    }
}
void dfs(int v, int p) {
    dp[v][0] = 0;
    for (auto& [u, w] : nadj[v]) {
        if (u == p) continue;
        dfs(u, v);
        calc(v, u, w);
        for (auto& [dist, val] : dp[u]) {
            if (!dp[v].count(dist + w)) {
                dp[v][dist + w] = val + 1;
            } else {
                dp[v][dist + w] = min(dp[v][dist + w], val + 1);
            }
        }
        while ((int)dp[v].size() > 0 && prev(dp[v].end())->first > need) dp[v].erase(prev(dp[v].end()));
        dp[u].clear();
    }
//    cerr << "v = " << v << ":\n";
//    for (auto& [u, w] : dp[v]) {
//        cerr << u << " " << w << endl;
//    }
}


int best_path(int N, int K, int H[][2], int L[])
{
    need = K;
    centroid c;
    c.init(N);
    for (int i = 0; i + 1 < N; ++i) {
        int u = H[i][0];
        int v = H[i][1];
//        cerr << u << " " << v << endl;
        c.edge(u, v);
        adj[u].push_back({v, L[i]});
        adj[v].push_back({u, L[i]});
    }
    c.init_centroid();
//    for (int i = 0; i < N; ++i) {
//        cerr << i << " = " << c.par[i] << endl;
//    }
    LCA lca;
    lca.init(N);
//    for (int i =0 ; i < N; ++i) {
//        for (int j = i + 1; j < N; ++j) {
//            cerr << i << " -> " << j << " dist = " << lca.DIST(i, j) << endl;
//        }
//    }
    int root = -1;
    for (int i = 0; i < N; ++i) {
        if (c.par[i] != -1) {
            int p = c.par[i];
            nadj[p].push_back({i, lca.DIST(i, p)});
        } else {
            assert(root == -1);
            root = i;
        }
    }
//    cerr << "root = " << root << endl;
//    for (int i = 0; i < N; ++i) {
//        cerr << i << ":\n";
//        for (auto& [u, w] : nadj[i]) {
//            cerr << u << " " << w << endl;
//        }
//    }
    dfs(root, -1);
    return ret;
}

//#define MAX_N 500000
//
//static int N, K;
//static int H[MAX_N][2];
//static int L[MAX_N];
//static int solution;
//
//inline
//void my_assert(int e) {if (!e) abort();}
//
//void read_input()
//{
//  int i;
//  my_assert(2==scanf("%d %d",&N,&K));
//  for(i=0; i<N-1; i++)
//    my_assert(3==scanf("%d %d %d",&H[i][0],&H[i][1],&L[i]));
//  my_assert(1==scanf("%d",&solution));
//}
//
//int main()
//{
//  int ans;
//  read_input();
//  ans = best_path(N,K,H,L);
//  if(ans==solution)
//    printf("Correct.\n");
//  else
//    printf("Incorrect. Returned %d, Expected %d.\n",ans,solution);
//
////  return 0;
//}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...