Submission #1301472

#TimeUsernameProblemLanguageResultExecution timeMemory
1301472sitingfakeTourism (JOI23_tourism)C++20
100 / 100
2302 ms33040 KiB
#include <bits/stdc++.h>
using namespace std;

// define
#define execute cerr << " Time: " << fixed << setprecision(6) << (1.0 * clock() / CLOCKS_PER_SEC) << "s\n";
#define ll long long
#define ii pair<int,int>
#define iii pair<int,ii>
#define se second
#define fi first
#define all(v) (v).begin(), (v).end()
#define Unique(v) sort(all(v)), v.resize(unique(all(v)) - v.begin())
#define bit(x,i) (((x) >> (i)) & 1LL)
#define flip(x,i) ((x) ^ (1LL << (i)))
#define ms(d,x) memset(d , x , sizeof(d))

//constant
const long long mod = 1e9 + 7;
const long long linf = 4557430888798830399LL;
const long long nlinf = -4485090715960753727LL;
const int LOG = 20;
const int inf = 1061109567;
const int ninf = -1044266559;

const int maxn = 1e5 + 7;

struct Bitset {
    using ull = unsigned long long;
    // sizes chosen to safely handle n <= 100000
    ull group[1605];
    ull GroupOnline[50];
    ull mask[66];

    int SIZE = 0;
    int numGroup = 0;       // number of 64-bit groups
    int numOfnumGroup = 0;  // number of GroupOnline entries

    inline ull mask_up_to(int idx) const {
        // return mask that has bits [0..idx] set
        if(idx < 0) return 0ULL;
        if(idx >= 63) return ~0ULL;
        return ((1ULL << (idx + 1)) - 1ULL);
    }
    inline ull mask_from(int idx) const {
        // return mask that has bits [idx..63] set
        if(idx <= 0) return ~0ULL;
        if(idx >= 64) return 0ULL;
        ull low = ((idx - 1) >= 63) ? ~0ULL : ((1ULL << idx) - 1ULL);
        return (~low);
    }

    void init(int sz) {
        SIZE = sz;
        numGroup = (SIZE + 63) / 64;                  // ceil
        numOfnumGroup = (numGroup + 63) / 64;         // ceil
        // zero arrays
        for(int i = 0; i < 1605; ++i) group[i] = 0ULL;
        for(int i = 0; i < 50; ++i) GroupOnline[i] = 0ULL;
        // optional precomputed masks (not strictly necessary)
        for(int i=0;i<=63;i++) mask[i] = mask_up_to(i);
        mask[64] = ~0ULL;
    }

    void setOn(int x) {
        int id = x >> 6;
        int block = id >> 6;
        int b = id & 63;
        // mark group presence
        GroupOnline[block] |= (1ULL << b);
        // set bit inside group
        int bitpos = x & 63;
        group[id] |= (1ULL << bitpos);
    }

    void setOff(int x) {
        int id = x >> 6;
        int block = id >> 6;
        int bitpos = x & 63;
        group[id] &= ~(1ULL << bitpos);
        if(group[id] == 0ULL) {
            int b = id & 63;
            GroupOnline[block] &= ~(1ULL << b);
        }
    }

    // return largest index <= val which is set, or -1
    int lower(int val) {
        if(val < 0) return -1;
        int id = val >> 6;
        int off = val & 63;
        if(id >= numGroup) id = numGroup - 1; // safety
        // mask bits <= off
        ull w = group[id] & mask_up_to(off);
        if(w) {
            int high = 63 - __builtin_clzll(w);
            return (id << 6) + high;
        }
        // look for previous non-empty group in same GroupOnline block
        int block = id >> 6;
        int posInBlock = id & 63;
        // groups strictly before id in the same block -> bits [0..posInBlock-1]
        ull wb = 0ULL;
        if(posInBlock > 0) {
            wb = GroupOnline[block] & mask_up_to(posInBlock - 1);
        }
        if(wb) {
            int groupIndexInBlock = 63 - __builtin_clzll(wb);
            int g = (block << 6) + groupIndexInBlock;
            ull wg = group[g];
            int high = 63 - __builtin_clzll(wg);
            return (g << 6) + high;
        }
        // search earlier blocks
        for(int b = block - 1; b >= 0; --b) {
            if(GroupOnline[b]) {
                int groupIndexInBlock = 63 - __builtin_clzll(GroupOnline[b]);
                int g = (b << 6) + groupIndexInBlock;
                ull wg = group[g];
                int high = 63 - __builtin_clzll(wg);
                return (g << 6) + high;
            }
        }
        return -1;
    }

    // return smallest index >= val which is set, or -1
    int upper(int val) {
        if(val < 0) return -1;
        int id = val >> 6;
        int off = val & 63;
        if(id >= numGroup) return -1;
        // mask bits >= off
        ull w;
        if(off == 0) w = group[id];
        else if(off >= 63) w = (group[id] & (1ULL << 63));
        else w = group[id] & mask_from(off);
        if(w) {
            int low = __builtin_ctzll(w);
            return (id << 6) + low;
        }
        int block = id >> 6;
        int posInBlock = id & 63;
        // groups strictly after id in same block -> bits [posInBlock+1 .. 63]
        ull wb = 0ULL;
        if(posInBlock < 63) {
            // mask groups after posInBlock
            ull afterMask = (~mask_up_to(posInBlock));
            wb = GroupOnline[block] & afterMask;
        }
        if(wb) {
            int groupIndexInBlock = __builtin_ctzll(wb);
            int g = (block << 6) + groupIndexInBlock;
            ull wg = group[g];
            int low = __builtin_ctzll(wg);
            return (g << 6) + low;
        }
        // search later blocks
        for(int b = block + 1; b < numOfnumGroup; ++b) {
            if(GroupOnline[b]) {
                int groupIndexInBlock = __builtin_ctzll(GroupOnline[b]);
                int g = (b << 6) + groupIndexInBlock;
                if(g >= numGroup) continue; // safety
                ull wg = group[g];
                int low = __builtin_ctzll(wg);
                return (g << 6) + low;
            }
        }
        return -1;
    }
} B;

