Submission #493993

#TimeUsernameProblemLanguageResultExecution timeMemory
493993fcmalkcinMergers (JOI19_mergers)C++17
0 / 100
124 ms54568 KiB
/*#pragma GCC optimize("Ofast")
#pragma GCC optimization("unroll-loops, no-stack-protector")
#pragma GCC target("avx,avx2,fma")*/

#include <bits/stdc++.h>
using namespace std;

#define ll  long long
#define pll pair<ll,ll>
#define ff first
#define ss second
#define pb push_back
#define endl "\n"
mt19937 rnd(chrono::steady_clock::now().time_since_epoch().count());

const ll maxn=5e5+20;
const ll mod=998244353  ;
const ll base=1e9+100;

/// you will be the best but now you just are trash
/// goal 5/7

ll anc[maxn][20];
ll par1[maxn];
ll a[maxn];
ll par[maxn];
vector<ll> gr[maxn];
vector<ll> adj[maxn];
ll cntnw=0;
ll f[maxn];
ll dep[maxn];
ll find_par(ll u)
{
    if (u==par[u]) return u;
    return par[u]=find_par(par[u]);
}
void dsu(ll x,ll y)
{
    x=find_par(x);
    y=find_par(y);
    if (x==y)
    {
        return ;
    }
    if (dep[x]<dep[y]) swap(x,y);
    par[x]=y;
}
void dfs(ll u,ll par)
{
    cntnw++;
    f[u]=cntnw;
    anc[u][0]=par;
    for (int i=1;i<20;i++) anc[u][i]=anc[anc[u][i-1]][i-1];
    par1[u]=par;
    for (auto to:adj[u])
    {
        if (to==par) continue;
        dep[to]=dep[u]+1;
        dfs(to,u);
    }
}
ll lca(ll x,ll y)
{
    if (dep[x]<dep[y]) swap(x,y);
    ll kc=dep[x]-dep[y];
    for (int i=19;i>=0;i--)
    {
        if (kc&(1ll<<i))
        {
            x=anc[x][i];
        }
    }
    if (x==y) return x;
    for (int i=19;i>=0;i--)
    {
        if (anc[x][i]!=anc[y][i])
        {
            x=anc[x][i];
            y=anc[y][i];
        }
    }
    return anc[x][0];
}
ll cnt[maxn];

int main()
{
    ios_base::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    if (fopen("t.inp", "r"))
    {
        freopen("test.inp", "r", stdin);
        freopen("test.out", "w", stdout);
    }
    ll n, k;
    cin>> n>> k;
    vector<pll> vt;
    for (int i=1;i<=n-1;i++)
    {
        ll x, y;
        cin>>x>> y;
        vt.pb(make_pair(x,y));
        adj[x].pb(y);
        adj[y].pb(x);
    }
    for (int i=1;i<=n;i++)
    {
        par[i]=i;
        cin>>a[i];
        gr[a[i]].pb(i);
    }
    dfs(1,0);
    for (int i=1;i<=n;i++)
    {
        for (auto to:adj[i])
        {
            if (a[to]==a[i]) dsu(to,i);
        }
    }
    for (int i=1;i<=k;i++)
    {
        set<pll> st;
        for (auto p:gr[i])
        {
            ll x=find_par(p);
            st.insert(make_pair(f[x],x));
         //   if (i==1) cout <<f[x]<<" "<<x<<" "<<p<<endl;
        }
        gr[i].clear();
        while (st.size()>=2)
        {
            auto p=(*st.begin());
            st.erase(st.begin());
            auto p1=(*st.begin());
            st.erase(st.begin());
            ll pos=p.ss;
            ll pos1=p1.ss;
            vector<ll> vt;
            ll nw=lca(pos,pos1);
            nw=find_par(nw);

            while (dep[nw]<dep[find_par(pos)])
            {
                pos=find_par(pos);
                ll h=par1[pos];
                dsu(pos,h);

                for (auto to:gr[h])
                {
                    vt.pb(to);
                }
                gr[a[h]].clear();
            }
            swap(pos,pos1);
          //  cout <<pos<<endl;
             while (dep[nw]<dep[find_par(pos)])
            {
                pos=find_par(pos);
                ll h=par1[pos];
                dsu(pos,h);
               // cout <<pos<<" "<<h<<endl;
                for (auto to:gr[a[h]])
                {
                    vt.pb(to);
                }
                gr[a[h]].clear();
            }
            for (auto to:vt)
            {
                ll x=find_par(to);
                st.insert(make_pair(f[x],x));
            }
            ll h=find_par(pos);
            st.insert(make_pair(f[h],h));
        }
    }
    for (int i=1;i<=n;i++) adj[i].clear();
    for (auto to:vt)
    {
        ll x=find_par(to.ff);
        ll y=find_par(to.ss);
        if (x!=y)
        {
           cnt[x]++;
           cnt[y]++;
        }
    }
    ll ans=0;
    ll dem=0;
    for (int i=1;i<=n;i++)
    {
        if (i==find_par(i))
        {
         //  cout <<i<<" "<<cnt[i]<<endl;
            dem++;
            if (cnt[i]==1) ans++;
        }
    }
    if (dem==1) cout <<0<<endl;
    else cout <<(ans+1)/2<<endl;


}

Compilation message (stderr)

mergers.cpp: In function 'int main()':
mergers.cpp:93:16: warning: ignoring return value of 'FILE* freopen(const char*, const char*, FILE*)' declared with attribute 'warn_unused_result' [-Wunused-result]
   93 |         freopen("test.inp", "r", stdin);
      |         ~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~
mergers.cpp:94:16: warning: ignoring return value of 'FILE* freopen(const char*, const char*, FILE*)' declared with attribute 'warn_unused_result' [-Wunused-result]
   94 |         freopen("test.out", "w", stdout);
      |         ~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...