Submission #1347063

#TimeUsernameProblemLanguageResultExecution timeMemory
1347063MMihalevTourism (JOI23_tourism)C++20
59 / 100
5018 ms55484 KiB
#include<iostream>
#include<vector>
#include<algorithm>
#include<set>
#include<cmath>
using namespace std;

const int MAX_N=1e5+5;
const int LOG=18;

int n,m,q;
vector<int>g[MAX_N];
int c[MAX_N];
int parent[MAX_N][LOG];

int sp[2*MAX_N][LOG];
int which[2*MAX_N][LOG];
int lg[2*MAX_N];
int first[MAX_N];

int depth[MAX_N];
int T=-1;
int in[MAX_N];
int out[MAX_N];
int ver[MAX_N];

int szsp;
void dfs(int u,int par)
{
    in[u]=++T;
    ver[T]=u;

    sp[++szsp][0]=depth[u];
    which[szsp][0]=u;
    first[u]=szsp;

    parent[u][0]=par;
    for(int j=1;j<LOG;j++)
    {
        parent[u][j]=parent[parent[u][j-1]][j-1];
    }

    for(int v:g[u])
    {
        if(v==par)continue;

        depth[v]=depth[u]+1;
        dfs(v,u);

        sp[++szsp][0]=depth[u];
        which[szsp][0]=u;
    }

    out[u]=T;
}

int ll,lr,lk;
int lca(int u,int v)
{
    ll=first[u];lr=first[v];
    if(ll>lr)swap(ll,lr);

    lk=lg[lr-ll+1];

    if(sp[ll][lk]<sp[lr-(1<<lk)+1][lk])
    {
        return which[ll][lk];
    }
    return which[lr-(1<<lk)+1][lk];
}

int cntorder[MAX_N];
int cntorderblock[MAX_N];
int cntlcs[MAX_N];
int cntlcsblock[MAX_N];
int sum;
int sz;

int nxt,prv,lc_old,lc_new;


int tree[MAX_N];
void Update(int pos,int val)
{
    for(;;)
    {
        if(pos>n)break;
        tree[pos]+=val;
        pos+=((pos)&(-pos));
    }
}

int res;
int Find(int pos)
{
    res=0;
    for(;;)
    {
        if(pos<1)break;
        res+=tree[pos];
        pos-=((pos)&(-pos));
    }
    return res;
}

int poss,curr,jj;
int kth(int k)
{
    poss=0;curr=0;
    for(jj=LOG-1;jj>=0;jj--)
    {
        if(poss+(1<<jj)>n or curr+tree[poss+(1<<jj)]>=k)continue;
        curr+=tree[poss+(1<<jj)];
        poss+=(1<<jj);
    }

    return poss+1;
}

int all,x;
void add(int u)
{
    cntorder[in[u]]++;
    cntorderblock[in[u]/sz]++;

    if(cntorder[in[u]]>1)return;
    

    x=Find(in[u]+1);
    
    if(x>=1 && x+1<=all)
    {
        nxt=ver[kth(x+1)-1];
        prv=ver[kth(x)-1];

        lc_old=lca(nxt,prv);
        cntlcs[in[lc_old]]--;
        cntlcsblock[in[lc_old]/sz]--;
        sum-=depth[nxt]-depth[lc_old];

        lc_new=lca(u,prv);
        cntlcs[in[lc_new]]++;
        cntlcsblock[in[lc_new]/sz]++;
        sum+=depth[u]-depth[lc_new];

        lc_new=lca(nxt,u);
        cntlcs[in[lc_new]]++;
        cntlcsblock[in[lc_new]/sz]++;
        sum+=depth[nxt]-depth[lc_new];
    }
    else if(x>=1)
    {
        prv=ver[kth(x)-1];

        lc_new=lca(u,prv);
        cntlcs[in[lc_new]]++;
        cntlcsblock[in[lc_new]/sz]++;
        sum+=depth[u]-depth[lc_new];
    }
    else if(x+1<=all)
    {
        nxt=ver[kth(x+1)-1];

        lc_new=lca(nxt,u);
        cntlcs[in[lc_new]]++;
        cntlcsblock[in[lc_new]/sz]++;
        sum+=depth[nxt]-depth[lc_new];
    }
    
    Update(in[u]+1,+1);
    all++;
}

