Submission #1102042

#TimeUsernameProblemLanguageResultExecution timeMemory
1102042PacybwoahJOI tour (JOI24_joitour)C++17
100 / 100
1280 ms153824 KiB
#include "joitour.h"
#include<algorithm>
#include<vector>
#include<utility>
#include<iostream>
#include<algorithm>
using namespace std;
typedef long long ll;

namespace{
    int n;
    ll cnta = 0, cntb = 0, cntc = 0, allp, ans = 0, allw = 0;
    vector<int> cur, head, in, out, ord, sz, par, wid;
    vector<ll> wall;
    vector<pair<int, int>> hans, uans;
    int timer = 0;
    vector<vector<int>> graph;
    vector<vector<pair<int, int>>> wans;
    void dfs(int node, int parent){
        sz[node] = 1;
        par[node] = parent;
        head[node] = node;
        int maxi = 0, pos = -1;
        for(auto &x: graph[node]){
            if(x == parent) continue;
            dfs(x, node);
            sz[node] += sz[x];
            if(sz[x] > maxi){
                maxi = sz[x];
                pos = x;
            }
        }
        if(pos > 0) head[pos] = node;
        for(auto &x: graph[node]){
            if(x == pos) swap(x, graph[node][0]);
        }
    }
    void dfs_head(int node, int parent){
        ord[++timer] = node;
        in[node] = timer;
        if(head[node] != node) head[node] = head[parent];
        int i = -1;
        for(auto &x: graph[node]){
            i++;
            if(x == parent) continue;
            wid[x] = i;
            dfs_head(x, node);
            uans[node].first += uans[x].first;
            uans[node].second += uans[x].second;
        }
        if(cur[node] == 0) uans[node].first++;
        else if(cur[node] == 2) uans[node].second++;
        i = -1;
        for(auto &x: graph[node]){
            i++;
            if(x == parent) continue;
            if(i == 0) hans[node] = uans[x];
            else wans[node][i] = uans[x];
        }
        out[node] = timer;
    }
    struct node{
        ll ca = 0, cc = 0, acta = 0, actc = 0, act = 0, sum = 0, atag = 0, ctag = 0;
    };
    struct st{
        vector<node> seg;
        // in[x], ord[x]
        void pull(int l, int r, int ind){
            seg[ind].ca = seg[ind * 2].ca + seg[ind * 2 + 1].ca;
            seg[ind].cc = seg[ind * 2].cc + seg[ind * 2 + 1].cc;
            seg[ind].acta = seg[ind * 2].acta + seg[ind * 2 + 1].acta;
            seg[ind].actc = seg[ind * 2].actc + seg[ind * 2 + 1].actc;
            seg[ind].act = seg[ind * 2].act + seg[ind * 2 + 1].act;
            seg[ind].sum = seg[ind * 2].sum + seg[ind * 2 + 1].sum;
        }
        void push(int l, int r, int ind){
            if(l == r) return;
            int mid = (l + r) >> 1;
            seg[ind * 2].sum += seg[ind].atag * seg[ind * 2].actc + seg[ind].ctag * seg[ind * 2].acta + seg[ind].atag * seg[ind].ctag * seg[ind * 2].act;
            seg[ind * 2 + 1].sum += seg[ind].atag * seg[ind * 2 + 1].actc + seg[ind].ctag * seg[ind * 2 + 1].acta + seg[ind].atag * seg[ind].ctag * seg[ind * 2 + 1].act;
            seg[ind * 2].acta += seg[ind].atag * seg[ind * 2].act;
            seg[ind * 2 + 1].acta += seg[ind].atag * seg[ind * 2 + 1].act;
            seg[ind * 2].actc += seg[ind].ctag * seg[ind * 2].act;
            seg[ind * 2 + 1].actc += seg[ind].ctag * seg[ind * 2 + 1].act;
            seg[ind * 2].ca += seg[ind].atag * (mid - l + 1);
            seg[ind * 2 + 1].ca += seg[ind].atag * (r - mid);
            seg[ind * 2].cc += seg[ind].ctag * (mid - l + 1);
            seg[ind * 2 + 1].cc += seg[ind].ctag * (r - mid);
            seg[ind * 2].atag += seg[ind].atag;
            seg[ind * 2 + 1].atag += seg[ind].atag;
            seg[ind * 2].ctag += seg[ind].ctag;
            seg[ind * 2 + 1].ctag += seg[ind].ctag;
            seg[ind].atag = 0;
            seg[ind].ctag = 0;
        }
        void build(int l, int r, int ind, vector<pair<int, int>> &vec){
            if(l == r){
                if(cur[ord[l]] == 1) seg[ind].act = 1;
                seg[ind].ca = vec[ord[l]].first;
                seg[ind].cc = vec[ord[l]].second;
                seg[ind].acta = seg[ind].act * seg[ind].ca;
                seg[ind].actc = seg[ind].act * seg[ind].cc;
                seg[ind].sum = seg[ind].acta * seg[ind].actc;
                return;
            }
            int mid = (l + r) >> 1;
            build(l, mid, ind * 2, vec);
            build(mid + 1, r, ind * 2 + 1, vec);
            pull(l, r, ind);
        }
        void modify(int l, int r, int start, int end, ll anum, ll cnum, int ind){
            if(r < start || end < l) return;
            if(start <= l && r <= end){
                seg[ind].atag += anum;
                seg[ind].ctag += cnum;
                seg[ind].ca += anum * (r - l + 1);
                seg[ind].cc += cnum * (r - l + 1);
                seg[ind].sum += anum * seg[ind].actc + cnum * seg[ind].acta + anum * cnum * seg[ind].act;
                seg[ind].acta += anum * seg[ind].act;
                seg[ind].actc += cnum * seg[ind].act;
                return;
            }
            int mid = (l + r) >> 1;
            push(l, r, ind);
            modify(l, mid, start, end, anum, cnum, ind * 2);
            modify(mid + 1, r, start, end, anum, cnum, ind * 2 + 1);
            pull(l, r, ind);
        }
        void toggle(int l, int r, int pos, int ind){
            if(l == r){
                seg[ind].act = 1 - seg[ind].act;
                seg[ind].acta = seg[ind].ca * seg[ind].act;
                seg[ind].actc = seg[ind].cc * seg[ind].act;
                seg[ind].sum = seg[ind].acta * seg[ind].actc;
                return;
            }
            int mid = (l + r) >> 1;
            push(l, r, ind);
            if(pos <= mid) toggle(l, mid, pos, ind * 2);
            else toggle(mid + 1, r, pos, ind * 2 + 1);
            pull(l, r, ind);
        }
        /*void toga(int l, int r, int pos, int ind){
            if(l == r){
                seg[ind].ca = 1 - seg[ind].ca;
                seg[ind].acta = seg[ind].ca * seg[ind].act;
                seg[ind].actc = seg[ind].cc * seg[ind].act;
                seg[ind].sum = seg[ind].acta * seg[ind].actc;
                return;
            }
            int mid = (l + r) >> 1;
            push(l, r, ind);
            if(pos <= mid) toggle(l, mid, pos, ind * 2);
            else toggle(mid + 1, r, pos, ind * 2 + 1);
        }
        void togc(int l, int r, int pos, int ind){
            if(l == r){
                seg[ind].cc = 1 - seg[ind].cc;
                seg[ind].acta = seg[ind].ca * seg[ind].act;
                seg[ind].actc = seg[ind].cc * seg[ind].act;
                seg[ind].sum = seg[ind].acta * seg[ind].actc;
                return;
            }
            int mid = (l + r) >> 1;
            push(l, r, ind);
            if(pos <= mid) toggle(l, mid, pos, ind * 2);
            else toggle(mid + 1, r, pos, ind * 2 + 1);
        }*/
    };
    st hseg, useg;
}
void init(int N, std::vector<int> F, std::vector<int> U, std::vector<int> V, int Q) {
    n = N;
    cur.resize(n + 1);
    graph.resize(n + 1);
    head.resize(n + 1);
    in.resize(n + 1);
    out.resize(n + 1);
    ord.resize(n + 1);
    sz.resize(n + 1);
    par.resize(n + 1);
    hans.resize(n + 1);
    uans.resize(n + 1);
    wid.resize(n + 1);
    wall.resize(n + 1);
    wans.resize(n + 1);
    for(int i = 0; i < n; i++){
        cur[i + 1] = F[i];
        if(F[i] == 0) cnta++;
        else if(F[i] == 1) cntb++;
        else cntc++;
    }
    allp = cnta * cntb * cntc;
    for(int i = 0; i < n - 1; i++){
        U[i]++;
        V[i]++;
        graph[U[i]].push_back(V[i]);
        graph[V[i]].push_back(U[i]);
    }
    for(int i = 1; i <= n; i++){
        wans[i].resize((int)graph[i].size());
    }
    dfs(1, 1);
    dfs_head(1, 1);
    for(int i = 1; i <= n; i++){
        uans[i].first = cnta - uans[i].first;
        uans[i].second = cntc - uans[i].second;
    }
    for(int i = 1; i <= n; i++){
        int szz = graph[i].size();
        for(int j = 0; j < szz; j++) wall[i] += 1ll * wans[i][j].first * wans[i][j].second;
        if(cur[i] == 1) allw += wall[i];
    }
    hseg.seg.resize(4 * n + 4);
    useg.seg.resize(4 * n + 4);
    hseg.build(1, n, 1, hans);
    useg.build(1, n, 1, uans);
}
void change(int X, int Y) {
    int v = X + 1, y = Y;
    if(cur[v] == 0) cnta--;
    else if(cur[v] == 1) cntb--;
    else cntc--;
    if(y == 0) cnta++;
    else if(y == 1) cntb++;
    else cntc++;
    allp = cnta * cntb * cntc;
    if(cur[v] == 1){
        hseg.toggle(1, n, in[v], 1);
        useg.toggle(1, n, in[v], 1);
        allw -= wall[v];
    }
    if(y == 1){
        hseg.toggle(1, n, in[v], 1);
        useg.toggle(1, n, in[v], 1);
        allw += wall[v];
    }
    /*if(cur[v] == 0 || y == 0){
        hseg.toga(1, n, in[v], 1);
        useg.toga(1, n, in[v], 1);
    }
    if(cur[v] == 2 || y == 2){
        hseg.togc(1, n, in[v], 1);
        useg.togc(1, n, in[v], 1);
    }*/
    auto add = [&](int tp, int val){
        int now = v;
        if(tp == 0) useg.modify(1, n, 1, n, val, 0, 1);
        else useg.modify(1, n, 1, n, 0, val, 1);
        int lh = -1;
        while(1){
            if(lh != -1){
                if(cur[now] == 1) allw -= wall[now];
                if(tp == 0){
                    wall[now] += val * wans[now][lh].second;
                    wans[now][lh].first += val;
                }
                else{
                    wall[now] += val * wans[now][lh].first;
                    wans[now][lh].second += val;
                }
                if(cur[now] == 1) allw += wall[now];
            }
            if(tp == 0) useg.modify(1, n, in[head[now]], in[now], -val, 0, 1);
            else useg.modify(1, n, in[head[now]], in[now], 0, -val, 1);
            if(now != head[now]){
                if(tp == 0) hseg.modify(1, n, in[head[now]], in[now] - 1, val, 0, 1);
                else hseg.modify(1, n, in[head[now]], in[now] - 1, 0, val, 1);
            }
            now = head[now];
            if(now == 1) break;
            lh = wid[now];
            now = par[now];
        }
    };
    if(cur[v] != 1) add(cur[v], -1);
    if(y != 1) add(y, 1);
    cur[v] = y;
}
long long num_tours() {
    //cout << allp << " " << allw << " " << useg.seg[1].sum << " " << hseg.seg[1].sum << "\n";
    //cout << hseg.seg[1].ca << " " << hseg.seg[1].cc << " " << hseg.seg[1].act << " " << hseg.seg[1].acta << " " << hseg.seg[1].actc << "\n";
    return allp - allw - useg.seg[1].sum - hseg.seg[1].sum;
}
// g++ -std=c++17 -Wall -Wextra -Wshadow -fsanitize=undefined -fsanitize=address -o run grader.cpp joitour.cpp
/*
7
1 0 2 2 0 1 0
0 1
0 2
1 3
1 4
2 5
2 6
1
0 0
*/

Compilation message (stderr)

joitour.cpp:12:44: warning: '{anonymous}::ans' defined but not used [-Wunused-variable]
   12 |     ll cnta = 0, cntb = 0, cntc = 0, allp, ans = 0, allw = 0;
      |                                            ^~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...