Submission #976833

#TimeUsernameProblemLanguageResultExecution timeMemory
976833saayan007Factories (JOI14_factories)C++17
100 / 100
7461 ms276012 KiB
#include "factories.h"
#include "bits/stdc++.h"
using namespace std;
 
#define rep(i, a, b) for(int (i) = (a); i <= (b); ++i)
#define repd(i, a, b) for(int (i) = (a); i >= (b); --i)

#define em emplace
#define eb emplace_back

#define pi pair<int, int>
#define fr first
#define sc second
#define mp make_pair

const char nl = '\n';
 
#warning using int and long long as well
using ll = long long;
const int mxN = 5e5L + 10;
const int logN = 19;
vector<pair<int, ll>> adj[mxN];
int n;
int par[mxN][logN + 1], dep[mxN];
ll dist[mxN][logN + 1];
int tmr = 0;
int tin[mxN], tout[mxN];
vector<pair<int, ll>> vt[mxN];
/* set<int> white, black, graph; */
int white[mxN] = {}, black[mxN] = {}, graph[mxN] = {};
ll cB, cW;
ll ans;
int TC = 5;
 
void dfs(int x, int p) {
    tin[x] = tmr++;
    for(auto U : adj[x]) {
        int y = U.first; ll w = U.second;
        if(y == p) continue;
        dep[y] = dep[x] + 1;
        dist[y][0] = w;
        par[y][0] = x;
        dfs(y, x);
    }
    tout[x] = tmr++;
}
 
void Init(int N, int A[], int B[], int D[]) {
    n = N;
    for(int i = 0; i < N - 1; ++i) {
        adj[A[i]].emplace_back(B[i], ll(D[i]));
        adj[B[i]].emplace_back(A[i], ll(D[i]));
    }
 
    dep[0] = 0;
    dist[0][0] = -1;
    par[0][0] = -1;
    dfs(0, -1);
    for(int j = 1; j < logN; ++j) {
        for(int i = 0; i < N; ++i) {
            if(dep[i] < (1ll << j)) {
                par[i][j] = dist[i][j] = -1;
            }
            else {
                par[i][j] = par[par[i][j - 1]][j - 1];
                dist[i][j] = dist[i][j - 1] + dist[par[i][j - 1]][j - 1];
            }
        }
    }
}
 
int lca(int x, int y) {
    if(dep[x] < dep[y]) swap(x, y);
    ll diff = dep[x] - dep[y];
    for(int j = logN - 1; j >= 0; --j) {
        if(!(diff & (1ll << j))) continue;
        x = par[x][j];
    }
    if(x == y) return x;
 
    for(int j = logN - 1; j >= 0; --j) {
        if(par[x][j] == par[y][j]) continue;
        x = par[x][j];
        y = par[y][j];
    }
    x = par[x][0];
    y = par[y][0];
    assert(x == y);
    return x;
}

ll distance(int x, int y) {
    ll res = 0;
    if(dep[x] < dep[y]) swap(x, y);
    ll diff = dep[x] - dep[y];
    for(int j = logN - 1; j >= 0; --j) {
        if(!(diff & (1ll << j))) continue;
        res += dist[x][j];
        x = par[x][j];
    }
 
    if(x == y) return res;
 
    for(int j = logN - 1; j >= 0; --j) {
        if(par[x][j] == par[y][j]) continue;
        res += dist[x][j];
        x = par[x][j];
        res += dist[y][j];
        y = par[y][j];
    }
    res += dist[x][0];
    x = par[x][0];
    res += dist[y][0];
    y = par[y][0];
    assert(x == y);
    return res;
}
 
bool contains(int up, int dn) {
    return (tin[up] < tin[dn] && tout[dn] < tout[up]);
}

/* void solve(int x, int p, ll tot) { */
/*     if(black.count(x) && tot < ans) ans = tot; */
/*     for(auto U : vt[x]) { */
/*         int y = U.fr; ll w = U.sc; */
/*         if(y == p) continue; */
/*         solve(y, x, tot + w); */
/*     } */
/* } */

int proc[mxN] = {};
int sz[mxN];

int get_sz(int x, int p) {
    sz[x] = 1;
    for(auto U : vt[x]) {
        int y = U.fr;
        if(y == p || proc[y] == TC) continue;
        sz[x] += get_sz(y, x);
    }
    /* cout << "Size of " << x << " is " << sz[x] << nl; */
    return sz[x];
}

int get_cen(int x, int p, int tot) {
    for(auto U : vt[x]) {
        int y = U.fr;
        if(y == p || proc[y] == TC || sz[y] * 2 < tot) continue;
        return get_cen(y, x, tot);
    }
    return x;
}

void solve(int x, int p, ll dd) {
    /* if(white.count(x)) { */
    if(white[x] == TC) {
        ans = min(ans, dd + cB);
        cW = min(cW, dd);
    }
    /* else if(black.count(x)) { */
    else if(black[x] == TC) {
        ans = min(ans, dd + cW);
        cB = min(cB, dd);
    }
    for(auto U : vt[x]) {
        int y = U.fr; ll w = U.sc;
        if(y == p || proc[y] == TC) continue;
        solve(y, x, dd + w);
    }
}

void decompose(int x, int p) {
    int c = get_cen(x, p, get_sz(x, p));

    /* cout << "Centroid is " << c << nl; */
    /* cout << "size is " << sz[x] << nl; */
    cW = cB = 1e18L;
    /* if(white.count(c)) cW = 0; */
    if(white[c] == TC) cW = 0;
    /* else if(black.count(c)) cB = 0; */
    else if(black[c] == TC) cB = 0;
    /* if(white[c] == TC) cout << c << " is white\n"; */
    /* if(black[c] == TC) cout << c << " is black\n"; */
    solve(c, p, 0);

    proc[c] = TC;
    /* cout << nl; */
    for(auto U : vt[c]) {
        int d = U.fr;
        if(proc[d] == TC) continue;
        decompose(d, c);
    }
}

long long Query(int S, int X[], int T, int Y[]) {
    ++TC;
    /* cout << nl; */
    /* cout << "TEST#" << TC << nl; */
    vector<int> nodes;

    /* graph.clear(); */
    /* white.clear(); */
    rep(i, 0, S - 1) {
        int x = X[i];
        /* white.insert(x); */
        white[x] = TC;
        /* if(!graph.count(x)) while(int(vt[x].size())) vt[x].pop_back(); */
        /* graph.insert(x); */
        if(graph[x] != TC) while(int(vt[x].size())) vt[x].pop_back();
        graph[x] = TC;
        /* proc[x] = 0; */
        nodes.eb(x);
    }
    /* black.clear(); */
    rep(i, 0, T - 1) {
        int y = Y[i];
        /* black.insert(y); */
        black[y] = TC;
        if(graph[y] != TC) while(int(vt[y].size())) vt[y].pop_back();
        graph[y] = TC;
        /* if(!graph.count(y)) while(int(vt[y].size())) vt[y].pop_back(); */
        /* graph.insert(y); */
        /* proc[y] = 0; */
        nodes.eb(y);
    }

    sort(nodes.begin(), nodes.end(), [&](int left, int right) {
        return tin[left] < tin[right];
    });
    rep(i, 0, S + T - 2) {
        nodes.eb(lca(nodes[i], nodes[i + 1]));
    }
    sort(nodes.begin(), nodes.end(), [&](int left, int right) {
        return tin[left] < tin[right];
    });
    nodes.erase(unique(nodes.begin(), nodes.end()), nodes.end());
    
    vector<int> st;
    st.eb(nodes[0]);
    rep(i, 1, int(nodes.size()) - 1) {
        int u = nodes[i];
        while(int(st.size()) >= 2 && !contains(st.back(), u)) {
            int a = st[int(st.size()) - 1];
            int b = st[int(st.size()) - 2];
            ll distab = distance(a, b);

            if(graph[a] != TC) while(int(vt[a].size())) vt[a].pop_back();
            graph[a] = TC;
            /* if(!graph.count(a)) while(int(vt[a].size())) vt[a].pop_back(); */
            /* graph.insert(a); */
            /* proc[a] = 0; */
            vt[a].eb(b, distab);
            /* cout << "Edge " << a << ' ' << b << ' ' << distab << nl; */

            if(graph[b] != TC) while(int(vt[b].size())) vt[b].pop_back();
            graph[b] = TC;
            /* if(!graph.count(b)) while(int(vt[b].size())) vt[b].pop_back(); */
            /* graph.insert(b); */
            /* proc[b] = 0; */
            vt[b].eb(a, distab);
            /* cout << "Edge " << b << ' ' << a << ' ' << distab << nl; */

            st.pop_back();
        }
        st.eb(u);
    }

    while(int(st.size()) >= 2) {
        int a = st[int(st.size()) - 1];
        int b = st[int(st.size()) - 2];

        ll distab = distance(a, b);
        if(graph[a] != TC) while(int(vt[a].size())) vt[a].pop_back();
        graph[a] = TC;
        /* if(!graph.count(a)) while(int(vt[a].size())) vt[a].pop_back(); */
        /* graph.insert(a); */
        /* proc[a] = 0; */
        vt[a].eb(b, distab);
        /* cout << "Edge " << a << ' ' << b << ' ' << distab << nl; */

        if(graph[b] != TC) while(int(vt[b].size())) vt[b].pop_back();
        graph[b] = TC;
        /* if(!graph.count(b)) while(int(vt[b].size())) vt[b].pop_back(); */
        /* graph.insert(b); */
        vt[b].eb(a, distab);
        /* cout << "Edge " << b << ' ' << a << ' ' << distab << nl; */

        st.pop_back();
    }

    rep(i, 0, n - 1) {
        /* if(white[i] == TC) cout << i << " is white\n"; */
        /* if(black[i] == TC) cout << i << " is black\n"; */
        /* if(graph[i] == TC) cout << i << " is graph:"; */
        /* for(auto U : vt[i]) cout << U.fr << ' '; */
        /* cout << nl; */
    }
    ans = 1e18L;
    decompose(X[0], -1);
    return ans;
}

Compilation message (stderr)

factories.cpp:18:2: warning: #warning using int and long long as well [-Wcpp]
   18 | #warning using int and long long as well
      |  ^~~~~~~
factories.cpp: In function 'long long int Query(int, int*, int, int*)':
factories.cpp:5:30: warning: unnecessary parentheses in declaration of 'i' [-Wparentheses]
    5 | #define rep(i, a, b) for(int (i) = (a); i <= (b); ++i)
      |                              ^
factories.cpp:204:5: note: in expansion of macro 'rep'
  204 |     rep(i, 0, S - 1) {
      |     ^~~
factories.cpp:5:30: warning: unnecessary parentheses in declaration of 'i' [-Wparentheses]
    5 | #define rep(i, a, b) for(int (i) = (a); i <= (b); ++i)
      |                              ^
factories.cpp:216:5: note: in expansion of macro 'rep'
  216 |     rep(i, 0, T - 1) {
      |     ^~~
factories.cpp:5:30: warning: unnecessary parentheses in declaration of 'i' [-Wparentheses]
    5 | #define rep(i, a, b) for(int (i) = (a); i <= (b); ++i)
      |                              ^
factories.cpp:231:5: note: in expansion of macro 'rep'
  231 |     rep(i, 0, S + T - 2) {
      |     ^~~
factories.cpp:5:30: warning: unnecessary parentheses in declaration of 'i' [-Wparentheses]
    5 | #define rep(i, a, b) for(int (i) = (a); i <= (b); ++i)
      |                              ^
factories.cpp:241:5: note: in expansion of macro 'rep'
  241 |     rep(i, 1, int(nodes.size()) - 1) {
      |     ^~~
factories.cpp:5:30: warning: unnecessary parentheses in declaration of 'i' [-Wparentheses]
    5 | #define rep(i, a, b) for(int (i) = (a); i <= (b); ++i)
      |                              ^
factories.cpp:292:5: note: in expansion of macro 'rep'
  292 |     rep(i, 0, n - 1) {
      |     ^~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...