void rem(int u)
{
    cntorder[in[u]]--;
    cntorderblock[in[u]/sz]--;

    if(cntorder[in[u]]>0)return;
    Update(in[u]+1,-1);
    all--;

    x=Find(in[u]+1);
    
    if(x>=1 && x+1<=all)
    {
        nxt=ver[kth(x+1)-1];
        prv=ver[kth(x)-1];

        lc_old=lca(nxt,prv);
        cntlcs[in[lc_old]]++;
        cntlcsblock[in[lc_old]/sz]++;
        sum+=depth[nxt]-depth[lc_old];

        lc_new=lca(u,prv);
        cntlcs[in[lc_new]]--;
        cntlcsblock[in[lc_new]/sz]--;
        sum-=depth[u]-depth[lc_new];

        lc_new=lca(nxt,u);
        cntlcs[in[lc_new]]--;
        cntlcsblock[in[lc_new]/sz]--;
        sum-=depth[nxt]-depth[lc_new];
    }
    else if(x>=1)
    {
        prv=ver[kth(x)-1];

        lc_new=lca(u,prv);
        cntlcs[in[lc_new]]--;
        cntlcsblock[in[lc_new]/sz]--;
        sum-=depth[u]-depth[lc_new];
    }
    else if(x+1<=all)
    {
        nxt=ver[kth(x+1)-1];

        lc_new=lca(nxt,u);
        cntlcs[in[lc_new]]--;
        cntlcsblock[in[lc_new]/sz]--;
        sum-=depth[nxt]-depth[lc_new];
    }
}

int calc()
{
    int u=-1,lc=-1;

    u=ver[kth(1)-1];

    for(int b=0;b<=(n-1)/sz;b++)
    {
        if(cntlcsblock[b]>0)
        {
            for(int i=b*sz;i<min(n,(b+1)*sz);i++)
            {
                if(cntlcs[i]>0)
                {
                    lc=ver[i];
                    break;
                }
            }
            if(lc!=-1)break;
        }
    }
    
    if(lc==-1 or in[lc]>in[u])lc=u;

    return sum+(depth[u]-depth[lc]+1);
}

vector<pair<pair<int,int>,int>>queries;
int ans[MAX_N];
int fuck=593;
int main ()
{
    ios_base::sync_with_stdio(0);
    cin.tie(NULL);
    cout.tie(NULL);

    cin>>n>>m>>q;
    sz=(int)(sqrt(n));
    
    
    for(int i=1;i<n;i++)
    {
        int u,v;
        cin>>u>>v;
        g[u].push_back(v);
        g[v].push_back(u);
    }

    dfs(1,0);

    for(int i=1;i<=szsp;i++)lg[i]=(int)(log2(i));
    for(int j=1;j<LOG;j++)
    {
        for(int i=1;i+(1<<j)-1<=szsp;i++)
        {
            if(sp[i][j-1]<sp[i+(1<<(j-1))][j-1])
            {
                sp[i][j]=sp[i][j-1];
                which[i][j]=which[i][j-1];
            }
            else
            {
                sp[i][j]=sp[i+(1<<(j-1))][j-1];
                which[i][j]=which[i+(1<<(j-1))][j-1];
            }
        }
    }

    for(int i=1;i<=m;i++)
    {
        cin>>c[i];
    }

    for(int i=1;i<=q;i++)
    {
        int l,r;
        cin>>l>>r;
        queries.push_back({{l,r},i});
    }

    sort(queries.begin(),queries.end(),[](pair<pair<int,int>,int>a,pair<pair<int,int>,int>b)
{
    if(a.first.first/fuck==b.first.first/fuck)return a.first.second<b.first.second;
    return a.first.first/fuck<b.first.first/fuck;
});

    int posl=1,posr=0;

    for(auto [pa,id]:queries)
    {
        int l=pa.first,r=pa.second;

        while(posr<r)
        {
            add(c[++posr]);
        }
        while(posl>l)
        {
            add(c[--posl]);
        }
        while(posr>r)
        {
            rem(c[posr--]);
        }
        while(posl<l)
        {
            rem(c[posl++]);
        }
        
        ans[id]=calc();
    }

    for(int i=1;i<=q;i++)
    {
        cout<<ans[i]<<"\n";
    }

    return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...