#include "joitour.h"
#include <bits/stdc++.h>
using namespace std;
using ll=long long;
const int Nmax=300010;
int N, A[Nmax], P[Nmax];
ll ans, C[3][Nmax], D1[Nmax], D2[Nmax];
vector<int> adj[Nmax];
void DFS(int curr, int prev) {
C[A[curr]][curr]=1;
for(int next:adj[curr]) if(next!=prev) {
DFS(next, curr);
for(int i=0; i<3; i++) C[i][curr]+=C[i][next];
D1[curr]+=D1[next], D2[curr]+=D2[next];
}
if(A[curr]==1) D1[curr]+=C[0][curr], D2[curr]+=C[2][curr];
}
void init(int N_, vector<int> F, vector<int> U, vector<int> V, int Q) {
N=N_;
for(int i=1; i<=N; i++) A[i]=F[i-1];
for(int i=0; i<N-1; i++) {
adj[U[i]+1].push_back(V[i]+1), adj[V[i]+1].push_back(U[i]+1);
}
for(int i=1; i<=N; i++) P[i]=i/2;
DFS(1, 0);
for(int i=1; i<=N; i++) if(A[i]==1) {
if((i<<1|1)<=N) ans+=C[0][i<<1]*C[2][i<<1|1]+C[2][i<<1]*C[0][i<<1|1];
ans+=C[0][i]*(C[2][1]-C[2][i])+C[2][i]*(C[0][1]-C[0][i]);
}
}
void Delete(int x) {
if(A[x]==1) {
if((x<<1|1)<=N) ans-=C[0][x<<1]*C[2][x<<1|1]+C[2][x<<1]*C[0][x<<1|1];
ans-=C[0][x]*(C[2][1]-C[2][x])+C[2][x]*(C[0][1]-C[0][x]);
for(int i=x; i; i=P[i]) D1[i]-=C[0][x], D2[i]-=C[2][x];
for(int i=x; i; i=P[i]) C[1][i]--;
}
else if(A[x]==0) {
ans-=D2[x];
for(int i=x; i!=1; i=P[i]) ans-=((i==(P[i]<<1))?D2[P[i]<<1|1]:D2[P[i]<<1]);
for(int i=x; i!=1; i=P[i]) if(A[P[i]]==1) ans-=C[2][1]-C[2][i];
if(x>1) for(int i=P[x], c=0; i!=1; i=P[i]) c+=(A[i]==1), D1[i]-=c;
for(int i=x; i; i=P[i]) C[0][i]--;
}
else if(A[x]==2) {
ans-=D1[x];
for(int i=x; i!=1; i=P[i]) ans-=((i==(P[i]<<1))?D1[P[i]<<1|1]:D1[P[i]<<1]);
for(int i=x; i!=1; i=P[i]) if(A[P[i]]==1) ans-=C[0][1]-C[0][i];
if(x>1) for(int i=P[x], c=0; i!=1; i=P[i]) c+=(A[i]==1), D2[i]-=c;
for(int i=x; i; i=P[i]) C[2][i]--;
}
}
void Add(int x) {
if(A[x]==1) {
for(int i=x; i; i=P[i]) C[1][i]++;
for(int i=x; i; i=P[i]) D1[i]+=C[0][x], D2[i]+=C[2][x];
ans+=C[0][x]*(C[2][1]-C[2][x])+C[2][x]*(C[0][1]-C[0][x]);
if((x<<1|1)<=N) ans+=C[0][x<<1]*C[2][x<<1|1]+C[2][x<<1]*C[0][x<<1|1];
}
else if(A[x]==0) {
for(int i=x; i; i=P[i]) C[0][i]++;
if(x>1) for(int i=P[x], c=0; i!=1; i=P[i]) c+=(A[i]==1), D1[i]+=c;
for(int i=x; i!=1; i=P[i]) if(A[P[i]]==1) ans+=C[2][1]-C[2][i];
for(int i=x; i!=1; i=P[i]) ans+=((i==(P[i]<<1))?D2[P[i]<<1|1]:D2[P[i]<<1]);
ans+=D2[x];
}
else if(A[x]==2) {
for(int i=x; i; i=P[i]) C[2][i]++;
if(x>1) for(int i=P[x], c=0; i!=1; i=P[i]) c+=(A[i]==1), D2[i]+=c;
for(int i=x; i!=1; i=P[i]) if(A[P[i]]==1) ans+=C[0][1]-C[0][i];
for(int i=x; i!=1; i=P[i]) ans+=((i==(P[i]<<1))?D1[P[i]<<1|1]:D1[P[i]<<1]);
ans+=D1[x];
}
}
void change(int X, int Y) {
X++;
Delete(X);
A[X]=Y;
Add(X);
}
long long num_tours() {
return ans;
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
3 ms |
7512 KB |
Output is correct |
2 |
Incorrect |
3 ms |
7512 KB |
Wrong Answer [1] |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
3 ms |
7512 KB |
Output is correct |
2 |
Incorrect |
3 ms |
7512 KB |
Wrong Answer [1] |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
3 ms |
7512 KB |
Output is correct |
2 |
Incorrect |
137 ms |
39016 KB |
Wrong Answer [1] |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
3 ms |
7512 KB |
Output is correct |
2 |
Incorrect |
3 ms |
7512 KB |
Wrong Answer [1] |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
3 ms |
7512 KB |
Output is correct |
2 |
Incorrect |
3 ms |
7512 KB |
Wrong Answer [1] |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
3 ms |
7512 KB |
Output is correct |
2 |
Incorrect |
3 ms |
7512 KB |
Wrong Answer [1] |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
3 ms |
7512 KB |
Output is correct |
2 |
Incorrect |
3 ms |
7512 KB |
Wrong Answer [1] |
3 |
Halted |
0 ms |
0 KB |
- |