#include "joitour.h"
#include <string.h>
#include <stdio.h>
#include <stdlib.h>
#define N 200000
int n, color[N], hd[N], e[N * 2], Ln[N * 2], ii, sz[N], hld[N], tail[N], nfd[N], dfn[N], dfc, fa[N], lv[N], n_[N], childcount[N], id[N], m_[N];
struct path {
long long c0,c2,c10,c12,c02,p10,p12,s1,ans;
path() {
c0 = c2 = c10 = c12 = c02 = p10 = p12 = s1 = ans = 0;
}
friend path operator+(path p, path c) { /* compress */
path res;
res.c0=p.c0+c.c0,res.c2=p.c2+c.c2,res.s1=p.s1+c.s1;
res.c10=p.c10+c.c10+p.s1*c.c0;
res.c12=p.c12+c.c12+p.s1*c.c2;
res.p10=p.p10+c.p10+c.s1*p.c0;
res.p12=p.p12+c.p12+c.s1*p.c2;
res.ans=p.ans+c.ans+p.c0*c.c12+p.c2*c.c10+c.c0*p.p12+c.c2*p.p10;
return res;
}
};
struct point {
long long c0,c2,c10,c12,c02,p10,p12,s1,ans;
point() {
c0 = c2 = c10 = c12 = c02 = p10 = p12 = s1 = ans = 0;
}
friend point operator+(point l, point r) { /* rake */
point res;
res.c0=l.c0+r.c0,res.c2=l.c2+r.c2;
res.c02=l.c02+r.c02+l.c0*r.c2+l.c2*r.c0;
res.c10=l.c10+r.c10,res.c12=l.c12+r.c12;
res.ans=l.ans+r.ans+l.c0*r.c12+l.c2*r.c10+r.c0*l.c12+r.c2*l.c10;
return res;
}
};
path *tt[N];
point *yy[N];
point Addedge(path x) {
point res;
res.c0=x.c0,res.c2=x.c2;
res.c10=x.c10,res.c12=x.c12;
res.c02=0;
res.ans=x.ans;
return res;
}
path Addvertex(point x, char col) {
path res;
res.c0=x.c0+(col==0);
res.c2=x.c2+(col==2);
res.s1=col==1;
res.c10=res.p10=x.c10+res.s1*x.c0;
res.c12=res.p12=x.c12+res.s1*x.c2;
res.ans=x.ans+res.s1*x.c02+(col==0)*x.c12+(col==2)*x.c10;
return res;
}
path Vertex(int col) {
path res;
res.c0=col==0,res.s1=col==1,res.c2=col==2;
res.c10=res.c12=res.p10=res.p12=res.ans=0;
return res;
}
void add(int u, int v) {
++ii;
Ln[ii] = hd[u];
e[ii] = v;
hd[u] = ii;
}
void dfs1(int u) {
sz[u] = 1;
for (int v, j = hd[u]; j; j = Ln[j]) {
v = e[j];
if (v == fa[u])
continue;
fa[v] = u;
dfs1(v);
sz[u] += sz[v];
if (e[hd[u]] == fa[u] || sz[e[hd[u]]] < sz[v])
e[j] = e[hd[u]], e[hd[u]] = v;
}
}
void dfs2(int u) {
nfd[dfn[u] = dfc++] = u;
tail[hld[u]] = u;
for (int v, j = hd[u]; j; j = Ln[j]) {
v = e[j];
if (v == fa[u])
continue;
hld[v] = (j != hd[u]) ? v: hld[u];
lv[v] = lv[u] + 1;
dfs2(v);
if (j != hd[u]) {
id[v] = childcount[u]++;
}
}
}
#define md ((l + r) / 2)
void pul(path *t, int v) {
t[v] = t[v << 1 | 1] + t[v << 1];
t[v] = t[v << 1] + t[v << 1 | 1];
}
void pull(point *t, int v) {
t[v] = t[v << 1] + t[v << 1 | 1];
}
void re_lp(int u) {
int p = fa[u];
if (p == u)
return;
yy[p][id[u] + m_[p]] = Addedge(tt[u][1]);
for (int j = id[u] + m_[p]; j >>= 1; )
pull(yy[p], j);
tt[hld[p]][dfn[p] - dfn[hld[p]] + n_[hld[p]]] = Addvertex(yy[p][1], color[p]);
for (int j = dfn[p] - dfn[hld[p]] + n_[hld[p]]; j >>= 1; )
pul(tt[hld[p]], j);
}
void dfs3(int u) {
for (int v, j = hd[u]; j; j = Ln[j]) {
v = e[j];
if (v == fa[u])
continue;
dfs3(v);
}
if (fa[u] != u && dfn[fa[u]] + 1 != dfn[u])
re_lp(u);
}
void init(int NN, std::vector<int> F, std::vector<int> U, std::vector<int> V, int Q) {
n = NN;
for (int i = 0; i + 1 < n; ++i)
add(U[i], V[i]), add(V[i], U[i]);
memcpy(color, F.data(), sizeof(int) * n);
fa[0] = 0;
dfs1(0);
hld[0] = 0;
lv[0] = 0;
dfs2(0);
/* build compress tree */
for (int i = 0; i < n; ++i) {
tail[i] = tail[hld[i]];
if (hld[i] != i)
continue;
int len = dfn[tail[i]] - dfn[i] + 1;
for (n_[i] = 1; n_[i] < len; n_[i] <<= 1);
tt[i] = new path[n_[i] << 1];
for (int j = dfn[i]; j <= dfn[tail[i]]; ++j)
tt[i][j + n_[i] - dfn[i]] = Vertex(color[nfd[j]]);
for (int j = n_[i] - 1; j >= 1; --j)
pul(tt[i], j);
}
/* build rake tree */
for (int i = 0; i < n; ++i) {
if (! childcount[i])
continue;
for (m_[i] = 1; m_[i] < childcount[i]; m_[i] <<= 1);
yy[i] = new point[m_[i] << 1];
}
dfs3(0);
//for (int i = 0; i < n; ++i) printf(" [%d] - id = %d, dfn = %d, hld = %d, tail = %d, childcount = %d\n", i, id[i], dfn[i], hld[i], tail[i], childcount[i]);
}
void change(int X, int Y) {
color[X] = Y;
for (int p = X; ; p = fa[hld[p]]) {
tt[hld[p]][dfn[p] - dfn[hld[p]] + n_[hld[p]]] = childcount[p] ?
Addvertex(yy[p][1], color[p]): Vertex(color[p]);
for (int j = dfn[p] - dfn[hld[p]] + n_[hld[p]]; j >>= 1; )
pul(tt[hld[p]], j);
if (hld[p] == 0)
break;
re_lp(hld[p]);
}
}
long long num_tours() {
path x = tt[0][1];
return x.ans;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |