Submission #909984

# Submission time Handle Problem Language Result Execution time Memory
909984 2024-01-17T17:26:12 Z EJIC_B_KEDAX Just Long Neckties (JOI20_ho_t1) C++17
0 / 100
19 ms 24656 KB
#ifdef LOCAL
    # define _GLIBCXX_DEBUG
#endif
#include <bits/stdc++.h>
#include <immintrin.h>

#ifndef LOCAL
    #pragma GCC optimize("O3")
    // #pragma GCC optimize("Ofast")
    // #pragma GCC optimize("unroll-loops")
    // #pragma GCC target("avx,avx2,bmi,bmi2,popcnt,lzcnt,sse,sse2,sse3,ssse3,mmx")
    // #pragma GCC target("avx2")
#endif
using namespace std;
using ll = long long;
using ld = long double;
# define x first
# define y second
# define all(x) x.begin(), x.end()
# define rall(x) x.rbegin(), x.rend()

mt19937_64 mt(time(0));

void solve();
void init();

int32_t main() {
#ifndef LOCAL
    // cin.tie(nullptr)->sync_with_stdio(false);
#endif
    // cout << fixed << setprecision(30);
    init();
    int t = 1;
    // cin >> t;
    while (t--) {
        solve();
    }
}

namespace fastio {
    static constexpr uint32_t SZ = 1 << 17;
    char ibuf[SZ];
    char obuf[SZ];
    uint32_t pil = 0, pir = 0, por = 0;

    struct Pre {
        char num[40000];

        constexpr Pre() : num() {
            for (int i = 0; i < 10000; i++) {
                int n = i;
                for (int j = 3; j >= 0; j--) {
                    num[i * 4 + j] = n % 10 + '0';
                    n /= 10;
                }
            }
        }
    } constexpr pre;

    __attribute__((target("avx2"), optimize("O3"))) inline void load() {
        memcpy(ibuf, ibuf + pil, pir - pil);
        pir = pir - pil + fread(ibuf + pir - pil, 1, SZ - pir + pil, stdin);
        pil = 0;
    }

    __attribute__((target("avx2"), optimize("O3"))) inline void flush() {
        fwrite(obuf, 1, por, stdout);
        por = 0;
    }

    inline void read(char &c) { c = ibuf[pil++]; }

    template<typename T>
    __attribute__((target("avx2"), optimize("O3"))) inline void read(T &x) {
        if (pil + 32 > pir) load();
        char c;
        do
            read(c);
        while (c < '-');
        bool minus = 0;
        if (std::is_signed<T>::value) {
            if (c == '-') {
                minus = 1;
                read(c);
            }
        }
        x = 0;
        while (c >= '0') {
            x = x * 10 + (c & 15);
            read(c);
        }
        if (std::is_signed<T>::value) {
            if (minus) x = -x;
        }
    }

    inline void write(char c) { obuf[por++] = c; }

    template<typename T>
    __attribute__((target("avx2"), optimize("O3"))) inline void write(T x) {
        if (por + 32 > SZ) flush();
        if (!x) {
            write('0');
            return;
        }
        if (std::is_signed<T>::value) {
            if (x < 0) {
                write('-');
                x = -x;
            }
        }
        if (x >= 10000000000000000) {
            uint32_t r1 = x % 100000000;
            uint64_t q1 = x / 100000000;
            uint32_t r2 = q1 % 100000000;
            uint32_t q2 = q1 / 100000000;
            uint32_t n1 = r1 % 10000;
            uint32_t n2 = r1 / 10000;
            uint32_t n3 = r2 % 10000;
            uint32_t n4 = r2 / 10000;
            if (x >= 1000000000000000000) {
                uint32_t q3 = (q2 * 20972) >> 21;
                uint32_t r3 = q2 - q3 * 100;
                uint32_t q4 = (r3 * 205) >> 11;
                uint32_t r4 = r3 - q4 * 10;
                obuf[por + 0] = '0' + q3;
                obuf[por + 1] = '0' + q4;
                obuf[por + 2] = '0' + r4;
                memcpy(obuf + por + 3, pre.num + (n4 << 2), 4);
                memcpy(obuf + por + 7, pre.num + (n3 << 2), 4);
                memcpy(obuf + por + 11, pre.num + (n2 << 2), 4);
                memcpy(obuf + por + 15, pre.num + (n1 << 2), 4);
                por += 19;
            } else if (x >= 100000000000000000) {
                uint32_t q3 = (q2 * 205) >> 11;
                uint32_t r3 = q2 - q3 * 10;
                obuf[por + 0] = '0' + q3;
                obuf[por + 1] = '0' + r3;
                memcpy(obuf + por + 2, pre.num + (n4 << 2), 4);
                memcpy(obuf + por + 6, pre.num + (n3 << 2), 4);
                memcpy(obuf + por + 10, pre.num + (n2 << 2), 4);
                memcpy(obuf + por + 14, pre.num + (n1 << 2), 4);
                por += 18;
            } else {
                obuf[por + 0] = '0' + q2;
                memcpy(obuf + por + 1, pre.num + (n4 << 2), 4);
                memcpy(obuf + por + 5, pre.num + (n3 << 2), 4);
                memcpy(obuf + por + 9, pre.num + (n2 << 2), 4);
                memcpy(obuf + por + 13, pre.num + (n1 << 2), 4);
                por += 17;
            }
        } else {
            int i = 8;
            char buf[12];
            while (x >= 10000) {
                memcpy(buf + i, pre.num + (x % 10000) * 4, 4);
                x /= 10000;
                i -= 4;
            }
            if (x < 100) {
                if (x < 10) {
                    write(char('0' + x));
                } else {
                    obuf[por + 0] = '0' + x / 10;
                    obuf[por + 1] = '0' + x % 10;
                    por += 2;
                }
            } else {
                if (x < 1000) {
                    memcpy(obuf + por, pre.num + (x << 2) + 1, 3);
                    por += 3;
                } else {
                    memcpy(obuf + por, pre.num + (x << 2), 4);
                    por += 4;
                }
            }
            memcpy(obuf + por, buf + i + 4, 8 - i);
            por += 8 - i;
        }
    }

    inline int getChar() {
        if (pil + 32 > pir) load();
        return ibuf[pil++]; }

    inline int readChar() {
        if (pil + 32 > pir) load();
        int c = getChar();
        while (c != -1 && c <= 32) c = getChar();
        return c;
    }

    inline void read(char *s) {
        int c = readChar();
        while (c > 32)
            *s++ = c, c = getChar();
        *s = 0;
    }

    inline void read(std::string &s) {
        s.clear();
        int c = readChar();
        while (c > 32) s.push_back(c), c = getChar();
    }

    inline void write(const char *s) {
        while (*s) {
            if (por + 32 > SZ) flush();
            write(*s++);
        }
    }

    inline void write(std::string &s) {
        for (auto &i: s) {
            if (por + 32 > SZ) flush();
            write(i);
        }
    }

    inline void write(double x) {
        if (por + 32 > SZ) flush();
        if (x < 0)
            write('-'), x = -x;
        int t = (int) x;
        write(t), x -= t;
        write('.');
        for (int i = 18 - 1; i > 0; i--) {
            x *= 10;
            t = std::min(9, (int) x);
            write('0' + t), x -= t;
        }
        x *= 10;
        t = std::min(9, (int) (x + 0.5));
        write('0' + t);
    }

    template<typename T, typename ...Args>
    inline void read(T &x, Args &...args) {
        read(x);
        read(args...);
    }

    template<typename T, typename ...Args>
    inline void write(T x, Args ...args) {
        write(x);
        write(args...);
    }

    struct AutoFlush {
        ~AutoFlush() { flush(); }
    } AutoFlush;

}  // namespace fastio
using fastio::read;
using fastio::write;


void init() {}

const int N = 50500, K = 700, M = 500500;
int p[N], sq[M / K][N];
int* pp;

int get(int x) {
    return x == pp[x] ? x : pp[x] = get(pp[x]);
}

void merge(int x, int y) {
    pp[get(y)] = get(x);
}

struct edge {
    int u, v, w;
};

bool operator <(const edge& a, const edge& b) {
    return a.w < b.w;
}

