답안 #527894

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
527894 2022-02-18T16:29:58 Z cig32 Transport (COCI19_transport) C++17
91 / 130
1000 ms 28584 KB
#pragma GCC optimize("Ofast")
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/detail/standard_policies.hpp>
using namespace std;
using namespace __gnu_pbds;
mt19937_64 rng((int)std::chrono::steady_clock::now().time_since_epoch().count());
const int MAXN = 1e5 + 10;
const int MOD = 1e9 + 7;
#define int long long
typedef tree<pair<int, int>, null_type, less<pair<int, int> >, rb_tree_tag, tree_order_statistics_node_update> ordered_set;
int rnd(int x, int y) {
  int u = uniform_int_distribution<int>(x, y)(rng); return u;
}
int n, a[MAXN]; set<pair<int, int> > adj[MAXN];
//Centroid: when you cut this vertex (remove it and 
//remove all edges from this vertex), the size of the largest 
//connected component of the remaining graph is the smallest possible.
//https://codeforces.com/problemset/problem/1406/C
//Total time complexity: O(n log^2 n)
int sz, mi, ctd, ans = 0; // ctd = centroid
int dfs0(int node, int prv) {
  int tot = 1;
  for(auto x: adj[node]) {
    if(x.first != prv) {
      tot += dfs0(x.first, node);
    }
  }
  return tot;
}
int dfs1(int node, int prv) {
  int ret = 1;
  int ma = 0;
  for(auto x: adj[node]) {
    if(x.first != prv) {
      int q = dfs1(x.first, node);
      ret += q;
      ma = max(ma, q);
    }
  }
  int rem = sz - ret;
  ma = max(ma, rem);
  if(ma < mi) {
    mi = ma;
    ctd = node;
  }
  return ret;
}
int id = 0;
void decomp(int st) {
  //st: an arbitrary node in the tree in question
  //step 1: find the centroid + count number of nodes
  sz = dfs0(st, -1);
  mi = 1e9, ctd = -1;
  dfs1(st, -1);
  //step 2: if nodes = 1, break
  if(sz == 1) return;
  //step 3: root the subtree at centroid, do a bfs
  // all weights = {sum, min}

  
  //step 3a: count id[u] < id[v]
  vector<pair<int,int> > all;
  for(auto x: adj[ctd]) all.push_back(x);
  ordered_set s;
  for(int i=0; i<all.size(); i++) {
    adj[ctd].erase(all[i]);
    adj[all[i].first].erase({ctd, all[i].second});
    unordered_map<int, pair<int, int> > w1; // weights to count
    unordered_map<int, pair<int, int> > w2; // weights to insert
    queue<int> q;
    set<int> vis;
    w1[all[i].first] = {a[ctd] + all[i].second, a[ctd] + all[i].second};
    w2[all[i].first] = {a[all[i].first] + all[i].second, a[all[i].first] + all[i].second};
    q.push(all[i].first);
    while(q.size()) {
      int f=q.front(); q.pop();
      int og = *vis.lower_bound(f);
      if(og ==f) continue;
      vis.insert(f);
      if(s.size()>0) {
        if(w1[f].second>=0) ans += s.size();
        else ans += s.size() - s.order_of_key({-w1[f].second, 0});
      }
      for(auto x: adj[f]) {
        og = *vis.lower_bound(x.first);
        if(og == x.first)continue;
        q.push(x.first);
        w1[x.first] = {w1[f].first + a[f] + x.second, min(w1[f].second, w1[f].first + a[f] + x.second)};
        w2[x.first] = {w2[f].first + a[x.first] + x.second, min(w2[f].second, 0ll) + a[x.first] + x.second};
      }
    }
    //insert
    for(int x: vis) {
      if(w2[x].second >= 0) s.insert({w2[x].first, ++id});
    }
    adj[ctd].insert(all[i]);
    adj[all[i].first].insert({ctd, all[i].second});
  }

  
  reverse(all.begin(), all.end());
  s.clear();
  //step 3b: count id[u] > id[v]
  for(int i=0; i<all.size(); i++) {
    adj[ctd].erase(all[i]);
    adj[all[i].first].erase({ctd, all[i].second});
    unordered_map<int, pair<int, int> > w1; // weights to count
    unordered_map<int, pair<int, int> > w2; // weights to insert
    queue<int> q;
    set<int> vis;
    w1[all[i].first] = {a[ctd] + all[i].second, a[ctd] + all[i].second};
    w2[all[i].first] = {a[all[i].first] + all[i].second, a[all[i].first] + all[i].second};
    q.push(all[i].first);
    while(q.size()) {
      int f=q.front(); q.pop();
      int og = *vis.lower_bound(f);
      if(og ==f) continue;
      vis.insert(f);
      if(s.size()>0) {
        if(w1[f].second>=0) ans += s.size();
        else ans += s.size() - s.order_of_key({-w1[f].second, 0});
      }
      for(auto x: adj[f]) {
        og = *vis.lower_bound(x.first);
        if(og == x.first)continue;
        q.push(x.first);
        w1[x.first] = {w1[f].first + a[f] + x.second, min(w1[f].second, w1[f].first + a[f] + x.second)};
        w2[x.first] = {w2[f].first + a[x.first] + x.second, min(w2[f].second, 0ll) + a[x.first] + x.second};
      }
    }
    //insert
    for(int x: vis) {
      if(w2[x].second >= 0) {
        s.insert({w2[x].first, ++id});
      }
    }
    adj[ctd].insert(all[i]);
    adj[all[i].first].insert({ctd, all[i].second});
    //step 3c: count u = root (= ctd)
    for(int x: vis) {
      if(w1[x].second >= 0) {
        ans++;
      }
    }
  }
  //step 3d: count v = root (= ctd)
  ans += s.size(); 
  //step 4: remove all edges from the centroid
  for(auto x: adj[ctd]) adj[x.first].erase({ctd, x.second});
  for(auto x: adj[ctd]) decomp(x.first);
  adj[ctd].clear();
}
void solve(int tc) {
  cin >> n;
  for(int i=1; i<=n; i++) cin >> a[i];
  for(int i=1; i<n; i++) {
    int u, v, w; cin >> u >> v >> w;
    w = -w;
    adj[u].insert({v, w});
    adj[v].insert({u, w});
  }
  decomp(1);
  cout << ans << "\n";
}
int32_t main(){
  ios::sync_with_stdio(0); cin.tie(0);
  int t = 1; //cin >> t;
  for(int i=1; i<=t; i++) solve(i);
}  

Compilation message

transport.cpp: In function 'void decomp(long long int)':
transport.cpp:67:17: warning: comparison of integer expressions of different signedness: 'long long int' and 'std::vector<std::pair<long long int, long long int> >::size_type' {aka 'long unsigned int'} [-Wsign-compare]
   67 |   for(int i=0; i<all.size(); i++) {
      |                ~^~~~~~~~~~~
transport.cpp:106:17: warning: comparison of integer expressions of different signedness: 'long long int' and 'std::vector<std::pair<long long int, long long int> >::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  106 |   for(int i=0; i<all.size(); i++) {
      |                ~^~~~~~~~~~~
# 결과 실행 시간 메모리 Grader output
1 Correct 25 ms 5676 KB Output is correct
2 Correct 24 ms 5864 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 36 ms 6220 KB Output is correct
2 Correct 41 ms 6416 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 560 ms 18204 KB Output is correct
2 Correct 453 ms 15708 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 759 ms 22188 KB Output is correct
2 Correct 816 ms 23400 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Execution timed out 1087 ms 28584 KB Time limit exceeded
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 418 ms 11716 KB Output is correct
2 Correct 199 ms 10180 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 388 ms 14448 KB Output is correct
2 Correct 571 ms 13820 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 890 ms 17012 KB Output is correct
2 Correct 816 ms 17652 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Execution timed out 1100 ms 22092 KB Time limit exceeded
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Execution timed out 1081 ms 28028 KB Time limit exceeded
2 Halted 0 ms 0 KB -