# | Time | Username | Problem | Language | Result | Execution time | Memory |
---|---|---|---|---|---|---|---|
1064177 | parsadox2 | Closing Time (IOI23_closing) | C++17 | 0 ms | 0 KiB |
This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
#include "closing.h"
#include <bits/stdc++.h>
#define F first
#define S second
using namespace std;
const int N = 2e5 + 10;
int n , x , y;
long long dis[2][N] , k , psum[3][N];
vector <pair<int , int>> adj[N];
void Dfs(int v , int ty , int p = -1)
{
for(auto u : adj[v]) if(u.F != p)
{
dis[ty][u.F] = dis[ty][v] + u.S;
Dfs(u.F , ty , v);
}
}
bool cmp(int aa , int bb)
{
return min(dis[0][aa] , dis[1][aa]) < min(dis[0][bb] , dis[1][bb]);
}
long long Get(int l , int r , int ty)
{
if(l > r)
return 0;
long long res = psum[ty][r];
if(l > 0)
res -= psum[ty][l - 1];
return res;
}
int Solve(int l , int r)
{
long long val = 0;
int ans = 0;
val = Get(l , r , 2);
ans = 2 * (r - l + 1);
val += Get(x , l , 0);
ans += max(0LL , l - x);
val += Get(r + 1 , y , 1);
ans += max(0LL , y - r);
if(val > k)
return 0;
int px = min(l , x) - 1 , py = max(r , y) + 1;
while(px != -1 && py != n)
{
if(dis[0][px] <= dis[1][py])
{
if(val + dis[0][px] > k)
break;
ans++;
val+= dis[0][px];
px--;
continue;
}
if(val + dis[1][py] > k)
break;
ans++;
val += dis[1][py];
py++;
}
while(py != n && val + dis[1][py] <= k)
{
val += dis[1][py];
py++;
ans++;
}
while(px != -1 && val + dis[0][px] <= k)
{
val += dis[0][px];
px--;
ans++;
}
return ans;
}
int max_score(int nn, int X, int Y, long long K,
vector<int> U, vector<int> V, vector<int> W)
{
n = nn;
x = X;
y = Y;
k = K;
if(x > y)
swap(x , y);
for(int i = 0 ; i < n ; i++)
adj[i].clear();
for(int i = 0 ; i < n - 1 ; i++)
{
adj[U[i]].push_back(make_pair(V[i] , W[i]));
adj[V[i]].push_back(make_pair(U[i] , W[i]));
}
//cout << x << " " << y << endl;
dis[0][x] = 0;
Dfs(x , 0);
dis[1][y] = 0;
Dfs(y , 1);
psum[0][0] = dis[0][0];
psum[1][0] = dis[1][0];
psum[2][0] = max(dis[0][0] , dis[1][0]);
for(int i = 1 ; i < n ; i++)
{
psum[0][i] = dis[0][i] + psum[0][i - 1];
psum[1][i] = dis[1][i] + psum[1][i - 1];
psum[2][i] = max(dis[0][i] , dis[1][i]) + psum[2][i - 1];
}
int ans = 0;
for(int l = 0 ; l < n ; l++) for(int r = l ; r < n ; r++)
{
ans = max(ans , Solve(l , r));
}
vector <long long> all;
for(int i = 0 ; i < n ; i++)
{
//cout << i << " : " << dis[0][i] << " " << dis[1][i] << endl;
all.push_back(min(dis[0][i] , dis[1][i]));
}
sort(all.rbegin() , all.rend());
int ans2 = 0;
while(!all.empty() && k >= all.back())
{
k -= all.back();
all.pop_back();
ans2++;
}
return max(ans , ans2);
}