# | 제출 시각 | 아이디 | 문제 | 언어 | 결과 | 실행 시간 | 메모리 |
---|---|---|---|---|---|---|---|
199508 | shahriarkhan | 경주 (Race) (IOI11_race) | C++14 | 0 ms | 0 KiB |
이 제출은 이전 버전의 oj.uz에서 채점하였습니다. 현재는 제출 당시와는 다른 서버에서 채점을 하기 때문에, 다시 제출하면 결과가 달라질 수도 있습니다.
#include<bits/stdc++.h>
using namespace std ;
const int mx = 2e5 + 5 , ww = 1e7 + 6 ;
map<pair<int,int> , int > cost ;
vector<int> adj[mx] ;
vector<pair<int , int > > vx ;
int subtree[mx] , vis[mx] , dist[ww] , n , ans = mx , k ;
void dfs(int s , int par)
{
int siz = adj[s].size() ;
subtree[s] = 1 ;
for(int i = 0 ; i < siz ; ++i)
{
if(adj[s][i]!=par && vis[adj[s][i]]==0)
{
dfs(adj[s][i],s) ;
subtree[s] += subtree[adj[s][i]] ;
}
}
}
int centroid(int s , int par , int n)
{
int siz = adj[s].size() ;
for(int i = 0 ; i < siz ; ++i)
{
if(adj[s][i]!=par && vis[adj[s][i]]==0)
{
if((subtree[adj[s][i]])>n) return centroid(adj[s][i],s,n) ;
}
}
return s ;
}
void findans(int s , int p , int num , int sum)
{
if(sum<=k) ans = min(ans,num + dist[k-sum]) ;
int siz = adj[s].size() ;
for(int i = 0 ; i < siz ; ++i)
{
if(vis[adj[s][i]]==0 && adj[s][i]!=p)
{
findans(adj[s][i],s,num+1,sum+cost[{s,adj[s][i]}]) ;
}
}
}
void addcnt(int s , int p , int num , int sum)
{
if(sum<=k) dist[sum] = min(dist[sum],num) ;
int siz = adj[s].size() ;
for(int i = 0 ; i < siz ; ++i)
{
if(vis[adj[s][i]]==0 && adj[s][i]!=p) addcnt(adj[s][i],s,num+1,sum+cost[{s,adj[s][i]}]) ;
}
}
void eras(int s , int p , int sum)
{
if(sum<=k) dist[sum] = mx ;
int siz = adj[s].size() ;
for(int i = 0 ; i < siz ; ++i)
{
if(adj[s][i] != p && vis[adj[s][i]]==0)
{
eras(adj[s][i],s,sum+cost[{s,adj[s][i]}]) ;
}
}
}
void decomp(int s , int par)
{
dfs(s,par) ;
int c = centroid(s,par,subtree[s]/2) ;
vis[c] = 1 ;
int siz = adj[c].size() ;
dist[0] = 0 ;
for(int i = 0 ; i < siz ; ++i)
{
int v = adj[c][i] ;
if(!vis[v])
{
findans(v,c,1,cost[{c,v}]) ;
addcnt(v,c,1,cost[{c,v}]) ;
}
}
for(int i = 0 ; i < siz ; ++i)
{
int v = adj[c][i] ;
if(!vis[v])
{
eras(v,c,cost[{c,v}]) ;
}
}
for(int i = 0 ; i < siz ; ++i)
{
if(!vis[adj[c][i]]) decomp(adj[c][i],c) ;
}
}
int main()
{
scanf("%d%d",&n,&k) ;
for(int i = 1 ; i < n ; ++i)
{
int a , b ;
scanf("%d%d",&a,&b) ;
adj[a].push_back(b) ;
adj[b].push_back(a) ;
vx.push_back({a,b}) ;
}
for(int i = 0 ; i < n - 1 ; ++i)
{
int x ;
scanf("%d",&x) ;
int a = vx[i].first , b = vx[i].second ;
cost[{a,b}] = x ;
cost[{b,a}] = x ;
}
for(int i = 0 ; i <= k ; ++i) dist[i] = mx ;
decomp(0,-1) ;
if(ans==mx) printf("-1\n") ;
else printf("%d\n",ans) ;
return 0 ;
}