제출 #1347174

#제출 시각아이디문제언어결과실행 시간메모리
1347174model_codeJOI Tour 2 (JOI26_joitour)C++20
100 / 100
2476 ms112812 KiB
// O(N sqrt(M) + M log(N) + N Q) time, O(N log(N) + M + Q) space

#include <cassert>
#include <cmath>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>
#include <bitset>
#include <chrono>
#include <complex>
#include <deque>
#include <functional>
#include <iostream>
#include <limits>
#include <map>
#include <numeric>
#include <queue>
#include <random>
#include <set>
#include <sstream>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>

using namespace std;

using Int = long long;

template <class T> ostream &operator<<(ostream &os, const vector<T> &as);
template <class T1, class T2> ostream &operator<<(ostream &os, const pair<T1, T2> &a) { return os << "(" << a.first << ", " << a.second << ")"; };
template <class T> ostream &operator<<(ostream &os, const vector<T> &as) { const int sz = as.size(); os << "["; for (int i = 0; i < sz; ++i) { if (i >= 256) { os << ", ..."; break; } if (i > 0) { os << ", "; } os << as[i]; } return os << "]"; }
template <class T> void pv(T a, T b) { for (T i = a; i != b; ++i) cerr << *i << " "; cerr << endl; }
template <class T> bool chmin(T &t, const T &f) { if (t > f) { t = f; return true; } return false; }
template <class T> bool chmax(T &t, const T &f) { if (t < f) { t = f; return true; } return false; }
#define COLOR(s) ("\x1b[" s "m")


// \sum[m] \sum[u<v on [S[m], T[m]]] [C[u] + C[v] = D[q]]

// N <= 2^E
constexpr int E = 17;

int N;
vector<int> C;
vector<int> A, B;
int M;
vector<int> S, T;
int Q;
vector<int> D;

vector<vector<int>> graph;
vector<int> par, dep;
vector<int> dis, fin;
vector<int> us;
void dfs(int u, int p) {
  if (~p) graph[u].erase(find(graph[u].begin(), graph[u].end(), p));
  par[u] = p;
  dep[u] = (~p) ? (dep[p] + 1) : 0;
  dis[u] = us.size();
  us.push_back(u);
  for (const int v : graph[u]) {
    dfs(v, u);
    us.push_back(u);
  }
  fin[u] = us.size();
}

int ppar[E][1 << E];
int lca(int u, int v) {
  for (int e = E; --e >= 0; ) {
    if (dep[u] - (1 << e) >= dep[v]) u = ppar[e][u];
    if (dep[u] <= dep[v] - (1 << e)) v = ppar[e][v];
  }
  for (int e = E; --e >= 0; ) {
    if (ppar[e][u] != ppar[e][v]) {
      u = ppar[e][u];
      v = ppar[e][v];
    }
  }
  if (u != v) {
    u = par[u];
    v = par[v];
  }
  assert(u == v);
  return u;
}

// ((t, (u, v)), coef): [0, t] * [u, v),  u in subtree(v),  \sum dist(u, v): small
vector<pair<pair<int, pair<int, int>>, int>> fs;
// [0, par[v]] * v
vector<Int> gs;
// ans[q] += hs[D[q]]
vector<Int> hs;

// [0, t] * u
void add(int t, int u, int coef) {
// cerr<<"  [add] "<<t<<" "<<u<<" "<<coef<<endl;
  if (~t) {
    if (t == par[u]) {
      gs[u] += coef;
    } else if (t == u) {
      gs[u] += coef;
      hs[C[u] + C[u]] += coef;
    } else {
      assert(false);
    }
  }
}
// [0, t] * [u, v)
void add(int t, int u, int v, int coef) {
  if (~t) {
    fs.emplace_back(make_pair(t, make_pair(u, v)), coef);
  }
}

// [s, t] -> [s, u]
void movePath(int s, int t, int u, int coef) {
  const int x = lca(s, t), y = lca(s, u), z = lca(t, u);
  if (x == y) {
// cerr<<"[movePath:"<<__LINE__<<"] s = "<<s<<", t = "<<t<<", u = "<<u<<", coef = "<<coef<<"; x = "<<x<<", y = "<<y<<", z = "<<z<<endl;
    /*
          x    
         / \   
        /   z  
       /   / \ 
      s   t   u
      [s, t] -> [s, z] -> [s, u]
    */
    for (int v = t; v != z; v = par[v]) {
      // - [s, x) * v
      // add(s, v, -coef);
      // add(x, v, +coef);
      // - [x, v) * v
      add(par[v], v, -coef);
      // add(par[x], v, +coef);
    }
    add(s, t, z, -coef);
    add(x, t, z, +coef);
    add(par[x], t, z, +coef);
    for (int v = u; v != z; v = par[v]) {
      // + [s, x) * v
      // add(s, v, +coef);
      // add(x, v, -coef);
      // + [x, v) * v
      add(par[v], v, +coef);
      // add(par[x], v, -coef);
    }
    add(s, u, z, +coef);
    add(x, u, z, -coef);
    add(par[x], u, z, -coef);
  } else if (x == z) {
// cerr<<"[movePath:"<<__LINE__<<"] s = "<<s<<", t = "<<t<<", u = "<<u<<", coef = "<<coef<<"; x = "<<x<<", y = "<<y<<", z = "<<z<<endl;
    /*
          x    
         / \   
        y   \  
       / \   \ 
      s   u   t
      [s, t] -> [s, x] -> [s, y] -> [s, u]
    */
    for (int v = t; v != x; v = par[v]) {
      // - [s, x) * v
      // add(s, v, -coef);
      // add(x, v, +coef);
      // - [x, v) * v
      add(par[v], v, -coef);
      // add(par[x], v, +coef);
    }
    add(s, t, x, -coef);
    add(x, t, x, +coef);
    add(par[x], t, x, +coef);
    for (int v = y; v != x; v = par[v]) {
      // - [s, v] * par[v]
      // add(s, par[v], -coef);
      add(par[v], par[v], +coef);
    }
    add(s, par[y], par[x], -coef);
    for (int v = u; v != y; v = par[v]) {
      // + [s, y) * v
      // add(s, v, +coef);
      // add(y, v, -coef);
      // + [y, v) * v
      add(par[v], v, +coef);
      // add(par[y], v, -coef);
    }
    add(s, u, y, +coef);
    add(y, u, y, -coef);
    add(par[y], u, y, -coef);
  } else if (y == z) {
// cerr<<"[movePath:"<<__LINE__<<"] s = "<<s<<", t = "<<t<<", u = "<<u<<", coef = "<<coef<<"; x = "<<x<<", y = "<<y<<", z = "<<z<<endl;
    /*
          y    
         / \   
        x   \  
       / \   \ 
      s   t   u
      [s, t] -> [s, x] -> [s, y] -> [s, u]
    */
    for (int v = t; v != x; v = par[v]) {
      // - [s, x) * v
      // add(s, v, -coef);
      // add(x, v, +coef);
      // - [x, v) * v
      add(par[v], v, -coef);
      // add(par[x], v, +coef);
    }
    add(s, t, x, -coef);
    add(x, t, x, +coef);
    add(par[x], t, x, +coef);
    for (int v = x; v != y; v = par[v]) {
      // + [s, v] * par[v]
      // add(s, par[v], +coef);
      add(par[v], par[v], -coef);
    }
    add(s, par[x], par[y], +coef);
    for (int v = u; v != y; v = par[v]) {
      // + [s, y) * v
      // add(s, v, +coef);
      // add(y, v, -coef);
      // + [y, v) * v
      add(par[v], v, +coef);
      // add(par[y], v, -coef);
    }
    add(s, u, y, +coef);
    add(y, u, y, -coef);
    add(par[y], u, y, -coef);
  } else assert(false);
}

int main() {
  for (; ~scanf("%d", &N); ) {
    C.resize(N);
    for (int u = 0; u < N; ++u) scanf("%d", &C[u]);
    A.resize(N - 1);
    B.resize(N - 1);
    for (int i = 0; i < N - 1; ++i) { scanf("%d%d", &A[i], &B[i]); --A[i]; --B[i]; }
    scanf("%d", &M);
    S.resize(M);
    T.resize(M);
    for (int m = 0; m < M; ++m) { scanf("%d%d", &S[m], &T[m]); --S[m]; --T[m]; }
    scanf("%d", &Q);
    D.resize(Q);
    for (int q = 0; q < Q; ++q) scanf("%d", &D[q]);
    
    assert(N <= 1 << E);
    
    graph.assign(N, {});
    for (int i = 0; i < N - 1; ++i) {
      graph[A[i]].push_back(B[i]);
      graph[B[i]].push_back(A[i]);
    }
    par.assign(N, -1);
    dep.assign(N, -1);
    dis.assign(N, -1);
    fin.assign(N, -1);
    us.clear();
    dfs(0, -1);
// cerr<<"dis = "<<dis<<", fin = "<<fin<<", us = "<<us<<endl;
    
    for (int u = 0; u < N; ++u) ppar[0][u] = par[u];
    for (int e = 0; e < E - 1; ++e) {
      for (int u = 0; u < N; ++u) {
        const int p = ppar[e][u];
        ppar[e + 1][u] = (~p) ? ppar[e][p] : -1;
      }
    }
    
    const int blockSize = max<int>((2*N - 1) / sqrt(2*M), 1);
    vector<pair<int, int>> paths(M);
    for (int m = 0; m < M; ++m) paths[m] = minmax(dis[S[m]], dis[T[m]]);
    sort(paths.begin(), paths.end(), [&](const pair<int, int> &lr0, const pair<int, int> &lr1) -> bool {
      const int k0 = lr0.first / blockSize;
      const int k1 = lr1.first / blockSize;
      return ((k0 != k1) ? (k0 < k1) : (k0 & 1) ? (lr0.second > lr1.second) : (lr0.second < lr1.second));
    });
// cerr<<"paths = "<<paths<<endl;
    
    fs.clear();
    gs.assign(N, 0);
    hs.assign(2*N + 1, 0);
    {
      int s = 0, t = 0;
      for (int m = 0; m < M; ++m) {
        const int u = us[paths[m].first];
        const int v = us[paths[m].second];
        movePath(t, s, u, M - m); s = u;
        movePath(u, t, v, M - m); t = v;
      }
    }
cerr<<"|fs| = "<<fs.size()<<endl;
    
    // sort by t
    vector<int> pt(N + 1, 0);
    {
      for (const auto &f : fs) ++pt[f.first.first];
      for (int t = 0; t < N; ++t) pt[t + 1] += pt[t];
      auto ffs = fs;
      for (const auto &f : fs) ffs[--pt[f.first.first]] = f;
      fs.swap(ffs);
    }
// cerr<<"pt = "<<pt<<", fs = "<<fs<<endl;
    
    vector<Int> ans(Q, 0);
    
    // u on [0, t]  <=>  dis[u] <= dis[t] < fin[u]
    vector<Int> pre(2*N + 1, 0);
    for (int i = 0; i < 2*N - 1; ++i) {
      const int u = us[i];
      if (i == dis[u]) {
        for (int q = 0; q < Q; ++q) {
          const int c = D[q] - C[u];
          if (c >= 0) ans[q] -= pre[c];
        }
        for (int k = pt[u]; k < pt[u + 1]; ++k) {
          const auto &f = fs[k];
          // pre[C[f.first.second]] += f.second;
          for (int v = f.first.second.first; v != f.first.second.second; v = par[v]) pre[C[v]] += f.second;
        }
        for (const int v : graph[u]) {
          pre[C[v]] += gs[v];
        }
      }
      if (i == fin[u] - 1) {
        for (int q = 0; q < Q; ++q) {
          const int c = D[q] - C[u];
          if (c >= 0) ans[q] += pre[c];
        }
      }
    }
    
    for (int q = 0; q < Q; ++q) ans[q] += hs[D[q]];
    
    for (int q = 0; q < Q; ++q) printf("%lld\n", ans[q]);
  }
  return 0;
}

컴파일 시 표준 에러 (stderr) 메시지

Main.cpp: In function 'int main()':
Main.cpp:237:38: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
  237 |     for (int u = 0; u < N; ++u) scanf("%d", &C[u]);
      |                                 ~~~~~^~~~~~~~~~~~~
Main.cpp:240:44: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
  240 |     for (int i = 0; i < N - 1; ++i) { scanf("%d%d", &A[i], &B[i]); --A[i]; --B[i]; }
      |                                       ~~~~~^~~~~~~~~~~~~~~~~~~~~~
Main.cpp:241:10: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
  241 |     scanf("%d", &M);
      |     ~~~~~^~~~~~~~~~
Main.cpp:244:40: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
  244 |     for (int m = 0; m < M; ++m) { scanf("%d%d", &S[m], &T[m]); --S[m]; --T[m]; }
      |                                   ~~~~~^~~~~~~~~~~~~~~~~~~~~~
Main.cpp:245:10: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
  245 |     scanf("%d", &Q);
      |     ~~~~~^~~~~~~~~~
Main.cpp:247:38: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
  247 |     for (int q = 0; q < Q; ++q) scanf("%d", &D[q]);
      |                                 ~~~~~^~~~~~~~~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...