답안 #922141

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
922141 2024-02-05T06:48:33 Z lozergam Designated Cities (JOI19_designated_cities) C++17
0 / 100
412 ms 81972 KB
//#pragma GCC optimize("Ofast")

//#include <bits/stdc++.h>

#include <iostream>
#include <vector>
#include <algorithm>
#include <array>
#include <set>

using namespace std;

#define ll long long
#define ld long double
#define pb push_back //emplace_back
#define pp pop_back
#define pii pair<ll, ll> //pair<int, int>
#define all(x) (x).begin(),(x).end()
#define mp(a,b) make_pair(a , b)
#define lb lower_bound
#define ub upper_bound
#define sz(x) (ll)(x).size()
#define F first
#define S second
#define bg(x) (x).begin()
#define For(x, n) for(int (x) = 0 ; (x) < (n) ; (x)++)
#define debug(x) cout << #x << " : " << x << endl << flush
#define endl '\n'
#define arr(x) array<ll , (x)>
#define yes cout << "YES\n"
#define no cout << "NO\n"
#define FAST ios_base::sync_with_stdio(0);cin.tie(0);

ll Sum(ll a , ll b , ll MOD)
{
 a %= MOD;
 b %= MOD;
 a += b;
 return a % MOD;
}

ll Mul(ll a , ll b , ll MOD)
{
 a %= MOD;
 b %= MOD;
 a *= b;
 return a % MOD;
}

ll Pow(ll a , ll b , ll MOD)
{
   ll res = 1;
   while(b)
   {
        if((b & 1))res = Mul(res , a , MOD);
     a = Mul(a , a , MOD);
     b >>= 1;
   }
   return res;
}

ll Min(ll a , ll b)
{
   if(a > b)return b;
   return a;
}

ll Max(ll a , ll b)
{
   if(a > b)return a;
   return b;
}

ll Ceil(ll a , ll b)
{
 if(a % b)return (a/b)+1;
 return a/b;
}

/////////////////////
//VALS
const ll maxn = 2e5;
const ll INF = 1e18;

ll n;
set<arr(3)> adj[maxn];
ll q;
ll barg;
ll par[maxn];
ll dp[maxn];
ll dp1;
ll root;
/////////////////////
//FUNCTIONS
void dfs_1_1(ll u, ll p)
{
	par[u] = p;
	
	for(auto v : adj[u])
	{
		if(v[0] == p)continue;
		
		dp1 += v[1];
		dfs_1_1(v[0], u);
	}
}

void dfs_1_2(ll u, ll val)
{
	for(auto v : adj[u])
	{
		if(v[0] == par[u])continue;
		//debug(u);
		//debug(v[0]);
		dp[1] = Min(dp[1], val - v[1] + v[2]);
		//debug(val - v[1] + v[2]);
		dfs_1_2(v[0], val - v[1] + v[2]);
	}
}

void compress(ll u)
{
	//debug(u);
	
	set<arr(3)> vv = adj[u];
	
	for(auto v : vv)
	{
		//debug(v[0]);
		if(v[0] != par[u])
			compress(v[0]);
	}

	if(u != root and sz(adj[u]) == 2)
	{
		auto ch = (*adj[u].begin());
		auto p = (*adj[u].rbegin());
		
		if(ch[0] == par[u])
			swap(ch,p);
			
		par[ch[0]] = p[0];
		adj[ch[0]].erase(adj[ch[0]].find({u, ch[2], ch[1]}));
		adj[p[0]].erase(adj[p[0]].find({u, p[2], p[1]}));
		adj[p[0]].insert({ch[0], p[2]+ch[1], p[1]+ch[2]});
		adj[ch[0]].insert({p[0], p[1]+ch[2], p[2]+ch[1]});
		adj[u].clear();
	} 
}
/////////////////////
//SOLVE
void solve()
{
	cin >> n;
	
	For(i,n-1)
	{
		ll a, b ,c ,d;
		cin >> a >> b >> c >> d;
		a--;
		b--;
		
		adj[a].insert({b,c,d});
		adj[b].insert({a,d,c});
	}
	
	barg = 0;
	For(i,n)if(sz(adj[i]) == 1)barg++;
			
	for(int i = barg; i <= n; i++)dp[i] = 0;
	
	root = 0;
	For(i,n)
		if(sz(adj[i]) != 1)
		{
			root = i;
			break;
		}
	
	dfs_1_1(root, root);
	dp[1] = dp1;
	//debug(dp[1]);	
	dfs_1_2(root, dp1);
	
												
	compress(root);
	
	/*
	For(i,n)
	{
		debug(i);
		debug(sz(adj[i]));
		for(auto v : adj[i])
		{
			debug(v[0]);
			debug(v[1]);
			debug(v[2]);
			cout << endl;		
		}
			debug("\n");
	}
	*/
	set<arr(3)> s;
	
	For(i,n)
		if(sz(adj[i]) == 1)
		{
			auto u = (*adj[i].begin());
			s.insert({u[2],u[1], i});
		}
	/*
	debug(sz(s));
	for(auto u : s)
	{
		debug(u[0]);
		debug(u[1]);
		debug(u[2]);
		cout << "AAAAAA\n" << flush;
	}
	*/
		
	ll cur = 0;
	barg--;	
	
	//debug(barg);
/*
15
14 5 12 7
14 12 6 5
14 10 14 16
9 14 16 12
13 7 4 15
1 3 8 1
6 7 15 13
15 4 4 6
9 1 12 6
13 1 7 6
13 4 5 15
2 6 11 19
8 4 12 7
13 11 14 5
3
3
6
7
*/
	for(int i = barg; i >= 2; i--)
	{
		//cout <<"AAAAAAA\n" << flush;
		auto bst = (*s.begin());
		s.erase(bst);
		//debug(bst[0]);
		//debug(bst[1]);
		//debug(bst[2]);
		cur += bst[0];
		dp[i] = cur;
		
		//bst -> {0 -> bottom to top, 1 -> top to bottom, 2 -> index}
		//cout << "BF\n" << flush;
		adj[par[bst[2]]].erase(adj[par[bst[2]]].find({bst[2], bst[0], bst[1]}));
		//cout << "AF\n" << flush;
		if(i != 1 and sz(adj[par[bst[2]]]) == 2 and par[bst[2]] != root)
		{
			//cout << ":P\n" << flush;
			ll u = par[bst[2]];
			auto ch = (*adj[u].begin());
			auto p = (*adj[u].rbegin());
			
			//debug(ch[0]);
			//debug(p[0]);
			
			if(ch[0] == par[u])
				swap(ch,p);
			
			par[ch[0]] = p[0];
			adj[ch[0]].erase(adj[ch[0]].find({u, ch[2], ch[1]}));
			//debug("I:");
			adj[p[0]].erase(adj[p[0]].find({u, p[2], p[1]}));
			//debug("I:");
			adj[p[0]].insert({ch[0], p[2]+ch[1], p[1]+ch[2]});
			//debug("I:");
			adj[ch[0]].insert({p[0], p[1]+ch[2], p[2]+ch[1]});
			//debug("I:");
			s.erase(s.find({ch[1], ch[2], ch[0]}));
			s.insert({p[2]+ch[1], p[1]+ch[2], ch[0]});
		}
		//cout << ";l\n" << flush;
	}
	
	cin >> q;
	
	while(q--)
	{
		//cout << "@_@\n" << endl;
		ll e;
		cin >> e;
		cout <<	endl;
		cout << dp[e] << endl;	
	}
}
/////////////////////
//MAIN
int main()
{
    FAST;
    ll t = 1;
//    cin >> t;
    while(t--)
    {
        solve();
    }
}
/////////////////////
/*
ZZZZZZZ     A        M     M     IIIIIII  N     N
     Z     A A      M M   M M       I     NN    N
    Z     A   A    M   M M   M      I     N N   N
   Z     AAAAAAA  M     M     M     I     N  N  N
  Z      A     A  M           M     I     N   N N
 Z       A     A  M           M     I     N    NN
ZZZZZZZ  A     A  M           M  IIIIIII  N     N  TREE
*/

Compilation message

designated_cities.cpp: In function 'void solve()':
designated_cities.cpp:26:27: warning: unnecessary parentheses in declaration of 'i' [-Wparentheses]
   26 | #define For(x, n) for(int (x) = 0 ; (x) < (n) ; (x)++)
      |                           ^
designated_cities.cpp:156:2: note: in expansion of macro 'For'
  156 |  For(i,n-1)
      |  ^~~
designated_cities.cpp:26:27: warning: unnecessary parentheses in declaration of 'i' [-Wparentheses]
   26 | #define For(x, n) for(int (x) = 0 ; (x) < (n) ; (x)++)
      |                           ^
designated_cities.cpp:168:2: note: in expansion of macro 'For'
  168 |  For(i,n)if(sz(adj[i]) == 1)barg++;
      |  ^~~
designated_cities.cpp:26:27: warning: unnecessary parentheses in declaration of 'i' [-Wparentheses]
   26 | #define For(x, n) for(int (x) = 0 ; (x) < (n) ; (x)++)
      |                           ^
designated_cities.cpp:173:2: note: in expansion of macro 'For'
  173 |  For(i,n)
      |  ^~~
designated_cities.cpp:26:27: warning: unnecessary parentheses in declaration of 'i' [-Wparentheses]
   26 | #define For(x, n) for(int (x) = 0 ; (x) < (n) ; (x)++)
      |                           ^
designated_cities.cpp:205:2: note: in expansion of macro 'For'
  205 |  For(i,n)
      |  ^~~
# 결과 실행 시간 메모리 Grader output
1 Correct 5 ms 9820 KB Output is correct
2 Runtime error 14 ms 19588 KB Execution killed with signal 6
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 7 ms 9820 KB Output is correct
2 Runtime error 381 ms 79872 KB Execution killed with signal 6
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 6 ms 9820 KB Output is correct
2 Runtime error 412 ms 81972 KB Execution killed with signal 6
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 5 ms 9820 KB Output is correct
2 Runtime error 14 ms 19588 KB Execution killed with signal 6
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 7 ms 9820 KB Output is correct
2 Runtime error 381 ms 79872 KB Execution killed with signal 6
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 5 ms 9820 KB Output is correct
2 Runtime error 14 ms 19588 KB Execution killed with signal 6
3 Halted 0 ms 0 KB -