답안 #594212

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
594212 2022-07-12T08:33:11 Z 이동현(#8435) Unique Cities (JOI19_ho_t5) C++17
32 / 100
1575 ms 116020 KB
#include <bits/stdc++.h>

using namespace std;

const int NS = (int)2e5 + 4;
int n, m;
vector<int> way[NS];
int col[NS], from[NS];
int ans[NS], online[NS];
pair<int, int> mdis[NS];
int dis[NS];
multiset<int> mrset;
set<int> rset;

struct seg{
    int n;
    vector<int> cnt, sum;
    vector<multiset<int>> st;
    seg(){}
    seg(int n):n(n + 4){
        cnt.resize(n * 4);
        sum.resize(n * 4);
        st.resize(n * 4);
    }
    void push(int x, int s, int e, int ps, int pe, int val){
        //cout << x << ' ' << s << ' ' << e << ' ' << ps << ' ' << pe << ' ' << val << endl;
        if(pe < s || ps > e || ps > pe){
            return;
        }
        if(ps <= s && pe >= e){
            cnt[x] += val;
            if(cnt[x]) sum[x] = 0;
            else{
                if(s == e) sum[x] = (int)st[x].size();
                else sum[x] = sum[x * 2] + sum[x * 2 + 1];
            }
            return;
        }
        int m = s + e >> 1;
        push(x * 2, s, m, ps, pe, val), push(x * 2 + 1, m + 1, e, ps, pe, val);
        if(cnt[x]) sum[x] = 0;
        else sum[x] = sum[x * 2] + sum[x * 2 + 1];
    }
    int get(int x, int s, int e, int fs, int fe){
        //cout << x << ' ' << s << ' ' << e << ' ' << fs << ' ' << fe << endl;
        if(fe < s || fs > e || cnt[x] || fs > fe) return 0;
        if(fs <= s && fe >= e){
            return sum[x];
        }
        int m = s + e >> 1;
        return get(x * 2, s, m, fs, fe) + get(x * 2 + 1, m + 1, e, fs, fe);
    }
    void push2(int x, int s, int e, int pos, int val){
        st[x].insert(val);
        if(s == e){
            if(cnt[x]) sum[x] = 0;
            else{
                if(s == e) sum[x] = (int)st[x].size();
                else sum[x] = sum[x * 2] + sum[x * 2 + 1];
            }
            return;
        }
        int m = s + e >> 1;
        if(pos <= m) push2(x * 2, s, m, pos, val);
        else push2(x * 2 + 1, m + 1, e, pos, val);
        if(cnt[x]) sum[x] = 0;
        else sum[x] = sum[x * 2] + sum[x * 2 + 1];
    }
    bool get2(int x, int s, int e, int fs, int fe, int val){
        if(fe < s || fs > e || fs > fe || cnt[x]) return 0;
        if(fs <= s && fe >= e){
            auto p = st[x].lower_bound(val);
            if(p != st[x].end() && *p == val) return true;
            return false;
        }
        int m = s + e >> 1;
        return (get2(x * 2, s, m, fs, fe, val) | get2(x * 2 + 1, m + 1, e, fs, fe, val));
    }
    void erase(int x, int s, int e, int pos, int val){
        auto p = st[x].lower_bound(val);
        st[x].erase(p);
        if(s == e){
            if(cnt[x]) sum[x] = 0;
            else{
                if(s == e) sum[x] = (int)st[x].size();
                else sum[x] = sum[x * 2] + sum[x * 2 + 1];
            }
            return;
        }
        int m = s + e >> 1;
        if(pos <= m) erase(x * 2, s, m, pos, val);
        else erase(x * 2 + 1, m + 1, e, pos, val);
        if(cnt[x]) sum[x] = 0;
        else sum[x] = sum[x * 2] + sum[x * 2 + 1];
    }
}tree;

pair<int, int> getfar(int x, int pr = -1){
    from[x] = pr;
    pair<int, int> rv = {0, x};
    for(auto&nxt:way[x]){
        if(nxt == pr){
            continue;
        }
        rv = max(rv, getfar(nxt, x));
    }
    return {rv.first + 1, rv.second};
}

int getdep(int x, int pr = -1){
    int rv = 0;
    for(auto&nxt:way[x]){
        if(online[nxt] || nxt == pr){
            continue;
        }
        dis[nxt] = dis[x] + 1;
        rv = max(rv, getdep(nxt, x) + 1);
        if(mdis[nxt].first > mdis[x].first) mdis[x].second = mdis[x].first, mdis[x].first = mdis[nxt].first;
        else if(mdis[nxt].first > mdis[x].second) mdis[x].second = mdis[nxt].first;
    }
    if(!mdis[x].first) mdis[x].first = dis[x];
    return rv;
}

void sol(int x, int pr = -1){
    ans[x] += (int)rset.size();
    if(mdis[x].first){
        tree.push(1, 0, n - 1, dis[x] - (mdis[x].first - dis[x]), dis[x] - 1, 1);
    }
    if(dis[x]) ans[x] += tree.get(1, 0, n - 1, 0, dis[x] - 1);
    int pushed = 0;
    auto p = rset.lower_bound(col[x]);
    if(p == rset.end() || *p != col[x]){
        if(!dis[x] || !tree.get2(1, 0, n - 1, 0, dis[x] - 1, col[x])){
            tree.push2(1, 0, n - 1, dis[x], col[x]);
            pushed = 1;
        }
    }
    for(auto&nxt:way[x]){
        if(online[nxt] || nxt == pr){
            continue;
        }
        if(mdis[x].first == mdis[nxt].first && mdis[x].first > mdis[x].second){
            tree.push(1, 0, n - 1, dis[x] - (mdis[x].first - dis[x]), dis[x] - 1, -1);
            if(mdis[x].second) tree.push(1, 0, n - 1, dis[x] - (mdis[x].second - dis[x]), dis[x] - 1, 1);

            if(pushed){
                tree.erase(1, 0, n - 1, dis[x], col[x]);
            }
            int pushed2 = 0;
            auto p = rset.lower_bound(col[x]);
            if(p == rset.end() || *p != col[x]){
                if(!dis[x] || !tree.get2(1, 0, n - 1, 0, dis[x] - 1, col[x])){
                    tree.push2(1, 0, n - 1, dis[x], col[x]);
                    pushed2 = 1;
                }
            }

            sol(nxt, x);

            if(pushed2 == 1){
                tree.erase(1, 0, n - 1, dis[x], col[x]);
            }
            if(pushed){
                tree.push2(1, 0, n - 1, dis[x], col[x]);
            }

            if(mdis[x].second) tree.push(1, 0, n - 1, dis[x] - (mdis[x].second - dis[x]), dis[x] - 1, -1);
            tree.push(1, 0, n - 1, dis[x] - (mdis[x].first - dis[x]), dis[x] - 1, 1);
        }
        else{
            sol(nxt, x);
        }
    }
    if(pushed){
        tree.erase(1, 0, n - 1, dis[x], col[x]);
        pushed = 1;
    }
    if(mdis[x].first){
        tree.push(1, 0, n - 1, dis[x] - (mdis[x].first - dis[x]), dis[x] - 1, -1);
    }
}

int main(){
    ios_base::sync_with_stdio(false);
    cin.tie(0);
    cin >> n >> m;
    tree = seg(n);
    for(int i = 1; i < n; ++i){
        int x, y; cin >> x >> y; --x, --y;
        way[x].push_back(y);
        way[y].push_back(x);
    }
    for(int i = 0; i < n; ++i){
        cin >> col[i];
    }
    int l = getfar(1).second;
    int r = getfar(l).second;
    vector<int> line;
    int now = r;
    while(true){
        line.push_back(now);
        online[now] = 1;
        if(from[now] == -1){
            break;
        }
        now = from[now];
    }
    for(int rep = 0; rep < 2; ++rep){
        for(int i = 0; i < n; ++i){
            dis[i] = 0;
            mdis[i] = {0, 0};
        }
        vector<int> dep((int)line.size()), lpos((int)line.size());
        for(int i = 0; i < (int)line.size(); ++i){
            dep[i] = getdep(line[i]);
        }
        vector<int> stk;
        for(int i = 0; i < (int)line.size(); ++i){
            while((int)stk.size() && stk.back() + dep[stk.back()] < i){
                stk.pop_back();
            }
            if((int)stk.size()){
                lpos[i] = stk.back();
            }
            while((int)stk.size() && stk.back() + dep[stk.back()] <= i + dep[i]){
                stk.pop_back();
            }
            stk.push_back(i);
        }
        int rpp = (int)line.size() - 1;
        vector<vector<int>> pop((int)line.size());
        for(int i = (int)line.size() / 2 - 1 + rep * ((int)line.size() % 2); i >= 0; --i){
            while((int)pop[i].size()){
                auto p = mrset.lower_bound(pop[i].back());
                mrset.erase(p);
                p = mrset.lower_bound(pop[i].back());
                if(!(p != mrset.end() && *p == pop[i].back())){
                    rset.erase(pop[i].back());
                }
                pop[i].pop_back();
            }
            while(rpp > i * 2){
                if(i >= lpos[rpp]){
                    if(lpos[rpp]){
                        pop[lpos[rpp] - 1].push_back(col[line[rpp]]);
                    }
                    mrset.insert(col[line[rpp]]);
                    rset.insert(col[line[rpp]]);
                }
                --rpp;
            }
            sol(line[i]);
        }
        rset.clear();
        mrset.clear();
        reverse(line.begin(), line.end());
    }
    for(int i = 0; i < n; ++i){
        cout << ans[i] << '\n';
    }
    return 0;
}

Compilation message

joi2019_ho_t5.cpp: In member function 'void seg::push(int, int, int, int, int, int)':
joi2019_ho_t5.cpp:39:19: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
   39 |         int m = s + e >> 1;
      |                 ~~^~~
joi2019_ho_t5.cpp: In member function 'int seg::get(int, int, int, int, int)':
joi2019_ho_t5.cpp:50:19: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
   50 |         int m = s + e >> 1;
      |                 ~~^~~
joi2019_ho_t5.cpp: In member function 'void seg::push2(int, int, int, int, int)':
joi2019_ho_t5.cpp:63:19: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
   63 |         int m = s + e >> 1;
      |                 ~~^~~
joi2019_ho_t5.cpp: In member function 'bool seg::get2(int, int, int, int, int, int)':
joi2019_ho_t5.cpp:76:19: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
   76 |         int m = s + e >> 1;
      |                 ~~^~~
joi2019_ho_t5.cpp: In member function 'void seg::erase(int, int, int, int, int)':
joi2019_ho_t5.cpp:90:19: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
   90 |         int m = s + e >> 1;
      |                 ~~^~~
# 결과 실행 시간 메모리 Grader output
1 Correct 3 ms 4948 KB Output is correct
2 Incorrect 9 ms 5588 KB Output isn't correct
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 267 ms 37444 KB Output isn't correct
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 1037 ms 49372 KB Output is correct
2 Correct 1115 ms 96060 KB Output is correct
3 Correct 154 ms 21724 KB Output is correct
4 Correct 1266 ms 61180 KB Output is correct
5 Correct 1126 ms 100232 KB Output is correct
6 Correct 1575 ms 115860 KB Output is correct
7 Correct 1018 ms 60588 KB Output is correct
8 Correct 922 ms 65664 KB Output is correct
9 Correct 968 ms 63996 KB Output is correct
10 Correct 1291 ms 69012 KB Output is correct
11 Correct 729 ms 60696 KB Output is correct
12 Correct 1000 ms 89172 KB Output is correct
13 Correct 839 ms 74240 KB Output is correct
14 Correct 1319 ms 108892 KB Output is correct
15 Correct 405 ms 61000 KB Output is correct
16 Correct 898 ms 87340 KB Output is correct
17 Correct 1371 ms 116020 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 3 ms 4948 KB Output is correct
2 Incorrect 9 ms 5588 KB Output isn't correct
3 Halted 0 ms 0 KB -