답안 #991995

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
991995 2024-06-03T14:09:27 Z amine_aroua Mergers (JOI19_mergers) C++17
0 / 100
36 ms 22992 KB
#include <bits/stdc++.h>
//#pragma GCC optimize("O3")
//#pragma GCC optimize("unroll-loops")
using namespace std;
#define int long long
#define pb push_back
#define nl '\n'
#define fore(i, y) for(int i = 0; i < y; i++)
#define forr(i, x, y) for(int i = x;i<=y;i++)
#define forn(i, y, x) for(int i = y; i >= x; i--)

const int N = 5e5 + 10;

vector<int> s;
int c;
vector<int> adj[N];
struct DSU{
    int nbC = 0;
    vector<int> e;
    DSU(int n)
    {
        e.assign(n , -1);
    }
    int get(int x){
        return (e[x] < 0 ? x : e[x] = get(e[x]));
    }
    int size(int x)
    {
        return -e[get(x)];
    }
    int same(int u , int v)
    {
        return get(u) == get(v);
    }
    vector<vector<int>> mod;
    bool unite(int u , int v)
    {
        u = get(u) , v = get(v);
        if(u == v)
        {
            mod.pb({-1 , -1 , -1 , -1});
            return 0;
        }
        nbC--;
        if(e[u] > e[v])
            swap(u , v);
        mod.pb({u , v , e[u] , e[v]});
        e[u]+=e[v];
        e[v] = u;
        return 1;
    }
    void roll_back()
    {
        if(mod.empty())
            return;
        auto v = mod.back();
        mod.pop_back();
        if(v[0] != -1)
        {
            fore(i , 2)
                e[v[i]] = v[i + 2];
            nbC++;
        }
    }
    void activate(int x)
    {
        if(s[x] == c)
        {
            for(auto u : adj[x])
            {
                if(s[u] == c)
                    unite(u , x);
            }
        }
    }
};
signed main() {
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    int n , k;
    cin>>n>>k;
    fore(i , n - 1)
    {
        int u , v;
        cin>>u>>v;
        u-- , v--;
        adj[u].pb(v);
        adj[v].pb(u);
    }
    s = vector<int>(n);
    vector<int> pos[k];
    fore(i , n)
    {
        cin>>s[i];
        s[i]--;
        pos[s[i]].pb(i);
    }
    DSU dsu(n);
    int ans = 0;
    fore(i , k)
    {
        c = i;
        dsu.nbC = (int)pos[i].size();
        for(auto x : pos[i])
        {
            dsu.activate(x);
        }
        if(dsu.nbC == 1)
        {
            int cnt = 0;
            for(auto x : pos[i])
            {
                for(auto u : adj[x])
                {
                    if(s[u] != s[x])
                        cnt++;
                }
            }
            if(cnt == 1)
                ans++;
        }
        while(!dsu.mod.empty())
            dsu.roll_back();
    }
    cout<<(ans + 1)/2;
}

# 결과 실행 시간 메모리 Grader output
1 Correct 2 ms 12124 KB Output is correct
2 Incorrect 2 ms 12124 KB Output isn't correct
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 2 ms 12124 KB Output is correct
2 Incorrect 2 ms 12124 KB Output isn't correct
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 2 ms 12124 KB Output is correct
2 Incorrect 2 ms 12124 KB Output isn't correct
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 27 ms 18888 KB Output is correct
2 Correct 36 ms 22992 KB Output is correct
3 Incorrect 3 ms 12380 KB Output isn't correct
4 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 2 ms 12124 KB Output is correct
2 Incorrect 2 ms 12124 KB Output isn't correct
3 Halted 0 ms 0 KB -