#include "closing.h"
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
typedef pair<ll,int> pli;
typedef vector<int> vi;
typedef vector<ll> vl;
#define all(x) x.begin(), x.end()
#define sz(x) (int)x.size()
#define pb push_back
#define mk make_pair
#define fr first
#define sc second
vl calc_dist(int X, vector<vector<pii>> edges) {
int N = sz(edges);
vl dist(N, 1e12);
vi vis(N);
priority_queue<pli, vector<pli>, greater<pli>> pq;
pq.push({0,X});
dist[X] = 0;
while(!pq.empty()) {
auto [D,at] = pq.top();
pq.pop();
if(vis[at]) continue;
vis[at] = 1;
for(auto [viz, peso] : edges[at]) {
if(dist[viz] > D + peso) {
dist[viz] = D + peso;
pq.push({D+peso, viz});
}
}
}
return dist;
}
int solve(ll K, vl a, vl b) {
int n = sz(a);
set<pli> op[3]; // -1, 1, 2
if(K < 0) return - 2*n;
for(int i = 0; i < n; i++) {
op[1].insert({a[i], i});
op[2].insert({b[i], i});
}
int l = 0;
ll at_val = 0;
vector<int> at(n);
while(at_val <= K) {
ll A = K+1, B = K+1;
if(sz(op[0]) and sz(op[2]))
A = (*op[0].begin()).fr + (*op[2].begin()).fr;
if(sz(op[1]))
B = (*op[1].begin()).fr;
if(A < B) {
int i = (*op[0].begin()).sc;
int j = (*op[2].begin()).sc;
if(at[i] == 1) {
op[0].erase({-a[i], i});
op[1].erase({b[i]-a[i], i});
at[i] = 0;
op[1].insert({a[i], i});
op[2].insert({b[i], i});
}
else {
op[0].erase({-b[i]+a[i], i});
at[i] = 1;
op[0].insert({-a[i], i});
op[1].insert({b[i]-a[i], i});
}
op[1].erase({a[j], j});
op[2].erase({b[j], j});
at[j] = 2;
op[0].insert({-b[j]+a[j], j});
at_val += A;
}
else {
int i = (*op[1].begin()).sc;
if(at[i] == 1) {
op[0].erase({-a[i], i});
op[1].erase({b[i]-a[i], i});
at[i] = 2;
op[0].insert({-b[i]+a[i], i});
}
else {
op[1].erase({a[i], i});
op[2].erase({b[i], i});
at[i] = 1;
op[0].insert({-a[i], i});
op[1].insert({b[i]-a[i], i});
}
at_val += B;
}
if(at_val > K) break;
l++;
}
return l;
}
int max_score(int N, int X, int Y, ll K, vi U, vi V, vi W) {
vector<vector<pii>> edges(N);
for(int i = 0; i+1 < N; i++)
edges[U[i]].pb({V[i], W[i]}), edges[V[i]].pb({U[i], W[i]});
vl distX = calc_dist(X, edges), distY = calc_dist(Y, edges);
vl a(N), b(N);
set<pli> cam;
for(int i = 0; i < N; i++) {
a[i] = min(distX[i], distY[i]);
b[i] = max(distX[i], distY[i]);
cam.insert({a[i], i});
}
int ans1 = 0, ans2 = 0;
ll at_val = 0;
while(!cam.empty()) {
auto [D, ind] = *cam.begin();
cam.erase(cam.begin());
at_val += D;
if(at_val > K) break;
ans1++;
}
at_val = 0;
for(int i = 0; i < N; i++) {
if(distX[Y] == a[i] + b[i])
at_val += a[i], a[i] = b[i]-a[i], b[i] = K+1, ans2++;
}
return max(ans1, ans2 + solve(K - at_val, a, b));
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |