Submission #1178368

#TimeUsernameProblemLanguageResultExecution timeMemory
117836812345678JOI tour (JOI24_joitour)C++17
100 / 100
316 ms95688 KiB
#include "joitour.h"
#include <bits/stdc++.h>

using namespace std;

#define ll long long

const int nx=2e5+5;

int n, f[nx];
vector<int> d[nx];

enum Type{Compress, Rake, AddEdge, Vertex, AddVertex};

struct statictoptree
{
    int pa[nx], hv[nx];
    int lch[4*nx], rch[4*nx], par[4*nx], cnt, rt;
    Type type[4*nx];
    int dfs(int u)
    {
        int sz=1, mx=0;
        for (auto v:d[u])
        {
            if (v==pa[u]) continue;
            pa[v]=u;
            int t=dfs(v);
            if (t>mx) mx=t, hv[u]=v;
            sz+=t;
        }
        return sz;
    }
    int add(int i, int l, int r, Type t)
    {
        if (!i) i=++cnt;
        lch[i]=l, rch[i]=r, type[i]=t;
        if (l) par[l]=i;
        if (r) par[r]=i;
        return i;
    }
    pair<int, int> merge(vector<pair<int, int>> &a, Type t)
    {
        if (a.size()==1) return a[0];
        int sm=0;
        vector<pair<int, int>> b, c;
        for (auto [i, sz]:a) sm+=sz;
        for (auto [i, sz]:a)
        {
            (sm>sz?b:c).push_back({i, sz});
            sm-=2*sz;
        }
        auto [i, szi]=merge(b, t);
        auto [j, szj]=merge(c, t);
        return {add(0, i, j, t), szi+szj};
    }
    pair<int, int> compress(int i)
    {
        //cout<<"getin "<<i<<'\n';
        vector<pair<int, int>> a={addvertex(i)};
        while (hv[i]) a.push_back({addvertex(i=hv[i])});
        //cout<<"compress "<<i<<'\n';
        //for (auto [x, y]:a) cout<<"inside "<<x<<' '<<y<<'\n';
        return merge(a, Compress);
    }
    pair<int, int> rake(int i)
    {
        vector<pair<int, int>> a;
        for (auto v:d[i]) if (v!=pa[i]&&v!=hv[i]) a.push_back(addedge(v));
        return a.empty()?make_pair(0, 0):merge(a, Rake);
    }
    pair<int, int> addedge(int i)
    {
        auto [j, sz]=compress(i);
        return {add(0, j, 0, AddEdge), sz};
    }
    pair<int, int> addvertex(int i)
    {
        auto [j, sz]=rake(i);
        //cout<<"addvertex "<<j<<' '<<sz<<'\n';
        return {add(i, j, 0, j?AddVertex:Vertex), sz+1};
    }
    void build()
    {
        cnt=n;
        dfs(1);
        rt=compress(1).first;
    }
} s;

struct info
{
    ll c0, c2, c10, c12, c02, dn10, dn12, ch1, ans;
    info(): c0(0), c2(0), c10(0), c12(0), c02(0), dn10(0), dn12(0), ch1(0), ans(0){}
} v[4*nx];

info compress(info p, info c)
{
    info res;
    res.c0=p.c0+c.c0;
    res.c2=p.c2+c.c2;
    res.c10=p.c10+c.c10+p.ch1*c.c0;
    res.c12=p.c12+c.c12+p.ch1*c.c2;
    res.dn10=p.dn10+c.dn10+c.ch1*p.c0;
    res.dn12=p.dn12+c.dn12+c.ch1*p.c2;
    res.ch1=p.ch1+c.ch1;
    res.ans=p.ans+c.ans+c.c0*p.dn12+c.c2*p.dn10+c.c12*p.c0+c.c10*p.c2;
    return res;
}

info rake(info l, info r)
{
    info res;
    res.c0=l.c0+r.c0;
    res.c2=l.c2+r.c2;
    res.c10=l.c10+r.c10;
    res.c12=l.c12+r.c12;
    res.c02=l.c0*r.c2+l.c2*r.c0+l.c02+r.c02;
    res.ans = l.ans + r.ans + l.c0 * r.c12 + l.c2 * r.c10 + r.c0 * l.c12 + r.c2 * l.c10;
    return res;
}

info addedge(info x)
{
    info res;
    res.c0=x.c0;
    res.c2=x.c2;
    res.c10=x.c10;
    res.c12=x.c12;
    res.ans=x.ans;
    return x;
}

info addvertex(info x, int i)
{
    info res;
    res.c0=x.c0+(f[i]==0);
    res.c2=x.c2+(f[i]==2);
    res.ch1=f[i]==1;
    res.c10=res.dn10=x.c10+res.ch1*x.c0;
    res.c12=res.dn12=x.c12+res.ch1*x.c2;
    res.ans=x.ans+(f[i]==1)*(x.c02)+(f[i]==0)*(x.c12)+(f[i]==2)*(x.c10);
    return res;
}

info vertex(int i)
{
    info res;
    res.c0=(f[i]==0), res.c2=(f[i]==2);
    res.ch1=(f[i]==1);
    return res;
}

void update(int i)
{
    if (s.type[i]==Compress) v[i]=compress(v[s.lch[i]], v[s.rch[i]]);
    else if (s.type[i]==Rake) v[i]=rake(v[s.lch[i]], v[s.rch[i]]);
    else if (s.type[i]==AddEdge) v[i]=addedge(v[s.lch[i]]);
    else if (s.type[i]==AddVertex) v[i]=addvertex(v[s.lch[i]], i);
    else if (s.type[i]==Vertex) v[i]=vertex(i);
}

void dfs(int i)
{
    if (!i) return;
    dfs(s.lch[i]);
    dfs(s.rch[i]);
    update(i);
    //cout<<"here "<<i<<':'<<v[i].c0<<' '<<v[i].c2<<'\n';
    //cout<<"left "<<s.lch[i]<<' '<<"right "<<s.rch[i]<<'\n';
    //cout<<"type "<<s.type[i]<<'\n';
}

void init(int N, std::vector<int> F, std::vector<int> U, std::vector<int> V, int Q) 
{
    n=N;
    for (int i=1; i<=n; i++) f[i]=F[i-1];
    for (int i=0; i<n-1; i++) d[U[i]+1].push_back(V[i]+1), d[V[i]+1].push_back(U[i]+1);
    s.build();
    dfs(s.rt);
}

void change(int X, int Y) 
{
    X++;
    f[X]=Y;
    while (X) update(X), X=s.par[X];
}

long long num_tours() 
{
    //cout<<"count "<<v[s.rt].c0<<' '<<v[s.rt].c2<<'\n';
    return v[s.rt].ans;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...