제출 #985559

#제출 시각아이디문제언어결과실행 시간메모리
985559maomao90JOI tour (JOI24_joitour)C++17
86 / 100
3030 ms527048 KiB

// Hallelujah, praise the one who set me free
// Hallelujah, death has lost its grip on me
// You have broken every chain, There's salvation in your name
// Jesus Christ, my living hope
#include <bits/stdc++.h> 
#include "joitour.h"
using namespace std;

#define REP(i, s, e) for (int i = (s); i < (e); i++)
#define RREP(i, s, e) for (int i = (s); i >= (e); i--)
template <class T>
inline bool mnto(T& a, T b) {return a > b ? a = b, 1 : 0;}
template <class T>
inline bool mxto(T& a, T b) {return a < b ? a = b, 1: 0;}

typedef unsigned long long ull;
typedef long long ll;
typedef long double ld;
#define FI first
#define SE second
typedef pair<int, int> ii;
typedef pair<ll, ll> pll;
typedef tuple<int, int, int> iii;
#define ALL(_a) _a.begin(), _a.end()
#define SZ(_a) (int) _a.size()
#define pb push_back
typedef vector<int> vi;
typedef vector<ll> vll;
typedef vector<ii> vii;
typedef vector<iii> viii;

#ifndef DEBUG
#define cerr if (0) cerr
#endif

const int INF = 1000000005;
const ll LINF = 1000000000000000005ll;
const int MAXN = 200005;
const int MAXL = 19;

int n;
int f[MAXN];
vi adj[MAXN];
ll ans;

int sub[MAXN], cl[MAXN], pre[MAXL][MAXN], ipre[MAXL][MAXN], pst[MAXL][MAXN], oc[MAXN], ptr;
ll bsm[MAXN][2], sm[MAXN][2];
int fsub(int u, int p, int l) {
    ipre[l][ptr] = u;
    pre[l][u] = ptr++;
    oc[u] += f[u] == 1;
    REP (z, 0, 2) {
        sm[u][z] = (z << 1) == f[u];
        bsm[u][z] = sm[u][z] * oc[u];
    }
    sub[u] = 1;
    for (int v : adj[u]) {
        if (cl[v] != -1 || v == p) {
            continue;
        }
        oc[v] = oc[u];
        sub[u] += fsub(v, u, l);
        REP (z, 0, 2) {
            sm[u][z] += sm[v][z];
            bsm[u][z] += bsm[v][z];
        }
    }
    pst[l][u] = ptr;
    return sub[u];
}
int fcent(int u, int p, int s) {
    for (int v : adj[u]) {
        if (cl[v] != -1 || v == p) {
            continue;
        }
        if (sub[v] > s / 2) {
            return fcent(v, u, s);
        }
    }
    return u;
}
int cp[MAXN];
struct Solver {
    int r;
    ll pairs;
    vi ch;
    struct Node {
        ll bsm[2];
        int sm[2], lz, oc;
    };
#define MLR int mid = lo + hi >> 1, lc = u << 1, rc = u << 1 ^ 1
    vector<Node> seg;
    void apply(int u, int lo, int hi, int x) {
        seg[u].lz += x;
        seg[u].oc += x;
        REP (z, 0, 2) {
            seg[u].bsm[z] += x * seg[u].sm[z];
        }
    }
    void pull(int u, int lo, int hi) {
        MLR;
        REP (z, 0, 2) {
            seg[u].bsm[z] = seg[lc].bsm[z] + seg[rc].bsm[z];
            seg[u].sm[z] = seg[lc].sm[z] + seg[rc].sm[z];
        }
        seg[u].oc = max(seg[lc].oc, seg[rc].oc);
    }
    void propo(int u, int lo, int hi) {
        if (!seg[u].lz) {
            return;
        }
        MLR;
        apply(lc, lo, mid, seg[u].lz);
        apply(rc, mid + 1, hi, seg[u].lz);
        seg[u].lz = 0;
    }
    void init(int u, int lo, int hi) {
        seg[u].lz = 0;
        if (lo == hi) {
            int v = ipre[cl[r]][lo];
            seg[u].oc = oc[v];
            REP (z, 0, 2) {
                seg[u].sm[z] = (z << 1) == f[v];
                seg[u].bsm[z] = seg[u].oc * seg[u].sm[z];
            }
        } else {
            MLR;
            init(lc, lo, mid);
            init(rc, mid + 1, hi);
            pull(u, lo, hi);
        }
    }
    void init() {
        init(1, 0, sub[r] - 1);
    }
    void incre(int s, int e, int x, int u, int lo, int hi) {
        if (lo >= s && hi <= e) {
            apply(u, lo, hi, x);
            return;
        }
        MLR;
        propo(u, lo, hi);
        if (s <= mid) {
            incre(s, e, x, lc, lo, mid);
        }
        if (e > mid) {
            incre(s, e, x, rc, mid + 1, hi);
        }
        pull(u, lo, hi);
    }
    void incre(int s, int e, int x) {
        incre(s, e, x, 1, 0, sub[r] - 1);
    }
    void upd(int p, int f, int u, int lo, int hi) {
        if (lo == hi) {
            REP (z, 0, 2) {
                seg[u].sm[z] = (z << 1) == f;
                seg[u].bsm[z] = seg[u].sm[z] * seg[u].oc;
            }
            return;
        }
        MLR;
        propo(u, lo, hi);
        if (p <= mid) {
            upd(p, f, lc, lo, mid);
        } else {
            upd(p, f, rc, mid + 1, hi);
        }
        pull(u, lo, hi);
    }
    void upd(int p, int f) {
        upd(p, f, 1, 0, sub[r] - 1);
    }
    struct Value {
        ll bsm[2];
        int sm[2];
        Value() {
            bsm[0] = bsm[1] = sm[0] = sm[1] = 0;
        }
        Value& operator+= (const Value &o) {
            REP (z, 0, 2) {
                bsm[z] += o.bsm[z];
                sm[z] += o.sm[z];
            }
            return *this;
        }
        Value operator+ (const Value &o) const {
            Value res = *this;
            return res += o;
        }
        friend ostream& operator<<(ostream &os, const Value &o) {
            return os << "({" << o.bsm[0] << ", " << o.bsm[1] << "}, {" << o.sm[0] << ", " << o.sm[1] << "})";
        }
    };
    Value qsm(int s, int e, int u, int lo, int hi) {
        if (s > e) {
            return Value();
        }
        if (lo >= s && hi <= e) {
            Value res = Value();
            REP (z, 0, 2) {
                res.bsm[z] = seg[u].bsm[z];
                res.sm[z] = seg[u].sm[z];
            }
            return res;
        }
        MLR;
        propo(u, lo, hi);
        Value res = Value();
        if (s <= mid) {
            res += qsm(s, e, lc, lo, mid);
        }
        if (e > mid) {
            res += qsm(s, e, rc, mid + 1, hi);
        }
        return res;
    }
    Value qsm(int s, int e) {
        return qsm(s, e, 1, 0, sub[r] - 1);
    }
    pll calc(int s, int e) {
        Value in = qsm(s, e), out = qsm(0, s - 1) + qsm(e + 1, sub[r] - 1);
        cerr << "   CALC " << s << ' ' << e << ": " << in << ' ' << out << '\n';
        pll res = {0, 0};
        REP (z, 0, 2) {
            res.FI += in.bsm[z] * out.sm[z ^ 1] + in.sm[z] * out.bsm[z ^ 1];
            res.SE += in.sm[z] * out.sm[z ^ 1];
        }
        if (f[r] == 1) {
            res.FI -= res.SE;
        }
        cerr << "   " << res.FI << ' ' << res.SE << '\n';
        return res;
    }
    Solver() {};
    Solver(int r): r(r), pairs(0), seg(sub[r] * 4), ch(0) {
        init();
        /*
        REP (i, 0, sub[r]) {
            int u = ipre[cl[r]][i];
            upd(pre[cl[r]][u], f[u]);
            if (f[u] == 1) {
                incre(pre[cl[r]][u], pst[cl[r]][u] - 1, 1);
            }
        }
        */
        ch.pb(0);
        for (int v : adj[r]) {
            if (cl[v] != -1) {
                continue;
            }
            ch.pb(pre[cl[r]][v]);
        }
        ch.pb(sub[r]);
        ll res = 0;
        REP (i, 0, SZ(ch) - 1) {
            auto [a, b] = calc(ch[i], ch[i + 1] - 1);
            res += a;
            pairs += b;
        }
        pairs /= 2;
        res /= 2;
        ans += res;
        /*
        ll cbsm[2] = {0, 0}, csm[2] = {0, 0};
        REP (z, 0, 2) {
            csm[z] = (1 << z) == f[r];
        }
        for (int v : adj[r]) {
            if (cl[v] != -1) {
                continue;
            }
            REP (z, 0, 2) {
                pairs += csm[z] * sm[v][z ^ 1];
                ans += csm[z] * bsm[v][z ^ 1] + cbsm[z] * sm[v][z ^ 1];
            }
            REP (z, 0, 2) {
                csm[z] += sm[v][z];
                cbsm[z] += bsm[v][z];
            }
        }
        if (f[r] == 1) {
            ans -= pairs;
        }
        */
    }
} sol[MAXN];
void build(int u, int p, int l) {
    ptr = 0;
    u = fcent(u, -1, fsub(u, -1, l));
    cp[u] = p;
    cl[u] = l;
    ptr = 0;
    oc[u] = 0;
    fsub(u, -1, l);
    sol[u] = Solver(u);
    for (int v : adj[u]) {
        if (cl[v] != -1) {
            continue;
        }
        build(v, u, l + 1);
    }
}

