답안 #594202

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
594202 2022-07-12T08:19:41 Z 이동현(#8435) Unique Cities (JOI19_ho_t5) C++17
32 / 100
901 ms 100496 KB
#include <bits/stdc++.h>

using namespace std;

const int NS = (int)2e5 + 4;
int n, m;
vector<int> way[NS];
int col[NS], from[NS];
int ans[NS], online[NS];
int ransval;
pair<int, int> mdis[NS];
int dis[NS];
multiset<int> mrset;
set<int> rset;

struct seg{
    int n;
    vector<int> cnt, sum;
    vector<set<int>> st;
    seg(){}
    seg(int n):n(n + 4){
        cnt.resize(n * 4);
        sum.resize(n * 4);
        st.resize(n * 4);
    }
    void push(int x, int s, int e, int ps, int pe, int val){
        //cout << x << ' ' << s << ' ' << e << ' ' << ps << ' ' << pe << ' ' << val << endl;
        if(pe < s || ps > e || ps > pe){
            return;
        }
        if(ps <= s && pe >= e){
            cnt[x] += val;
            if(cnt[x]) sum[x] = 0;
            else{
                if(s == e) sum[x] = (int)st[x].size();
                else sum[x] = sum[x * 2] + sum[x * 2 + 1];
            }
            return;
        }
        int m = s + e >> 1;
        push(x * 2, s, m, ps, pe, val), push(x * 2 + 1, m + 1, e, ps, pe, val);
        if(cnt[x]) sum[x] = 0;
        else sum[x] = sum[x * 2] + sum[x * 2 + 1];
    }
    int get(int x, int s, int e, int fs, int fe){
        //cout << x << ' ' << s << ' ' << e << ' ' << fs << ' ' << fe << endl;
        if(fe < s || fs > e || cnt[x] || fs > fe) return 0;
        if(fs <= s && fe >= e){
            return sum[x];
        }
        int m = s + e >> 1;
        return get(x * 2, s, m, fs, fe) + get(x * 2 + 1, m + 1, e, fs, fe);
    }
    void push2(int x, int s, int e, int pos, int val){
        if(s == e){
            st[x].insert(val);
            if(cnt[x]) sum[x] = 0;
            else{
                if(s == e) sum[x] = (int)st[x].size();
                else sum[x] = sum[x * 2] + sum[x * 2 + 1];
            }
            return;
        }
        int m = s + e >> 1;
        if(pos <= m) push2(x * 2, s, m, pos, val);
        else push2(x * 2 + 1, m + 1, e, pos, val);
        if(cnt[x]) sum[x] = 0;
        else sum[x] = sum[x * 2] + sum[x * 2 + 1];
    }
    bool get2(int x, int s, int e, int fs, int fe, int val){
        if(fe < s || fs > e || fs > fe || cnt[x]) return 0;
        if(fs <= s && fe >= e){
            auto p = st[x].lower_bound(val);
            if(p != st[x].end() && *p == val) return true;
            return false;
        }
        int m = s + e >> 1;
        return (get2(x * 2, s, m, fs, fe, val) | get2(x * 2 + 1, m + 1, e, fs, fe, val));
    }
    void erase(int x, int s, int e, int pos, int val){
        if(s == e){
            st[x].erase(val);
            if(cnt[x]) sum[x] = 0;
            else{
                if(s == e) sum[x] = (int)st[x].size();
                else sum[x] = sum[x * 2] + sum[x * 2 + 1];
            }
            return;
        }
        int m = s + e >> 1;
        if(pos <= m) erase(x * 2, s, m, pos, val);
        else erase(x * 2 + 1, m + 1, e, pos, val);
        if(cnt[x]) sum[x] = 0;
        else sum[x] = sum[x * 2] + sum[x * 2 + 1];
    }
}tree;

pair<int, int> getfar(int x, int pr = -1){
    from[x] = pr;
    pair<int, int> rv = {0, x};
    for(auto&nxt:way[x]){
        if(nxt == pr){
            continue;
        }
        rv = max(rv, getfar(nxt, x));
    }
    return {rv.first + 1, rv.second};
}

int getdep(int x, int pr = -1){
    int rv = 0;
    for(auto&nxt:way[x]){
        if(online[nxt] || nxt == pr){
            continue;
        }
        dis[nxt] = dis[x] + 1;
        rv = max(rv, getdep(nxt, x) + 1);
        if(mdis[nxt].first > mdis[x].first) mdis[x].second = mdis[x].first, mdis[x].first = mdis[nxt].first;
        else if(mdis[nxt].first > mdis[x].second) mdis[x].second = mdis[nxt].first;
    }
    if(!mdis[x].first) mdis[x].first = dis[x];
    return rv;
}

void sol(int x, int pr = -1){
    ans[x] += (int)rset.size();
    if(mdis[x].first){
        tree.push(1, 0, n - 1, dis[x] - (mdis[x].first - dis[x]), dis[x] - 1, 1);
    }
    if(dis[x]) ans[x] += tree.get(1, 0, n - 1, 0, dis[x] - 1);
    int pushed = 0;
    auto p = rset.lower_bound(col[x]);
    if(p == rset.end() || *p != col[x]){
        if(!dis[x] || !tree.get2(1, 0, n - 1, 0, dis[x] - 1, col[x])){
            tree.push2(1, 0, n - 1, dis[x], col[x]);
            pushed = 1;
        }
    }
    for(auto&nxt:way[x]){
        if(online[nxt] || nxt == pr){
            continue;
        }
        if(mdis[x].first == mdis[nxt].first && mdis[x].first > mdis[x].second){
            tree.push(1, 0, n - 1, dis[x] - (mdis[x].first - dis[x]), dis[x] - 1, -1);
            if(mdis[x].second) tree.push(1, 0, n - 1, dis[x] - (mdis[x].second - dis[x]), dis[x] - 1, 1);
            sol(nxt, x);
            if(mdis[x].second) tree.push(1, 0, n - 1, dis[x] - (mdis[x].second - dis[x]), dis[x] - 1, -1);
            tree.push(1, 0, n - 1, dis[x] - (mdis[x].first - dis[x]), dis[x] - 1, 1);
        }
        else{
            sol(nxt, x);
        }
    }
    if(pushed){
        tree.erase(1, 0, n - 1, dis[x], col[x]);
        pushed = 1;
    }
    if(mdis[x].first){
        tree.push(1, 0, n - 1, dis[x] - (mdis[x].first - dis[x]), dis[x] - 1, -1);
    }
}

int main(){
    ios_base::sync_with_stdio(false);
    cin.tie(0);
    cin >> n >> m;
    tree = seg(n);
    for(int i = 1; i < n; ++i){
        int x, y; cin >> x >> y; --x, --y;
        way[x].push_back(y);
        way[y].push_back(x);
    }
    for(int i = 0; i < n; ++i){
        cin >> col[i];
    }
    int l = getfar(1).second;
    int r = getfar(l).second;
    vector<int> line;
    int now = r;
    while(true){
        line.push_back(now);
        online[now] = 1;
        if(from[now] == -1){
            break;
        }
        now = from[now];
    }
    for(int rep = 0; rep < 2; ++rep){
        for(int i = 0; i < n; ++i){
            dis[i] = 0;
            mdis[i] = {0, 0};
        }
        vector<int> dep((int)line.size()), lpos((int)line.size());
        for(int i = 0; i < (int)line.size(); ++i){
            dep[i] = getdep(line[i]);
        }
        vector<int> stk;
        for(int i = 0; i < (int)line.size(); ++i){
            while((int)stk.size() && stk.back() + dep[stk.back()] < i){
                stk.pop_back();
            }
            if((int)stk.size()){
                lpos[i] = stk.back();
            }
            while((int)stk.size() && stk.back() + dep[stk.back()] <= i + dep[i]){
                stk.pop_back();
            }
            stk.push_back(i);
        }
        int rpp = (int)line.size() - 1;
        vector<vector<int>> pop((int)line.size());
        for(int i = (int)line.size() / 2 - 1 + rep * ((int)line.size() % 2); i >= 0; --i){
            while((int)pop[i].size()){
                auto p = mrset.lower_bound(pop[i].back());
                mrset.erase(p);
                p = mrset.lower_bound(pop[i].back());
                if(!(p != mrset.end() && *p == pop[i].back())){
                    rset.erase(pop[i].back());
                }
                pop[i].pop_back();
            }
            while(rpp > i * 2){
                if(i >= lpos[rpp]){
                    if(lpos[rpp]){
                        pop[lpos[rpp] - 1].push_back(col[line[rpp]]);
                    }
                    mrset.insert(col[line[rpp]]);
                    rset.insert(col[line[rpp]]);
                }
                --rpp;
            }
            sol(line[i]);
        }
        rset.clear();
        mrset.clear();
        reverse(line.begin(), line.end());
    }
    for(int i = 0; i < n; ++i){
        cout << ans[i] << '\n';
    }
    return 0;
}

Compilation message

joi2019_ho_t5.cpp: In member function 'void seg::push(int, int, int, int, int, int)':
joi2019_ho_t5.cpp:40:19: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
   40 |         int m = s + e >> 1;
      |                 ~~^~~
joi2019_ho_t5.cpp: In member function 'int seg::get(int, int, int, int, int)':
joi2019_ho_t5.cpp:51:19: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
   51 |         int m = s + e >> 1;
      |                 ~~^~~
joi2019_ho_t5.cpp: In member function 'void seg::push2(int, int, int, int, int)':
joi2019_ho_t5.cpp:64:19: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
   64 |         int m = s + e >> 1;
      |                 ~~^~~
joi2019_ho_t5.cpp: In member function 'bool seg::get2(int, int, int, int, int, int)':
joi2019_ho_t5.cpp:77:19: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
   77 |         int m = s + e >> 1;
      |                 ~~^~~
joi2019_ho_t5.cpp: In member function 'void seg::erase(int, int, int, int, int)':
joi2019_ho_t5.cpp:90:19: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
   90 |         int m = s + e >> 1;
      |                 ~~^~~
# 결과 실행 시간 메모리 Grader output
1 Correct 3 ms 4948 KB Output is correct
2 Incorrect 6 ms 5632 KB Output isn't correct
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 216 ms 37364 KB Output isn't correct
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 403 ms 49356 KB Output is correct
2 Correct 882 ms 96232 KB Output is correct
3 Correct 68 ms 16608 KB Output is correct
4 Correct 491 ms 61304 KB Output is correct
5 Correct 901 ms 100496 KB Output is correct
6 Correct 547 ms 76884 KB Output is correct
7 Correct 450 ms 60828 KB Output is correct
8 Correct 456 ms 65888 KB Output is correct
9 Correct 488 ms 64228 KB Output is correct
10 Correct 522 ms 63080 KB Output is correct
11 Correct 379 ms 60936 KB Output is correct
12 Correct 718 ms 89244 KB Output is correct
13 Correct 520 ms 74564 KB Output is correct
14 Correct 460 ms 74324 KB Output is correct
15 Correct 240 ms 61376 KB Output is correct
16 Correct 614 ms 87632 KB Output is correct
17 Correct 521 ms 76936 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 3 ms 4948 KB Output is correct
2 Incorrect 6 ms 5632 KB Output isn't correct
3 Halted 0 ms 0 KB -