This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
#include "closing.h"
#include <bits/stdc++.h>
#define F first
#define S second
using namespace std;
const int N = 2e5 + 10;
int n , x , y;
long long dis[2][N] , k , psum[2][N];
vector <pair<int , int>> adj[N];
void Dfs(int v , int ty , int p = -1)
{
for(auto u : adj[v]) if(u.F != p)
{
dis[ty][u.F] = dis[ty][v] + u.S;
Dfs(u.F , ty , v);
}
}
bool cmp(int aa , int bb)
{
return min(dis[0][aa] , dis[1][aa]) < min(dis[0][bb] , dis[1][bb]);
}
long long Get(int l , int r , int ty)
{
long long res = psum[ty][r];
if(l > 0)
res -= psum[ty][l - 1];
return res;
}
int Solve(int l , int r)
{
long long val = 0;
int ans = 0;
for(int i = l ; i <= r ; i++)
{
val += max(dis[0][i] , dis[1][i]);
ans += 2;
}
for(int i = x ; i < l ; i++)
{
ans++;
val += dis[0][i];
}
for(int i = r + 1 ; i <= y ; i++)
{
ans++;
val += dis[1][i];
}
if(val > k)
return 0;
int px = min(l , x) - 1 , py = max(r , y) + 1;
while(px != -1 && py != n)
{
if(dis[0][px] <= dis[1][py])
{
if(val + dis[0][px] > k)
break;
ans++;
val+= dis[0][px];
px--;
continue;
}
if(val + dis[1][py] > k)
break;
ans++;
val += dis[1][py];
py++;
}
while(py != n && val + dis[1][py] <= k)
{
val += dis[1][py];
py++;
ans++;
}
while(px != -1 && val + dis[0][px] <= k)
{
val += dis[0][px];
px--;
ans++;
}
return ans;
}
int max_score(int nn, int X, int Y, long long K,
vector<int> U, vector<int> V, vector<int> W)
{
n = nn;
x = X;
y = Y;
k = K;
if(x > y)
swap(x , y);
for(int i = 0 ; i < n ; i++)
adj[i].clear();
for(int i = 0 ; i < n - 1 ; i++)
{
adj[U[i]].push_back(make_pair(V[i] , W[i]));
adj[V[i]].push_back(make_pair(U[i] , W[i]));
}
//cout << x << " " << y << endl;
dis[0][x] = 0;
Dfs(x , 0);
dis[1][y] = 0;
Dfs(y , 1);
psum[0][0] = dis[0][0];
psum[1][0] = dis[1][0];
for(int i = 1 ; i < n ; i++)
{
psum[0][i] = dis[0][i] + psum[0][i - 1];
psum[1][i] = dis[1][i] + psum[1][i - 1];
}
int ans = 0;
for(int l = 0 ; l < n ; l++) for(int r = l ; r < n ; r++)
{
ans = max(ans , Solve(l , r));
}
vector <long long> all;
for(int i = 0 ; i < n ; i++)
{
//cout << i << " : " << dis[0][i] << " " << dis[1][i] << endl;
all.push_back(min(dis[0][i] , dis[1][i]));
}
sort(all.rbegin() , all.rend());
int ans2 = 0;
while(!all.empty() && k >= all.back())
{
k -= all.back();
all.pop_back();
ans2++;
}
return max(ans , ans2);
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |