/*
#pragma GCC optimize("Ofast")
#pragma GCC target("avx,avx2,fma")
#pragma GCC optimize("unroll-loops")
*/
#include<bits/stdc++.h>
#define int long long
using namespace std;
#define all(x) x.begin(), x.end()
#define len(x) ll(x.size())
#define eb emplace_back
#define PI 3.14159265359
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define MIN(v) *min_element(all(v))
#define MAX(v) *max_element(all(v))
#define BIT(x,i) (1&((x)>>(i)))
#define MASK(x) (1LL<<(x))
#define task "tnc"
typedef long long ll;
const ll INF=1e18;
const int maxn=1e6+5;
const int mod=1e9+7;
const int mo=998244353;
using pi=pair<ll,ll>;
using vi=vector<ll>;
using pii=pair<pair<ll,ll>,ll>;
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
int n,k;
int res[maxn];
vector<pair<int,int>>G[maxn];
int sum=0;
int f[maxn];
int dist[maxn];
int mark[maxn];
int dp[maxn];
void dfs(int u,int pa,int mid){
for(int i=0; i<(int)G[u].size(); ++i)
{
int v = G[u][i].fi;
if(v == pa) continue;
dfs(v, u, mid);
int w = G[u][i].se;
f[u] = max(f[u], f[v] - w);
dp[u] += dp[v];
if(dist[v] == 0 && f[v] >= 0)
{
mark[v] = 2;
continue;
}
if(dist[v] + w > mid)
{
mark[v] = 1;
f[u] = max(f[u], mid - w);
dp[u]++;
}
}
for(auto v:G[u]){
if(v.fi==pa || mark[v.fi])continue;
if(f[u]-v.se>=dist[v.fi]){
mark[v.fi]=2;
}
}
for(auto v:G[u]){
if(v.fi==pa || mark[v.fi])continue;
dist[u]=max(dist[u],dist[v.fi]+v.se);
}
}
bool check(int mid){
sum=0;
for(int i=1;i<=n;i++){
f[i]=-INF;
dist[i]=0;
mark[i]=0;
dp[i]=0;
res[i]=0;
}
dfs(1,-1,mid);
dp[1]++;
if(dist[1]==0 && f[1]>=0){
dp[1]--;
}
else{
mark[1]=1;
}
if(dp[1]<=k){
return true;
}
else{
return false;
}
}
signed main()
{
cin.tie(0),cout.tie(0)->sync_with_stdio(0);
//freopen(task".inp" , "r" , stdin);
//freopen(task".out" , "w" , stdout);
cin>>n>>k;
int l,r;
l=0;
for(int i=1;i<n;i++){
int x,y,w;
cin>>x>>y>>w;
G[x].pb({y,w});
G[y].pb({x,w});
}
r=1e15;
int ans;
while(l<=r){
int mid=(l+r)/2;
if(check(mid)){
ans=mid;
for(int i=1;i<=n;i++){
if(mark[i]==1){
res[i]=1;
}
else res[i]=0;
}
r=mid-1;
}
else{
l=mid+1;
}
}
int cnt=0;
for(int i=1;i<=n;i++){
if(res[i]==1){
cnt++;
}
}
cnt=k-cnt;
for(int i=1;i<=n;i++){
if(res[i]==0 && cnt>0){
res[i]=1;
cnt--;
}
}
cout<<ans<<"\n";
for(int i=1;i<=n;i++){
if(res[i]){
cout<<i<<" ";
}
}
return 0;
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
12 ms |
23804 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
397 ms |
54656 KB |
Output is correct |
2 |
Correct |
424 ms |
55496 KB |
Output is correct |
3 |
Incorrect |
533 ms |
41008 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
432 ms |
56064 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
12 ms |
23804 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |