이 제출은 이전 버전의 oj.uz에서 채점하였습니다. 현재는 제출 당시와는 다른 서버에서 채점을 하기 때문에, 다시 제출하면 결과가 달라질 수도 있습니다.
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<ll, ll> pll;
typedef pair<int, int> pii;
#define all(x) (x).begin(),(x).end()
#define X first
#define Y second
#define sep ' '
#define endl '\n'
#define SZ(x) ll(x.size())
#define lc id << 1
#define rc lc | 1
const ll MAXN = 2e5 + 10;
const ll LOG = 22;
const ll INF = 8e18;
const ll MOD = 1e9 + 7; //998244353; //1e9 + 9;
int n , m , C[MAXN] , dist[MAXN] , mx[MAXN] , up[MAXN], lz[MAXN << 2] , ans[MAXN];
pii seg[MAXN << 2];
vector<int> adj[MAXN] , col[MAXN];
pii Merge(pii L, pii R){
pii ans = {min(L.X , R.X), 0};
if(ans.X == L.X) ans.Y += L.Y;
if(ans.X == R.X) ans.Y += R.Y;
return ans;
}
void build(int id = 1 , int l = 0 , int r = n){
seg[id] = {0 , r - l};
lz[id] = 0;
if(r - l == 1) return;
int mid = l + r >> 1;
build(lc , l , mid);
build(rc , mid , r);
}
void shift(int id){
seg[lc].X += lz[id]; lz[lc] += lz[id];
seg[rc].X += lz[id]; lz[rc] += lz[id];
lz[id] = 0;
}
void add(int ql , int qr , int val , int id = 1 , int l = 0 , int r = n){
if(qr <= l || r <= ql) return;
if(ql <= l && r <= qr){
lz[id] += val;
seg[id].X += val;
return;
}
shift(id);
int mid = l + r >> 1;
add(ql , qr , val , lc , l , mid);
add(ql , qr , val , rc , mid , r);
seg[id] = Merge(seg[lc], seg[rc]);
}
pii get(int ql , int qr , int id = 1 , int l = 0 , int r = n){
if(qr <= l || r <= ql) return {0 , 0};
if(ql <= l && r <= qr) return seg[id];
shift(id);
int mid = l + r >> 1;
return Merge(get(ql , qr , lc , l , mid), get(ql , qr , rc , mid , r));
}
int Find(int pos , int id = 1 , int l = 0 , int r = n){
if(r - l == 1) return (seg[id].X == 0);
shift(id);
int mid = l + r >> 1;
if(pos < mid) return Find(pos , lc , l , mid);
return Find(pos , rc , mid , r);
}
void DistDFS(int v , int p = -1){
dist[v] = (p == -1 ? 0 : dist[p] + 1);
for(int u : adj[v]){
if(u == p) continue;
DistDFS(u , v);
}
}
void PreDFS(int v , int p = -1){
mx[v] = up[v] = 0;
for(int u : adj[v]){
if(u == p) continue;
PreDFS(u , v);
up[u] = mx[v];
mx[v] = max(mx[v] , mx[u] + 1);
}
reverse(all(adj[v]));
int mxl = 0;
for(int u : adj[v]){
if(u == p) continue;
up[u] = max(up[u] , mxl);
mxl = max(mxl , mx[u] + 1);
}
}
void SolveDFS(int v , int p = -1 , int h = 0){
for(int u : adj[v]){
if(u == p) continue;
int flag = 1;
add(h - up[u] , h , 1);
// cout << "Child " << v << sep << u << sep << h - up[u] << sep << h << endl;
if(SZ(col[C[v]]) && Find(col[C[v]].back())){
flag = 0;
// cout << "Par" << v << sep << u << sep << h << sep << h + 1 << endl;
add(h , h + 1 , 1);
}
else{
col[C[v]].push_back(h);
}
ans[u] = max(ans[u] , get(0 , h + 1 - mx[u]).Y);
SolveDFS(u , v , h + 1);
if(flag){
col[C[v]].pop_back();
}
else{
add(h , h + 1 , -1);
}
add(h - up[u] , h , -1);
}
}
int main() {
ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
cin >> n >> m;
for(int i = 1 ; i < n ; i++){
int v , u;
cin >> v >> u;
adj[v].push_back(u);
adj[u].push_back(v);
}
for(int i = 1 ; i <= n ; i++){
cin >> C[i];
}
DistDFS(1);
int v = max_element(dist , dist + MAXN) - dist;
DistDFS(v);
int u = max_element(dist , dist + MAXN) - dist;
// cout << v << sep << u << endl;
build();
PreDFS(v);
SolveDFS(v);
for(int i = 0 ; i < MAXN ; i++) col[i].clear();
build();
PreDFS(u);
SolveDFS(u);
for(int i = 1 ; i <= n ; i++){
cout << ans[i] << endl;
}
return 0;
}
/*
*/
컴파일 시 표준 에러 (stderr) 메시지
joi2019_ho_t5.cpp: In function 'void build(int, int, int)':
joi2019_ho_t5.cpp:37:14: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
37 | int mid = l + r >> 1;
| ~~^~~
joi2019_ho_t5.cpp: In function 'void add(int, int, int, int, int, int)':
joi2019_ho_t5.cpp:56:14: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
56 | int mid = l + r >> 1;
| ~~^~~
joi2019_ho_t5.cpp: In function 'pii get(int, int, int, int, int)':
joi2019_ho_t5.cpp:66:14: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
66 | int mid = l + r >> 1;
| ~~^~~
joi2019_ho_t5.cpp: In function 'int Find(int, int, int, int)':
joi2019_ho_t5.cpp:73:14: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
73 | int mid = l + r >> 1;
| ~~^~~
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |