Submission #959199

#TimeUsernameProblemLanguageResultExecution timeMemory
959199andrei_boacaCapital City (JOI20_capital_city)C++17
100 / 100
960 ms120404 KiB
#include <bits/stdc++.h>

using namespace std;
typedef pair<int,int> pii;
int n,k,nr[200005],par[200005],dp[21][200005],in[200005],out[200005],niv[200005];
bool use[200005],isheavy[200005];
vector<int> mynodes[200005];
int cul[200005],poz[200005],where[200005],last[200005],chains,lg[200005],timp;
int nrcomp,nrcul[200005],nrnodes[200005],nrmuchii[200005],comp[200005];
vector<vector<int>> arb;
vector<vector<vector<int>>> aint;
vector<int> muchii[200005];
vector<pii> init;
vector<int> st;
vector<int> who;
bool isancestor(int a,int b)
{
    return in[a]<=in[b]&&out[a]>=out[b];
}
int LCA(int a,int b)
{
    if(niv[a]>niv[b])
        swap(a,b);
    if(isancestor(a,b))
        return a;
    for(int i=18;i>=0;i--)
        if(dp[i][a]!=0&&!isancestor(dp[i][a],b))
            a=dp[i][a];
    return par[a];
}
void initdfs(int nod)
{
    timp++;
    in[nod]=timp;
    nr[nod]=1;
    dp[0][nod]=par[nod];
    for(int i=1;i<=18;i++)
        dp[i][nod]=dp[i-1][dp[i-1][nod]];
    int who=0,maxim=0;
    for(int i:muchii[nod])
        if(i!=par[nod])
        {
            par[i]=nod;
            niv[i]=niv[nod]+1;
            initdfs(i);
            nr[nod]+=nr[i];
            if(nr[i]>maxim)
            {
                maxim=nr[i];
                who=i;
            }
        }
    if(who!=0)
        isheavy[who]=1;
    out[nod]=timp;
}
void buildheavy()
{
    for(int i=1;i<=n;i++)
    {
        int nod=i;
        bool ok=1;
        for(int j:muchii[nod])
            if(j!=par[nod]&&isheavy[j])
            {
                ok=0;
                break;
            }
        if(!ok)
            continue;
        chains++;
        while(nod!=0)
        {
            lg[chains]++;
            poz[nod]=lg[chains];
            where[nod]=chains;
            last[chains]=nod;
            if(!isheavy[nod])
                break;
            nod=par[nod];
        }
    }
    arb.resize(chains+1);
    aint.resize(chains+1);
    for(int i=1;i<=chains;i++)
    {
        arb[i].resize(4*lg[i]+5);
        aint[i].resize(4*lg[i]+5);
    }
}
void update(int ind,int nod,int st,int dr,int p,int val)
{
    if(st==dr)
    {
        arb[ind][nod]=val;
        return;
    }
    int mij=(st+dr)/2;
    if(p<=mij)
        update(ind,nod*2,st,mij,p,val);
    else
        update(ind,nod*2+1,mij+1,dr,p,val);
    arb[ind][nod]=max(arb[ind][nod*2],arb[ind][nod*2+1]);
}
int query(int ind,int nod,int st,int dr,int a,int b)
{
    if(st>=a&&dr<=b)
        return arb[ind][nod];
    int rez=-1;
    int mij=(st+dr)/2;
    if(a<=mij)
        rez=max(rez,query(ind,nod*2,st,mij,a,b));
    if(b>mij)
        rez=max(rez,query(ind,nod*2+1,mij+1,dr,a,b));
    return rez;
}
void plsput(int nod,int val)
{
    int ind=where[nod];
    update(ind,1,1,lg[ind],poz[nod],val);
}
int chmin(int a,int lca)
{
    int rez=-1;
    while(niv[a]>=niv[lca])
    {
        int ind=where[a];
        int p=poz[a];
        if(where[lca]==ind)
        {
            int val=query(ind,1,1,lg[ind],p,poz[lca]);
            rez=max(rez,val);
            break;
        }
        int val=query(ind,1,1,lg[ind],p,lg[ind]);
        rez=max(rez,val);
        a=par[last[ind]];
    }
    return rez;
}
int getmin(int a,int b)
{
    int lca=LCA(a,b);
    return max(chmin(a,lca),chmin(b,lca));
}
void upd(int ind,int nod,int st,int dr,int a,int b,int val)
{
    if(st>=a&&dr<=b)
    {
        aint[ind][nod].push_back(val);
        return;
    }
    int mij=(st+dr)/2;
    if(a<=mij)
        upd(ind,nod*2,st,mij,a,b,val);
    if(b>mij)
        upd(ind,nod*2+1,mij+1,dr,a,b,val);
}
void getnodes(int ind,int nod,int st,int dr,int p)
{
    for(int i:aint[ind][nod])
        who.push_back(i);
    aint[ind][nod].clear();
    if(st==dr)
        return;
    int mij=(st+dr)/2;
    if(p<=mij)
        getnodes(ind,nod*2,st,mij,p);
    else
        getnodes(ind,nod*2+1,mij+1,dr,p);
}
void add(int a,int lca,int val)
{
    while(niv[a]>=niv[lca])
    {
        int ind=where[a];
        int p=poz[a];
        if(where[lca]==ind)
        {
            upd(ind,1,1,lg[ind],p,poz[lca],val);
            break;
        }
        upd(ind,1,1,lg[ind],p,lg[ind],val);
        a=par[last[ind]];
    }
}
void plsadd(int a,int b,int val)
{
    int lca=LCA(a,b);
    add(a,lca,val);
    add(b,lca,val);
}
void dfs1(int nod)
{
    use[nod]=1;
    for(int x:mynodes[nod])
        plsput(x,-1);
    for(int i=1;i<mynodes[nod].size();i++)
    {
        int a=mynodes[nod][0];
        int b=mynodes[nod][i];
        int x=getmin(a,b);
        while(x>=1)
        {
            dfs1(x);
            x=getmin(a,b);
        }
    }
    st.push_back(nod);
}
void dfs2(int nod)
{
    use[nod]=1;
    comp[nod]=nrcomp;
    nrcul[nrcomp]++;
    for(int i:mynodes[nod])
    {
        who.clear();
        vector<int> aux;
        int ind=where[i];
        int p=poz[i];
        getnodes(ind,1,1,lg[ind],p);
        aux=who;
        for(int j:aux)
            if(!use[j])
                dfs2(j);
    }
}
int main()
{
    ios_base::sync_with_stdio(false);
    cin.tie(0);
    cin>>n>>k;
    for(int i=1;i<n;i++)
    {
        int a,b;
        cin>>a>>b;
        muchii[a].push_back(b);
        muchii[b].push_back(a);
        init.push_back({a,b});
    }
    for(int i=1;i<=n;i++)
    {
        cin>>cul[i];
        mynodes[cul[i]].push_back(i);
    }
    niv[1]=1;
    initdfs(1);
    buildheavy();
    for(int i=1;i<=n;i++)
        plsput(i,cul[i]);
    for(int i=1;i<=k;i++)
        if(!use[i])
            dfs1(i);
    reverse(st.begin(),st.end());
    for(int i=1;i<=n;i++)
        use[i]=0;
    for(int i=1;i<=k;i++)
    {
        for(int j=1;j<mynodes[i].size();j++)
        {
            int a=mynodes[i][0];
            int b=mynodes[i][j];
            plsadd(a,b,i);
        }
    }
    for(int i:st)
        if(!use[i])
        {
            nrcomp++;
            dfs2(i);
        }
    /*for(int i=1;i<=k;i++)
        cout<<comp[i]<<' ';
    cout<<'\n';*/
    for(int i=1;i<=n;i++)
    {
        int c=comp[cul[i]];
        nrnodes[c]++;
    }
    for(pii p:init)
    {
        int a=p.first;
        int b=p.second;
        a=comp[cul[a]];
        b=comp[cul[b]];
        if(a==b)
            nrmuchii[a]++;
    }
    int ans=1e9;
    for(int i=1;i<=nrcomp;i++)
        if(nrmuchii[i]+1==nrnodes[i])
            ans=min(ans,nrcul[i]);
    cout<<ans-1;
    return 0;
}

Compilation message (stderr)

capital_city.cpp: In function 'void dfs1(int)':
capital_city.cpp:198:18: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  198 |     for(int i=1;i<mynodes[nod].size();i++)
      |                 ~^~~~~~~~~~~~~~~~~~~~
capital_city.cpp: In function 'int main()':
capital_city.cpp:260:22: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  260 |         for(int j=1;j<mynodes[i].size();j++)
      |                     ~^~~~~~~~~~~~~~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...