#include "closing.h"
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define pb push_back
#define ff first
#define sd second
#define debug(x) cerr << #x << "----> " << x << endl;
//#pragma GCC optimize("unroll-loops")
//#pragma GCC optimize("Ofast")
//#pragma GCC optimize("O3")
const int mxN = 3e5 + 5;
ll n,dis[mxN],dis1[mxN],dis2[mxN];
vector<pair<ll,ll>> v[mxN];
vector<ll> vv;
void dfs(ll at, ll par, bool ok = false){
if(ok) vv.pb(at);
for(auto it : v[at]){
if(it.ff == par) continue;
dis[it.ff] = dis[at] + it.sd;
dfs(it.ff, at, ok);
}
}
int max_score(int N, int X, int Y, long long K,
std::vector<int> U, std::vector<int> V, std::vector<int> W){
n = N;
for(int i = 0; i < n - 1; i++){
v[U[i]].pb({V[i], W[i]});
v[V[i]].pb({U[i], W[i]});
}
ll root = 0;
for(int i = 0; i < n; i++) if(v[i].size() == 1){
root = i;
break;
}
dis[root] = 0;
vv.clear();
dfs(root, root, true);
int x,y;
for(int i = 0; i < n; i++){
if(vv[i] == X) x = i;
if(vv[i] == Y) y = i;
}
if(x > y) swap(x, y);
dis[X] = 0;
dfs(X, X);
for(int i = 0; i < n; i++) dis1[i] = dis[vv[i]];
dis[Y] = 0;
dfs(Y, Y);
for(int i = 0; i < n; i++) dis2[i] = dis[i];
for(int i = 0; i < n; i++) dis[i] = dis2[vv[i]];
ll sum = 0,ans = 0;
vector<ll> v2;
for(int i = 0; i < n; i++){
v2.pb(dis[i]);
v2.pb(dis1[i]);
}
sort(v2.begin(), v2.end());
for(auto it : v2){
sum += it;
if(sum > K) break;
ans++;
}
for(int l = 0; l <= y; l++){
multiset<pair<ll,ll>> s,s1;
sum = 0;
ll ans1 = 2 * max(1, x - l + 1),sum1 = 0;
for(int i = 0; i < l; i++){
s.insert({dis[i], 0});
sum1 += dis[i];
}
for(int i = y + 1; i < n; i++){
s.insert({dis[i], 0});
sum1 += dis[i];
}
for(int i = 0; i < min(x, l); i++){
s.insert({dis1[i], 1});
sum1 += dis1[i];
}
for(int i = max(l, x) + 1; i < n; i++){
if(i <= y){
ans1++;
sum += dis[i];
}
sum1 += dis1[i];
s.insert({dis1[i], 1});
}
for(int i = x; i < l; i++){
ans1++;
sum += dis1[i];
}
for(int i = l; i <= max(l, x); i++) sum += max(dis[i], dis1[i]);
if(sum > K) continue;
while(sum + sum1 > K){
auto it = s.end(); --it;
s1.insert(*it);
sum1 -= (*it).ff;
s.erase(it);
}
ans = max(ans, ans1 + (ll)s.size());
for(int r = max(l, x) + 1; r < n; r++){
auto it = s.find({dis[r], 0}),it1 = s.find({dis1[r], 1});
if(it != s.end()){
sum1 -= (*it).ff;
s.erase(it);
}
if(it1 != s.end()){
sum1 -= (*it1).ff;
s.erase(it1);
}
if(r <= y){
sum -= dis[r];
ans1--;
}
ans1 += 2;
sum += max(dis[r], dis1[r]);
if(sum > K) break;
// while(sum + sum1 > K){
// auto it2 = s.end(); --it2;
// s1.insert(*it2);
// sum1 -= (*it2).ff;
// s.erase(it2);
// }
// while(s1.size()){
// auto it2 = s1.begin();
// if((*it2).ff + sum1 + sum > K) break;
// sum1 += (*it2).ff;
// s.insert(*it2);
// s1.erase(it2);
// }
ans = max(ans, ans1 + (ll)s.size());
}
}
for(int i = 0; i < n; i++) v[i].clear();
return ans;
}
//1
//6 0 3 10
//5 3 5
//3 1 2
//1 2 1
//2 0 3
//0 4 7
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |