이 제출은 이전 버전의 oj.uz에서 채점하였습니다. 현재는 제출 당시와는 다른 서버에서 채점을 하기 때문에, 다시 제출하면 결과가 달라질 수도 있습니다.
#include "joitour.h"
#include <algorithm>
#include <iostream>
#include <numeric>
#include <cassert>
#include <vector>
typedef long long llong;
const int MAXN = 200000 + 10;
const int MAXLOG = 20;
const llong INF = 1e18;
const int INTINF = 1e9;
int n, q;
struct Fenwick
{
    int tree[MAXN * MAXLOG];
    void update(int pos, int val)
    {
        assert(pos != 0);
        for (int idx = pos ; idx < MAXN * MAXLOG ; idx += idx & (-idx))
        {
            tree[idx] += val;
        }
    }
    int query(int pos)
    {
        int res = 0;
        for (int idx = pos ; idx > 0 ; idx -= idx & (-idx))
        {
            res += tree[idx];
        }
        return res;
    }
    int rangeQuery(int l, int r)
    {
        return query(r) - query(l - 1);
    }
    void rangeUpdate(int l, int r, int val)
    {
        update(l, val);
        update(r + 1, -val);
    }
};
int f[MAXN];
int sz[MAXN];
bool vis[MAXN];
int dep[MAXN];
int in[MAXN][MAXLOG];
int out[MAXN][MAXLOG];
int parC[MAXN][MAXLOG];
int centroid[MAXN][MAXLOG];
llong centroidAnswer[MAXN];
llong global10[MAXN];
llong global12[MAXN];
llong local10[MAXN][MAXLOG];
llong local12[MAXN][MAXLOG];
int global0[MAXN];
int global2[MAXN];
int local0[MAXN][MAXLOG];
int local2[MAXN][MAXLOG];
llong localSum02[MAXN];
std::vector <int> g[MAXN];
std::vector <int> c[MAXN];
Fenwick fenwick[3];
llong answer;
void findSize(int node, int par)
{
    sz[node] = 1;
    for (const int &u : g[node])
    {
        if (u == par || vis[u])
        {
            continue;
        }
        findSize(u, node);
        sz[node] += sz[u];
    }
}
int findCentroid(int node, int par, int globalSZ)
{
    sz[node] = 1;
    for (const int &u : g[node])
    {
        if (u == par || vis[u])
        {
            continue;
        }
        if (sz[u] > globalSZ / 2)
        {
            return findCentroid(u, node, globalSZ);
        }
    }
    return node;
}
int decompose(int node, int d)
{
    findSize(node, 0);
    int cntr = findCentroid(node, 0, sz[node]);
    dep[cntr] = d;
    vis[cntr] = true;
    for (const int &u : g[cntr])
    {
        if (vis[u])
        {
            continue;
        }
        c[cntr].push_back(decompose(u, d + 1));
    }
    return cntr;
}
void removeSubtreeAnswer(int node, int d)
{
    int cntr = centroid[node][d];
    answer -= centroidAnswer[cntr];
    centroidAnswer[cntr] -= (global10[cntr] - local10[node][d]) * local2[node][d]; 
    centroidAnswer[cntr] -= (global12[cntr] - local12[node][d]) * local0[node][d]; 
    centroidAnswer[cntr] -= (global2[cntr] - local2[node][d]) * local10[node][d]; 
    centroidAnswer[cntr] -= (global0[cntr] - local0[node][d]) * local12[node][d]; 
    answer += centroidAnswer[cntr];
    
    global10[cntr] -= local10[node][d];
    global12[cntr] -= local12[node][d];
    global0[cntr] -= local0[node][d];
    global2[cntr] -= local2[node][d];
    localSum02[cntr] -= 1LL * local0[node][d] * local2[node][d];
}
void addSubtreeAnswer(int node, int d)
{
    int cntr = centroid[node][d];
    global10[cntr] += local10[node][d];
    global12[cntr] += local12[node][d];
    global0[cntr] += local0[node][d];
    global2[cntr] += local2[node][d];
    localSum02[cntr] += 1LL * local0[node][d] * local2[node][d];
 
    answer -= centroidAnswer[cntr];
    centroidAnswer[cntr] += (global10[cntr] - local10[node][d]) * local2[node][d]; 
    centroidAnswer[cntr] += (global12[cntr] - local12[node][d]) * local0[node][d]; 
    centroidAnswer[cntr] += (global2[cntr] - local2[node][d]) * local10[node][d]; 
    centroidAnswer[cntr] += (global0[cntr] - local0[node][d]) * local12[node][d]; 
    answer += centroidAnswer[cntr];
}
void removeCentroidAnswer(int cntr)
{
    answer -= centroidAnswer[cntr];
    if (f[cntr] == 0)
    {
        centroidAnswer[cntr] -= global12[cntr];
    } else if (f[cntr] == 1)
    {
        centroidAnswer[cntr] -= 1LL * global0[cntr] * global2[cntr] - localSum02[cntr];
    } else
    {
        centroidAnswer[cntr] -= global10[cntr];
    }
 
    answer += centroidAnswer[cntr];
}
void addCentroidAnswer(int cntr)
{
    answer -= centroidAnswer[cntr];
    if (f[cntr] == 0)
    {
        centroidAnswer[cntr] += global12[cntr];
    } else if (f[cntr] == 1)
    {
        centroidAnswer[cntr] += 1LL * global0[cntr] * global2[cntr] - localSum02[cntr];
    } else
    {
        centroidAnswer[cntr] += global10[cntr];
    }
 
    answer += centroidAnswer[cntr];
}
void removeNodeAnswer(int node, int d)
{
    if (f[node] == 0)
    {
        local0[parC[node][d]][d]--;
        local10[parC[node][d]][d] -= fenwick[1].query(in[node][d]);
        fenwick[0].update(in[node][d], -1);
    } else if (f[node] == 1)
    {
        local10[parC[node][d]][d] -= fenwick[0].rangeQuery(in[node][d], out[node][d]);
        local12[parC[node][d]][d] -= fenwick[2].rangeQuery(in[node][d], out[node][d]);
        fenwick[1].rangeUpdate(in[node][d], out[node][d], -1);
    } else 
    {
        local2[parC[node][d]][d]--;
        local12[parC[node][d]][d] -= fenwick[1].query(in[node][d]);
        fenwick[2].update(in[node][d], -1);
    }
}
void addNodeAnswer(int node, int d)
{
    if (f[node] == 0)
    {
        local0[parC[node][d]][d]++;
        local10[parC[node][d]][d] += fenwick[1].query(in[node][d]);
        fenwick[0].update(in[node][d], 1);
    } else if (f[node] == 1)
    {
        local10[parC[node][d]][d] += fenwick[0].rangeQuery(in[node][d], out[node][d]);
        local12[parC[node][d]][d] += fenwick[2].rangeQuery(in[node][d], out[node][d]);
        fenwick[1].rangeUpdate(in[node][d], out[node][d], 1);
    } else 
    {
        local2[parC[node][d]][d]++;
        local12[parC[node][d]][d] += fenwick[1].query(in[node][d]);
        fenwick[2].update(in[node][d], 1);
    }
}
void updateNode(int node, int d, int from, int to)
{
    if (centroid[node][d] == node)
    {
        removeCentroidAnswer(node);
        f[node] = to;
        addCentroidAnswer(node);
        f[node] = from;
    } else
    {
        removeCentroidAnswer(centroid[node][d]);
        removeSubtreeAnswer(parC[node][d], d);
        removeNodeAnswer(node, d);
        f[node] = to;
        addNodeAnswer(node, d);
        addSubtreeAnswer(parC[node][d], d);
        addCentroidAnswer(centroid[node][d]);
        f[node] = from;
    }
}
void addNode(int node, int d)
{
    if (centroid[node][d] == node)
    {
        addCentroidAnswer(node);
    } else
    {
        removeCentroidAnswer(centroid[node][d]);
        removeSubtreeAnswer(parC[node][d], d);
        addNodeAnswer(node, d);
        addSubtreeAnswer(parC[node][d], d);
        addCentroidAnswer(centroid[node][d]);
    }
}
void callUpdate(int node, int d, int from, int to)
{
    updateNode(node, d, from, to);
    if (d > 1)
    {
        callUpdate(node, d - 1, from, to);
    }
}
int timer;
void buildDFS(int node, int par, int d, int parentC, int cntr)
{
    in[node][d] = ++timer;
    parC[node][d] = parentC;
    centroid[node][d] = cntr;
    for (const int &u : g[node])
    {
        if (u == par || dep[u] <= d)
        {
            continue;
        }
        buildDFS(u, node, d, (parentC == 0 ? u : parentC), cntr);
    }
    out[node][d] = timer;
}
void addDFS(int node, int par, int d)
{
    addNode(node, d);
    for (const int &u : g[node])
    {
        if (u == par || dep[u] <= d)
        {
            continue;
        }
        addDFS(u, node, d);
    }
}
void buildTour(int node)
{
    for (const int &u : c[node])
    {
        buildTour(u);
    }
    buildDFS(node, 0, dep[node], 0, node);
    addDFS(node, 0, dep[node]);
}
void init(int N, std::vector <int> F, std::vector <int> U, std::vector <int> V, int Q) 
{
    n = N; q = Q;
    for (int i = 0 ; i < n ; ++i)
    {
        f[i + 1] = F[i];
    }
    for (int i = 0 ; i < n - 1 ; ++i)
    {
        g[U[i] + 1].push_back(V[i] + 1);
        g[V[i] + 1].push_back(U[i] + 1);
    }
    int root = decompose(1, 1);
    buildTour(root);
}
void change(int node, int to)
{
    node++;
    callUpdate(node, dep[node], f[node], to);
    f[node] = to;
}
long long num_tours() 
{
    return answer;
}
/*
7
0 0 2 2 0 1 0
0 1
0 2
1 3
1 4
2 5
2 6
1
1 1
*/
| # | Verdict  | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict  | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict  | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict  | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict  | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict  | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict  | Execution time | Memory | Grader output | 
|---|
| Fetching results... |