void init(int N, vi F, vi U, vi V, int Q) {
    n = N;
    REP (i, 0, n) {
        f[i] = F[i];
    }
    REP (i, 0, n - 1) {
        adj[U[i]].pb(V[i]);
        adj[V[i]].pb(U[i]);
    }

    ans = 0;
    REP (i, 0, n) {
        cl[i] = -1;
    }
    build(0, -1, 0);
}

void change(int x, int y) {
    if (f[x] == 1) {
        ans -= sol[x].pairs;
    }
    int u = x;
    RREP (l, cl[x], 0) {
        int id = upper_bound(ALL(sol[u].ch), pre[l][x]) - sol[u].ch.begin();
        auto [a, b] = sol[u].calc(sol[u].ch[id - 1], sol[u].ch[id] - 1);
        ans -= a;
        sol[u].pairs -= b;
        if (f[x] == 1) {
            sol[u].incre(pre[l][x], pst[l][x] - 1, -1);
        }
        u = cp[u];
    }
    f[x] = y;
    u = x;
    RREP (l, cl[x], 0) {
        int id = upper_bound(ALL(sol[u].ch), pre[l][x]) - sol[u].ch.begin();
        sol[u].upd(pre[l][x], f[x]);
        if (y == 1) {
            sol[u].incre(pre[l][x], pst[l][x] - 1, 1);
        }
        auto [a, b] = sol[u].calc(sol[u].ch[id - 1], sol[u].ch[id] - 1);
        ans += a;
        sol[u].pairs += b;
        u = cp[u];
    }
    if (f[x] == 1) {
        ans += sol[x].pairs;
    }
}

ll num_tours() {
  return ans;
}

컴파일 시 표준 에러 (stderr) 메시지

joitour.cpp: In member function 'void Solver::pull(int, int, int)':
joitour.cpp:92:26: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
   92 | #define MLR int mid = lo + hi >> 1, lc = u << 1, rc = u << 1 ^ 1
      |                       ~~~^~~~
joitour.cpp:102:9: note: in expansion of macro 'MLR'
  102 |         MLR;
      |         ^~~
joitour.cpp:92:17: warning: unused variable 'mid' [-Wunused-variable]
   92 | #define MLR int mid = lo + hi >> 1, lc = u << 1, rc = u << 1 ^ 1
      |                 ^~~
joitour.cpp:102:9: note: in expansion of macro 'MLR'
  102 |         MLR;
      |         ^~~
joitour.cpp: In member function 'void Solver::propo(int, int, int)':
joitour.cpp:92:26: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
   92 | #define MLR int mid = lo + hi >> 1, lc = u << 1, rc = u << 1 ^ 1
      |                       ~~~^~~~
joitour.cpp:113:9: note: in expansion of macro 'MLR'
  113 |         MLR;
      |         ^~~
joitour.cpp: In member function 'void Solver::init(int, int, int)':
joitour.cpp:92:26: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
   92 | #define MLR int mid = lo + hi >> 1, lc = u << 1, rc = u << 1 ^ 1
      |                       ~~~^~~~
joitour.cpp:128:13: note: in expansion of macro 'MLR'
  128 |             MLR;
      |             ^~~
joitour.cpp: In member function 'void Solver::incre(int, int, int, int, int, int)':
joitour.cpp:92:26: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
   92 | #define MLR int mid = lo + hi >> 1, lc = u << 1, rc = u << 1 ^ 1
      |                       ~~~^~~~
joitour.cpp:142:9: note: in expansion of macro 'MLR'
  142 |         MLR;
      |         ^~~
joitour.cpp: In member function 'void Solver::upd(int, int, int, int, int)':
joitour.cpp:92:26: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
   92 | #define MLR int mid = lo + hi >> 1, lc = u << 1, rc = u << 1 ^ 1
      |                       ~~~^~~~
joitour.cpp:163:9: note: in expansion of macro 'MLR'
  163 |         MLR;
      |         ^~~
joitour.cpp: In member function 'Solver::Value Solver::qsm(int, int, int, int, int)':
joitour.cpp:92:26: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
   92 | #define MLR int mid = lo + hi >> 1, lc = u << 1, rc = u << 1 ^ 1
      |                       ~~~^~~~
joitour.cpp:208:9: note: in expansion of macro 'MLR'
  208 |         MLR;
      |         ^~~
joitour.cpp: In constructor 'Solver::Solver(int)':
joitour.cpp:93:18: warning: 'Solver::seg' will be initialized after [-Wreorder]
   93 |     vector<Node> seg;
      |                  ^~~
joitour.cpp:87:8: warning:   'vi Solver::ch' [-Wreorder]
   87 |     vi ch;
      |        ^~
joitour.cpp:237:5: warning:   when initialized here [-Wreorder]
  237 |     Solver(int r): r(r), pairs(0), seg(sub[r] * 4), ch(0) {
      |     ^~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...