int n, m, q;
int BlockSize = 0;

int c[maxn], curNodes = 0, cnt[maxn];
int id[maxn];

struct Query {
    int l, r;
    int id;
    bool operator < (Query &other) {
        if ((l / BlockSize) == (other.l / BlockSize)) return r < other.r;
        return l < other.l;
    }
} Q[maxn];

int ans[maxn], sumDist;

vector<int> adj[maxn];
int in[maxn], out[maxn], depth[maxn], timer;
int arr[maxn * 2][20];
int tin[maxn], tout[maxn], timer2;

void dfs(int u, int p) {
    in[u] = ++timer;
    tin[u] = ++timer2;
    arr[timer][0] = u;
    for (int v : adj[u]) {
        if (v == p) continue;
        depth[v] = depth[u] + 1;
        dfs(v, u);
        arr[++timer][0] = u;
    }
    out[u] = timer;
    tout[u] = timer2;
}

void BuildLca() {
    for (int j = 1; (1 << j) <= timer; ++j) {
        for (int i = 1; i + (1 << j) - 1 <= timer; ++i) {
            int a = arr[i][j - 1], b = arr[i + (1 << (j - 1))][j - 1];
            arr[i][j] = (depth[a] < depth[b]) ? a : b;
        }
    }
}

int lca(int u, int v) {
    int l = in[u], r = in[v];
    if (l > r) swap(l, r);
    int k = __lg(r - l + 1);
    if (depth[arr[l][k]] < depth[arr[r - (1 << k) + 1][k]]) return arr[l][k];
    return arr[r - (1 << k) + 1][k];
}

int dist(int u, int v) {
    return depth[u] + depth[v] - 2 * depth[lca(u, v)];
}

void AddSet(int u) {
    if (curNodes == 1) {
        B.setOn(u);
        return;
    }

    int l = B.lower(u);
    if (l == -1) {
        if (cnt[n - 1] != 0) l = n - 1;
        else l = B.lower(n - 1);
    }

    int r = B.upper(u);
    if (r == -1) {
        if (cnt[0] != 0) r = 0;
        else r = B.upper(0);
    }
    assert(l != -1 && r != -1);
    sumDist -= dist(id[l], id[r]);
    sumDist += dist(id[l], id[u]);
    sumDist += dist(id[u], id[r]);
    B.setOn(u);
}

void DelSet(int u) {
    if (curNodes <= 1) {
        sumDist = 0;
        B.setOff(u);
        return;
    }
    B.setOff(u);
    int l = B.lower(u);
    if (l == -1) {
        if (cnt[n - 1]) l = n - 1;
        else l = B.lower(n - 1);
    }
    int r = B.upper(u);
    if (r == -1) {
        if (cnt[0]) r = 0;
        else r = B.upper(0);
    }
    assert(l != -1 && r != -1);

    sumDist -= dist(id[l], id[u]);
    sumDist -= dist(id[r], id[u]);
    sumDist += dist(id[l], id[r]);
}

void add_Node(int u) {
    if (cnt[u] == 0) {
        ++curNodes;
        AddSet(u);
    }
    ++cnt[u];
}

void del_Node(int u) {
    if (cnt[u] == 1) {
        --curNodes;
        DelSet(u);
    }
    --cnt[u];
}

int L = 1, R = 0;

void Mo(const int &l, const int &r) {
    while (R < r) { ++R; add_Node(tin[c[R]]); }
    while (L > l) { --L; add_Node(tin[c[L]]); }
    while (L < l) { del_Node(tin[c[L]]); ++L; }
    while (R > r) { del_Node(tin[c[R]]); --R; }
}

void solve(void) {
    cin >> n >> m >> q;
    // clear
    for (int i = 0; i < n; ++i) {
        adj[i].clear();
        cnt[i] = 0;
    }
    curNodes = 0;
    sumDist = 0;
    timer = timer2 = 0;
    for (int i = 1; i < n; ++i) {
        int u, v;
        cin >> u >> v;
        --u; --v;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
    for (int i = 1; i <= m; ++i) {
        cin >> c[i];
        --c[i];
    }
    for (int i = 1; i <= q; ++i) {
        cin >> Q[i].l >> Q[i].r;
        Q[i].id = i;
    }
    // root at 0 (nodes are 0-based after decrement)
    depth[0] = 0;
    dfs(0, -1);
    BuildLca();
    B.init(n);
    for (int i = 0; i < n; ++i) {
        tin[i]--;
        id[tin[i]] = i;
    }
    BlockSize = max(1, (int)sqrt(m));
    sort(Q + 1, Q + q + 1);
    for (int i = 1; i <= q; ++i) {
        Mo(Q[i].l, Q[i].r);
        ans[Q[i].id] = (sumDist / 2) + 1;
    }
    for (int i = 1; i <= q; ++i) {
        cout << ans[i] << "\n";
    }
}

/**
Sample input (from your comment)
7 6 1
1 2
1 3
2 4
2 5
3 6
3 7
2 3 6 4 5 7
4 6
**/

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);

    int tc = 1;
    //cin >> tc;
    while (tc--) solve();
    //execute;
    return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...