제출 #1115494

#제출 시각아이디문제언어결과실행 시간메모리
1115494SkymagicElection Campaign (JOI15_election_campaign)C++17
100 / 100
222 ms52048 KiB
#include <bits/stdc++.h>
#define int long long
#define pb push_back
#define mp make_pair
#define all(x) (x).begin(),(x).end()
#define fi first
#define sc second
#define endl "\n"
#define pii pair<int,int>
 
using namespace std;
 
const int MAXN = 1e5+5;
const int mod7 = 1e9+7;
const long long inf = 1e18;
const int lg = 18;
 
struct put{int a,b,w;};
 
int t = 1;
int n,m;
vector<vector<int>> graf(MAXN);
vector<vector<put>> sviputevi(MAXN);
int up[MAXN][lg];
int dist[MAXN];
int tin[MAXN];
int tout[MAXN];
int dp[MAXN];
 
 
vector<int>seg(MAXN<<3);
vector<int>lazy(MAXN<<3);
 
void push(int nod, int tl, int tr)
{
    if(!lazy[nod])return;
    if(tl!=tr)
    {
        lazy[nod<<1] += lazy[nod];
        lazy[nod<<1|1] += lazy[nod];
    }
    seg[nod] += (tr-tl+1)*lazy[nod];
    lazy[nod] = 0;
}
 
void update(int nod, int tl, int tr, int l, int r, int v)
{
    push(nod, tl, tr);
    if(tl > r || tr<l || tl>tr || l > r)return;
    if(tl>=l && tr<=r)
    {
        lazy[nod]+=v;
        push(nod, tl, tr);
        return;
    }
    int mid = tl+tr >> 1;
    update(nod<<1, tl, mid, l, r, v);
    update(nod<<1|1, mid+1, tr, l, r, v);
    seg[nod] = seg[nod<<1] + seg[nod<<1|1];
}
 
int query(int nod, int tl, int tr, int l, int r)
{
    push(nod, tl, tr);
    if(tl > r || tr<l || tl>tr || l > r)return 0;
    if(tl>=l && tr<=r)return seg[nod];
    int mid = tl+tr >> 1;
    return query(nod<<1, tl, mid, l, r) + query(nod<<1|1, mid+1, tr, l, r);
}
 
void dfs(int nod, int p, int d)
{
    tin[nod] = t++;
    dist[nod] = d;
    up[nod][0] = p;
    for(auto x: graf[nod])if(x!=p)dfs(x,nod, d+1);
    tout[nod] = t++;
}
 
void fillUp()
{
    for(int i=1; i<lg; i++)
    {
        for(int j=1; j<=n; j++)
        {
            up[j][i] = up[up[j][i-1]][i-1];
        }
    }
}
 
int lca(int u, int v)
{
    if(u==v)return u;
    if(tin[u] < tin[v] && tout[u] > tout[v])return u;
    if(tin[v] < tin[u] && tout[v] > tout[u])return v;
 
    for(int i=lg-1; i>=0; i--)
    {
        int guess = up[v][i];
        if(!(tin[guess] < tin[u] && tout[guess] > tout[u]))v = guess;
    }
    return up[v][0];
}
 
void solve(int nod, int p)
{
    for(auto x: graf[nod])
    {
        if(x!=p)solve(x, nod);
        dp[nod]+=dp[x];
    }
    int rez = dp[nod];
    for(auto x: sviputevi[nod])
    {
        int a = x.a;
        int b = x.b;
        int w = x.w;
        if(dist[a] > dist[b])swap(a,b);
 
        if(a == nod)rez = max(rez, dp[nod] - query(1,1,2*MAXN, b,b) + w);
        else
        {
            rez = max(rez, dp[nod] - query(1,1,2*MAXN, a,a) - query(1,1,2*MAXN, b,b) + w);
        }
    }
    dp[nod] = rez;
    update(1,1, 2*MAXN, tin[nod], tout[nod], dp[nod]);
 
}
 
signed main()
{
    ios_base::sync_with_stdio(false),cin.tie(0), cout.tie(0);
    int tt=1;
    //cin >> tt;
    while(tt--)
    {
        cin >> n;
 
        for(int i=0; i<n-1; i++)
        {
            int a,b;cin >> a >> b;
            graf[a].pb(b);
            graf[b].pb(a);
        }
        dfs(1,1,0);
        fillUp();
        cin >> m;
        vector<pii> dubine;
        for(int i=0; i<m; i++)
        {
            int a,b,w;cin >> a >> b >> w;
            sviputevi[lca(a,b)].pb({a,b,w});
        }
        for(int i=1; i<=n; i++)dubine.pb({dist[i], i});
        sort(all(dubine));
        reverse(all(dubine));
        for(int i=0; i<n; i++)
        {
            int nod = dubine[i].sc;
            int p = up[nod][0];
            for(auto x: graf[nod])
            {
                if(x==p)continue;
                dp[nod]+=dp[x];
            }
            int rez = dp[nod];
            for(auto x: sviputevi[nod])
            {
                int a = x.a;
                int b = x.b;
                int w = x.w;
                if(dist[a] > dist[b])swap(a,b);
 
                if(a == nod)
                {
                    int q = query(1,1,2*MAXN, tin[b],tin[b]);
                    rez = max(rez, dp[nod] - query(1,1,2*MAXN, tin[b],tin[b]) + w);
                }
                else
                {
                    rez = max(rez, dp[nod] - query(1,1,2*MAXN, tin[a],tin[a])- query(1,1,2*MAXN, tin[b],tin[b]) + w);
                }
            }
            update(1,1, 2*MAXN, tin[nod], tout[nod], rez - dp[nod]);
            dp[nod] = rez;
        }
        //solve(1,1);
        cout << dp[1] << endl;
 
 
    }
}

컴파일 시 표준 에러 (stderr) 메시지

election_campaign.cpp: In function 'void update(long long int, long long int, long long int, long long int, long long int, long long int)':
election_campaign.cpp:56:17: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
   56 |     int mid = tl+tr >> 1;
      |               ~~^~~
election_campaign.cpp: In function 'long long int query(long long int, long long int, long long int, long long int, long long int)':
election_campaign.cpp:67:17: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
   67 |     int mid = tl+tr >> 1;
      |               ~~^~~
election_campaign.cpp: In function 'int main()':
election_campaign.cpp:177:25: warning: unused variable 'q' [-Wunused-variable]
  177 |                     int q = query(1,1,2*MAXN, tin[b],tin[b]);
      |                         ^
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...