답안 #594203

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
594203 2022-07-12T08:21:58 Z 이동현(#8435) Unique Cities (JOI19_ho_t5) C++17
32 / 100
1151 ms 115716 KB
#include <bits/stdc++.h>

using namespace std;

const int NS = (int)2e5 + 4;
int n, m;
vector<int> way[NS];
int col[NS], from[NS];
int ans[NS], online[NS];
int ransval;
pair<int, int> mdis[NS];
int dis[NS];
multiset<int> mrset;
set<int> rset;

struct seg{
    int n;
    vector<int> cnt, sum;
    vector<multiset<int>> st;
    seg(){}
    seg(int n):n(n + 4){
        cnt.resize(n * 4);
        sum.resize(n * 4);
        st.resize(n * 4);
    }
    void push(int x, int s, int e, int ps, int pe, int val){
        //cout << x << ' ' << s << ' ' << e << ' ' << ps << ' ' << pe << ' ' << val << endl;
        if(pe < s || ps > e || ps > pe){
            return;
        }
        if(ps <= s && pe >= e){
            cnt[x] += val;
            if(cnt[x]) sum[x] = 0;
            else{
                if(s == e) sum[x] = (int)st[x].size();
                else sum[x] = sum[x * 2] + sum[x * 2 + 1];
            }
            return;
        }
        int m = s + e >> 1;
        push(x * 2, s, m, ps, pe, val), push(x * 2 + 1, m + 1, e, ps, pe, val);
        if(cnt[x]) sum[x] = 0;
        else sum[x] = sum[x * 2] + sum[x * 2 + 1];
    }
    int get(int x, int s, int e, int fs, int fe){
        //cout << x << ' ' << s << ' ' << e << ' ' << fs << ' ' << fe << endl;
        if(fe < s || fs > e || cnt[x] || fs > fe) return 0;
        if(fs <= s && fe >= e){
            return sum[x];
        }
        int m = s + e >> 1;
        return get(x * 2, s, m, fs, fe) + get(x * 2 + 1, m + 1, e, fs, fe);
    }
    void push2(int x, int s, int e, int pos, int val){
        st[x].insert(val);
        if(s == e){
            if(cnt[x]) sum[x] = 0;
            else{
                if(s == e) sum[x] = (int)st[x].size();
                else sum[x] = sum[x * 2] + sum[x * 2 + 1];
            }
            return;
        }
        int m = s + e >> 1;
        if(pos <= m) push2(x * 2, s, m, pos, val);
        else push2(x * 2 + 1, m + 1, e, pos, val);
        if(cnt[x]) sum[x] = 0;
        else sum[x] = sum[x * 2] + sum[x * 2 + 1];
    }
    bool get2(int x, int s, int e, int fs, int fe, int val){
        if(fe < s || fs > e || fs > fe || cnt[x]) return 0;
        if(fs <= s && fe >= e){
            auto p = st[x].lower_bound(val);
            if(p != st[x].end() && *p == val) return true;
            return false;
        }
        int m = s + e >> 1;
        return (get2(x * 2, s, m, fs, fe, val) | get2(x * 2 + 1, m + 1, e, fs, fe, val));
    }
    void erase(int x, int s, int e, int pos, int val){
        auto p = st[x].lower_bound(val);
        st[x].erase(p);
        if(s == e){
            if(cnt[x]) sum[x] = 0;
            else{
                if(s == e) sum[x] = (int)st[x].size();
                else sum[x] = sum[x * 2] + sum[x * 2 + 1];
            }
            return;
        }
        int m = s + e >> 1;
        if(pos <= m) erase(x * 2, s, m, pos, val);
        else erase(x * 2 + 1, m + 1, e, pos, val);
        if(cnt[x]) sum[x] = 0;
        else sum[x] = sum[x * 2] + sum[x * 2 + 1];
    }
}tree;

pair<int, int> getfar(int x, int pr = -1){
    from[x] = pr;
    pair<int, int> rv = {0, x};
    for(auto&nxt:way[x]){
        if(nxt == pr){
            continue;
        }
        rv = max(rv, getfar(nxt, x));
    }
    return {rv.first + 1, rv.second};
}

int getdep(int x, int pr = -1){
    int rv = 0;
    for(auto&nxt:way[x]){
        if(online[nxt] || nxt == pr){
            continue;
        }
        dis[nxt] = dis[x] + 1;
        rv = max(rv, getdep(nxt, x) + 1);
        if(mdis[nxt].first > mdis[x].first) mdis[x].second = mdis[x].first, mdis[x].first = mdis[nxt].first;
        else if(mdis[nxt].first > mdis[x].second) mdis[x].second = mdis[nxt].first;
    }
    if(!mdis[x].first) mdis[x].first = dis[x];
    return rv;
}

void sol(int x, int pr = -1){
    ans[x] += (int)rset.size();
    if(mdis[x].first){
        tree.push(1, 0, n - 1, dis[x] - (mdis[x].first - dis[x]), dis[x] - 1, 1);
    }
    if(dis[x]) ans[x] += tree.get(1, 0, n - 1, 0, dis[x] - 1);
    int pushed = 0;
    auto p = rset.lower_bound(col[x]);
    if(p == rset.end() || *p != col[x]){
        if(!dis[x] || !tree.get2(1, 0, n - 1, 0, dis[x] - 1, col[x])){
            tree.push2(1, 0, n - 1, dis[x], col[x]);
            pushed = 1;
        }
    }
    for(auto&nxt:way[x]){
        if(online[nxt] || nxt == pr){
            continue;
        }
        if(mdis[x].first == mdis[nxt].first && mdis[x].first > mdis[x].second){
            tree.push(1, 0, n - 1, dis[x] - (mdis[x].first - dis[x]), dis[x] - 1, -1);
            if(mdis[x].second) tree.push(1, 0, n - 1, dis[x] - (mdis[x].second - dis[x]), dis[x] - 1, 1);
            sol(nxt, x);
            if(mdis[x].second) tree.push(1, 0, n - 1, dis[x] - (mdis[x].second - dis[x]), dis[x] - 1, -1);
            tree.push(1, 0, n - 1, dis[x] - (mdis[x].first - dis[x]), dis[x] - 1, 1);
        }
        else{
            sol(nxt, x);
        }
    }
    if(pushed){
        tree.erase(1, 0, n - 1, dis[x], col[x]);
        pushed = 1;
    }
    if(mdis[x].first){
        tree.push(1, 0, n - 1, dis[x] - (mdis[x].first - dis[x]), dis[x] - 1, -1);
    }
}

int main(){
    ios_base::sync_with_stdio(false);
    cin.tie(0);
    cin >> n >> m;
    tree = seg(n);
    for(int i = 1; i < n; ++i){
        int x, y; cin >> x >> y; --x, --y;
        way[x].push_back(y);
        way[y].push_back(x);
    }
    for(int i = 0; i < n; ++i){
        cin >> col[i];
    }
    int l = getfar(1).second;
    int r = getfar(l).second;
    vector<int> line;
    int now = r;
    while(true){
        line.push_back(now);
        online[now] = 1;
        if(from[now] == -1){
            break;
        }
        now = from[now];
    }
    for(int rep = 0; rep < 2; ++rep){
        for(int i = 0; i < n; ++i){
            dis[i] = 0;
            mdis[i] = {0, 0};
        }
        vector<int> dep((int)line.size()), lpos((int)line.size());
        for(int i = 0; i < (int)line.size(); ++i){
            dep[i] = getdep(line[i]);
        }
        vector<int> stk;
        for(int i = 0; i < (int)line.size(); ++i){
            while((int)stk.size() && stk.back() + dep[stk.back()] < i){
                stk.pop_back();
            }
            if((int)stk.size()){
                lpos[i] = stk.back();
            }
            while((int)stk.size() && stk.back() + dep[stk.back()] <= i + dep[i]){
                stk.pop_back();
            }
            stk.push_back(i);
        }
        int rpp = (int)line.size() - 1;
        vector<vector<int>> pop((int)line.size());
        for(int i = (int)line.size() / 2 - 1 + rep * ((int)line.size() % 2); i >= 0; --i){
            while((int)pop[i].size()){
                auto p = mrset.lower_bound(pop[i].back());
                mrset.erase(p);
                p = mrset.lower_bound(pop[i].back());
                if(!(p != mrset.end() && *p == pop[i].back())){
                    rset.erase(pop[i].back());
                }
                pop[i].pop_back();
            }
            while(rpp > i * 2){
                if(i >= lpos[rpp]){
                    if(lpos[rpp]){
                        pop[lpos[rpp] - 1].push_back(col[line[rpp]]);
                    }
                    mrset.insert(col[line[rpp]]);
                    rset.insert(col[line[rpp]]);
                }
                --rpp;
            }
            sol(line[i]);
        }
        rset.clear();
        mrset.clear();
        reverse(line.begin(), line.end());
    }
    for(int i = 0; i < n; ++i){
        cout << ans[i] << '\n';
    }
    return 0;
}

Compilation message

joi2019_ho_t5.cpp: In member function 'void seg::push(int, int, int, int, int, int)':
joi2019_ho_t5.cpp:40:19: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
   40 |         int m = s + e >> 1;
      |                 ~~^~~
joi2019_ho_t5.cpp: In member function 'int seg::get(int, int, int, int, int)':
joi2019_ho_t5.cpp:51:19: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
   51 |         int m = s + e >> 1;
      |                 ~~^~~
joi2019_ho_t5.cpp: In member function 'void seg::push2(int, int, int, int, int)':
joi2019_ho_t5.cpp:64:19: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
   64 |         int m = s + e >> 1;
      |                 ~~^~~
joi2019_ho_t5.cpp: In member function 'bool seg::get2(int, int, int, int, int, int)':
joi2019_ho_t5.cpp:77:19: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
   77 |         int m = s + e >> 1;
      |                 ~~^~~
joi2019_ho_t5.cpp: In member function 'void seg::erase(int, int, int, int, int)':
joi2019_ho_t5.cpp:91:19: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
   91 |         int m = s + e >> 1;
      |                 ~~^~~
# 결과 실행 시간 메모리 Grader output
1 Correct 2 ms 4948 KB Output is correct
2 Incorrect 6 ms 5604 KB Output isn't correct
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 265 ms 37448 KB Output isn't correct
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 683 ms 49420 KB Output is correct
2 Correct 1039 ms 96032 KB Output is correct
3 Correct 104 ms 21592 KB Output is correct
4 Correct 866 ms 61204 KB Output is correct
5 Correct 1134 ms 100240 KB Output is correct
6 Correct 1151 ms 115716 KB Output is correct
7 Correct 698 ms 60600 KB Output is correct
8 Correct 663 ms 65588 KB Output is correct
9 Correct 716 ms 64004 KB Output is correct
10 Correct 878 ms 68892 KB Output is correct
11 Correct 562 ms 60868 KB Output is correct
12 Correct 867 ms 89020 KB Output is correct
13 Correct 600 ms 74244 KB Output is correct
14 Correct 1127 ms 108148 KB Output is correct
15 Correct 700 ms 60992 KB Output is correct
16 Correct 870 ms 87380 KB Output is correct
17 Correct 1071 ms 115252 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 2 ms 4948 KB Output is correct
2 Incorrect 6 ms 5604 KB Output isn't correct
3 Halted 0 ms 0 KB -