Submission #1330406

#TimeUsernameProblemLanguageResultExecution timeMemory
1330406Zbyszek99JOI tour (JOI24_joitour)C++20
48 / 100
3082 ms896804 KiB
#include "joitour.h"
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#define ll long long
#define ld long double
#define ull unsigned long long
#define ff first
#define ss second
#define pii pair<int,int>
#define pll pair<long long, long long>
#define vi vector<int>
#define vl vector<long long>
#define pb push_back
#define rep(i, b) for(int i = 0; i < (b); ++i)
#define rep2(i,a,b) for(int i = a; i <= (b); ++i)
#define rep3(i,a,b,c) for(int i = a; i <= (b); i+=c)
#define count_bits(x) __builtin_popcountll((x))
#define all(x) (x).begin(),(x).end()
#define siz(x) (int)(x).size()
#define forall(it,x) for(auto& it:(x))
using namespace __gnu_pbds;
using namespace std;
typedef tree<int, null_type, less<int>, rb_tree_tag,tree_order_statistics_node_update> ordered_set;
//mt19937 mt;void random_start(){mt.seed(chrono::time_point_cast<chrono::milliseconds>(chrono::high_resolution_clock::now()).time_since_epoch().count());}
//ll los(ll a, ll b) {return a + (mt() % (b-a+1));}
const int INF = 1e9+50;
const ll INF_L = 1e18+40;
const ll MOD = 1e9+7;

struct elm
{
    int centr, p1, p2, s1, s2;
};

struct segtree
{
    int tree_siz;
    ll* sum;
    ll* oper;
    void setup(int n)
    {
        tree_siz = (1<<(__lg(n)+2))-1;
        sum = new ll[tree_siz+1];
        oper = new ll[tree_siz+1];
        rep(i,tree_siz+1)
        {
            sum[i] = 0;
            oper[i] = 0;
        }
    }
    void spych(int akt)
    {
        sum[akt*2+1] += oper[akt];
        sum[akt*2] += oper[akt];
        oper[akt*2+1] += oper[akt];
        oper[akt*2] += oper[akt];
        oper[akt] = 0;
    }
    void add_seg2(int akt, int p1, int p2, int s1, int s2, ll x)
    {
        if(p2 < s1 || p1 > s2) return;
        if(p1 >= s1 && p2 <= s2)
        {
            sum[akt] += x;
            oper[akt] += x;
            return;
        }
        spych(akt);
        add_seg2(akt*2,p1,(p1+p2)/2,s1,s2,x);
        add_seg2(akt*2+1,(p1+p2)/2+1,p2,s1,s2,x);
        sum[akt] = sum[akt*2]+sum[akt*2+1];
    }
    ll get_sum2(int akt, int p1, int p2, int s1, int s2)
    {
        if(p2 < s1 || p1 > s2) return 0;
        if(p1 >= s1 && p2 <= s2) return sum[akt];
        spych(akt);
        return get_sum2(akt*2,p1,(p1+p2)/2,s1,s2)+get_sum2(akt*2+1,(p1+p2)/2+1,p2,s1,s2);
    }
    void add_seg(int l, int r, ll x)
    {
        add_seg2(1,0,tree_siz/2,l,r,x);
    }
    ll get_sum(int l, int r)
    {
        return get_sum2(1,0,tree_siz/2,l,r);
    }
};

ll cur_ans = 0;
int F[200001];
vi graph[200001];
bool odw[200001];
int sub[200001];
int cur_pre = 1;
int pre[200001];
int maxpre[200001];
vector<elm> elms[200001];

struct centr
{
    segtree tree_1,tree_2,tree_3;
    int n;
    int cc = -1;
    ll sum_1_2 = 0, sum_3_2 = 0, sum_1_3 = 0, sum_1 = 0, sum_3 = 0;
    unordered_map<int,ll> map_1_2, map_3_2, map_1, map_3;
    void setup(int N)
    {
        n = N+2;
        tree_1.setup(n);
        tree_2.setup(n);
        tree_3.setup(n);
    }
    void set_vert(int p1, int p2, int s1, int s2, int f, int p)
    {
        if(p1 != 1)
        {
            if(f == 1)
            {
                int two = tree_2.get_sum(p1,p1);
                cur_ans += (two*(sum_3-map_3[s1])+sum_3_2-map_3_2[s1])*p;
                tree_1.add_seg(p1,p1,p);
                sum_1_2 += (two-(cc==2))*p;
                map_1_2[s1] += (two-(cc==2))*p;
                sum_1 += p;
                map_1[s1] += p;
                sum_1_3 += (sum_3 - map_3[s1] - (cc==3))*p;
            }
            if(f == 2)
            {
                int one = tree_1.get_sum(p1,p2);
                int three = tree_3.get_sum(p1,p2);
                cur_ans += (one*(sum_3-map_3[s1])+three*(sum_1-map_1[s1]))*p;
                tree_2.add_seg(p1,p2,p);
                if(p1 != 1)
                {
                    sum_1_2 += one*p;
                    map_1_2[s1] += one*p;
                    sum_3_2 += three*p;
                    map_3_2[s1] += three*p;
                }
            }
            if(f == 3)
            {
                int two = tree_2.get_sum(p1,p1);
                cur_ans += (two*(sum_1-map_1[s1])+sum_1_2-map_1_2[s1])*p;
                tree_3.add_seg(p1,p1,p);
                sum_3_2 += (two-(cc==2))*p;
                map_3_2[s1] += (two-(cc==2))*p;
                sum_3 += p;
                map_3[s1] += p;
                sum_1_3 += (sum_1 - map_1[s1] - (cc == 1))*p;
            }
        }
        else
        {
            cc = f;
            if(f == 1)
            {
                cur_ans += sum_3_2*p;
                sum_1 += p;
            }
            if(f == 2)
            {
                cur_ans += sum_1_3*p;
                tree_2.add_seg(p1,p2,p);
            }
            if(f == 3)
            {
                cur_ans += sum_1_2*p;
                sum_3 += p;
            }
        }
    }
};

centr centr_str[200001];

void dfs_sub(int v, int pop)
{
    sub[v] = 1;
    forall(it,graph[v]) if(it != pop && !odw[it])
    {
        dfs_sub(it,v);
        sub[v] += sub[it];
    }
}

void dfs_pre(int v, int pop)
{
    pre[v] = cur_pre++;
    forall(it,graph[v]) if(it != pop && !odw[it]) dfs_pre(it,v);
    maxpre[v] = cur_pre-1;
}

void dfs_add(int v, int pop, int c, int s1, int s2)
{
    elms[v].pb({c,pre[v],maxpre[v],s1,s2});
    forall(it,graph[v]) if(it != pop && !odw[it]) dfs_add(it,v,c,s1,s2);
}

void centroid(int v, int n)
{
    dfs_sub(v,v);
    int pop = v;
    while(true)
    {
        pii best = {-1,-1};
        forall(it,graph[v]) if(it != pop && !odw[it]) best = max(best,{sub[it],it});
        if(best.ff > n/2)
        {
            pop = v;
            v = best.ss;
        }
        else break;
    }
    odw[v] = 1;
    centr_str[v].setup(n);
    dfs_sub(v,v);
    cur_pre = 1;
    dfs_pre(v,v);
    forall(it,graph[v]) if(!odw[it]) dfs_add(it,v,v,pre[it],maxpre[it]);
    elms[v].pb({v,1,maxpre[v],1,maxpre[v]});
    forall(it,graph[v]) if(!odw[it]) centroid(it,sub[it]);
}

void change_vert(int v, int f)
{
    forall(it,elms[v])
    {
        if(F[v] != -1) centr_str[it.centr].set_vert(it.p1,it.p2,it.s1,it.s2,F[v],-1);
        centr_str[it.centr].set_vert(it.p1,it.p2,it.s1,it.s2,f,1);
    }
    F[v] = f;
}

void init(int n, vi F2, vi U, vi V, int Q) 
{   
    rep(i,n) F[i] = -1;
    rep(i,n-1)
    {
        graph[U[i]].pb(V[i]);
        graph[V[i]].pb(U[i]);
    }
    centroid(0,n);
    rep(i,n) change_vert(i,F2[i]+1);
}

void change(int X, int Y) 
{
    change_vert(X,Y+1);
}

ll num_tours() 
{
    return cur_ans;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...