void solve() {
    int n, m, k, s;
    // cin >> n >> m >> k >> s; s--;
    read(n, m, k, s); s--;
    edge e[M];
    int e_size = 0;
    unordered_map<ll, int> mp;
    int cnt = 0;
    for (int i = 0; i < m; i++) {
        int u, v, w;
        // cin >> u >> v >> w; u--; v--;
        read(u, v, w); u--; v--;
        if (u > v) {
            swap(u, v);
        }
        if (mp.find(1ll * N * u + v) != mp.end()) {
            mp[1ll * N * u + v] = min(mp[1ll * N * u + v], w);
        } else {
            mp[1ll * N * u + v] = w;
            if (u == s || v == s) {
                cnt++;
            }
        }
    }
    for (auto [p, w] : mp) {
        e[e_size++] = {p / N, p % N, w};
    }
    sort(e, e + e_size);
    for (int i = 0; i < n; i++) {
        p[i] = i;
    }
    m = e_size;
    int com[M], comm[M], commold[M];
    // vector<int> com(m), comm(m + 1), commold(m + 1);
    int comps = n, comps_nos = n;
    pp = p;
    for (int i = 0; i < m; i++) {
        if (!(i % K)) {
            int ind = i / K;
            for (int j = 0; j < n; j++) {
                sq[ind][j] = p[j];
            }
        }
        comm[i] = comps;
        if (get(e[i].u) != get(e[i].v)) {
            comps--;
        }
        merge(e[i].u, e[i].v);
    }
    comm[m] = 1;
    for (int i = 0; i <= m; i++) {
        commold[i] = comm[i];
    }
    for (int i = 0; i < n; i++) {
        p[i] = i;
    }
    for (int i = 0; i < m; i++) {
        com[i] = comps_nos;
        if (e[i].v == s || e[i].u == s) {
            continue;
        }
        if (get(e[i].u) != get(e[i].v)) {
            comps_nos--;
        }
        merge(e[i].u, e[i].v);
    }
    com[m] = comps_nos;
    if (comps > 1 || cnt < k || com[m] > k + 1) {
        // cout << "-1\n";
        write("-1\n");
        return;
    }
    vector<int> good;
    for (int j = 0; j < n; j++) {
        p[j] = j;
    }
    int nos = 0, sst = cnt;
    int used[N], used_sz = 0;
    for (int i = m - 1; i >= 0; i--) {
        if (commold[i] != commold[i + 1]) {
            int ind = i / K;
            comps = comm[ind * K];
            used_sz = 0;
            for (int j = ind * K; j < i; j++) {
                pp = sq[ind];
                int u = get(e[j].u), v = get(e[j].v);
                pp = p;
                u = get(u);
                v = get(v);
                if (u != v) {
                    comps--;
                    p[u] = v;
                    used[used_sz++] = u;
                }
            }
            for (int i = 0; i < used_sz; i++) {
                p[used[i]] = used[i];
            }
        } else {
            comps = 1;
        }
        int is_s = 0;
        if (e[i].v == s || e[i].u == s) {
            is_s = 1;
        }
        if (comps > 1 || sst - is_s < k || com[i] - nos > k + 1) {
            if (e[i].v != s && e[i].u != s) {
                nos++;
            }
            good.push_back(i);
            for (int j = 0, jj = 0; jj < i; j++, jj += K) {
                pp = sq[j];
                int u = get(e[i].u), v = get(e[i].v);
                if (u != v) {
                    comm[jj]--;
                    sq[j][v] = u;
                }
                // merge(u, v, sq[j]);
            }
        } else if (is_s) {
            sst--;
        }
    }
    ll ans = 0;
    for (int i : good) {
        ans += e[i].w;
    }
    // cout << ans << '\n';
    write(ans, '\n');
}

Compilation message

ho_t1.cpp: In function 'void solve()':
ho_t1.cpp:305:26: warning: narrowing conversion of '(((long long int)p) / ((long long int)((int)N)))' from 'long long int' to 'int' [-Wnarrowing]
  305 |         e[e_size++] = {p / N, p % N, w};
      |                        ~~^~~
ho_t1.cpp:305:33: warning: narrowing conversion of '(((long long int)p) % ((long long int)((int)N)))' from 'long long int' to 'int' [-Wnarrowing]
  305 |         e[e_size++] = {p / N, p % N, w};
      |                               ~~^~~
# Verdict Execution time Memory Grader output
1 Runtime error 19 ms 24656 KB Execution killed with signal 6
2 Halted 0 ms 0 KB -
# Verdict Execution time Memory Grader output
1 Runtime error 19 ms 24656 KB Execution killed with signal 6
2 Halted 0 ms 0 KB -
# Verdict Execution time Memory Grader output
1 Runtime error 19 ms 24656 KB Execution killed with signal 6
2 Halted 0 ms 0 KB -