답안 #594228

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
594228 2022-07-12T08:57:40 Z 이동현(#8435) Unique Cities (JOI19_ho_t5) C++17
0 / 100
2000 ms 119252 KB
#include <bits/stdc++.h>

using namespace std;

const int NS = (int)2e5 + 4;
int n, m;
vector<int> way[NS];
int col[NS], from[NS];
int ans[NS], online[NS];
pair<int, int> mdis[NS];
int dis[NS];
multiset<int> mrset;
set<int> rset;

struct seg{
    int n;
    vector<int> cnt, sum, osum;
    vector<multiset<int>> st;
    seg(){}
    seg(int n):n(n + 4){
        cnt.resize(n * 4);
        sum.resize(n * 4);
        st.resize(n * 4);
        osum.resize(n * 4);
    }
    void push(int x, int s, int e, int ps, int pe, int val){
        //cout << x << ' ' << s << ' ' << e << ' ' << ps << ' ' << pe << ' ' << val << endl;
        if(pe < s || ps > e || ps > pe){
            return;
        }
        if(ps <= s && pe >= e){
            osum[x] += (e - s + 1) * val;
            cnt[x] += val;
            if(cnt[x]) sum[x] = 0;
            else{
                if(s == e) sum[x] = (int)st[x].size();
                else sum[x] = sum[x * 2] + sum[x * 2 + 1];
            }
            return;
        }
        int m = s + e >> 1;
        push(x * 2, s, m, ps, pe, val), push(x * 2 + 1, m + 1, e, ps, pe, val);
        if(cnt[x]) sum[x] = 0;
        else sum[x] = sum[x * 2] + sum[x * 2 + 1];
        osum[x] = osum[x * 2] + osum[x * 2 + 1];
    }
    int get(int x, int s, int e, int fs, int fe){
        //cout << x << ' ' << s << ' ' << e << ' ' << fs << ' ' << fe << endl;
        if(fe < s || fs > e || cnt[x] || fs > fe) return 0;
        if(fs <= s && fe >= e){
            return sum[x];
        }
        int m = s + e >> 1;
        return get(x * 2, s, m, fs, fe) + get(x * 2 + 1, m + 1, e, fs, fe);
    }
    void push2(int x, int s, int e, int pos, int val){
        st[x].insert(val);
        if(s == e){
            if(cnt[x]) sum[x] = 0;
            else{
                if(s == e) sum[x] = (int)st[x].size();
                else sum[x] = sum[x * 2] + sum[x * 2 + 1];
            }
            return;
        }
        int m = s + e >> 1;
        if(pos <= m) push2(x * 2, s, m, pos, val);
        else push2(x * 2 + 1, m + 1, e, pos, val);
        if(cnt[x]) sum[x] = 0;
        else sum[x] = sum[x * 2] + sum[x * 2 + 1];
    }
    bool get2(int x, int s, int e, int fs, int fe, int val){
        if(fe < s || fs > e || fs > fe || cnt[x]) return 0;
        if(fs <= s && fe >= e && !osum[x]){
            auto p = st[x].lower_bound(val);
            if(p != st[x].end() && *p == val) return true;
            return false;
        }
        int m = s + e >> 1;
        return (get2(x * 2, s, m, fs, fe, val) | get2(x * 2 + 1, m + 1, e, fs, fe, val));
    }
    void erase(int x, int s, int e, int pos, int val){
        auto p = st[x].lower_bound(val);
        st[x].erase(p);
        if(s == e){
            if(cnt[x]) sum[x] = 0;
            else{
                if(s == e) sum[x] = (int)st[x].size();
                else sum[x] = sum[x * 2] + sum[x * 2 + 1];
            }
            return;
        }
        int m = s + e >> 1;
        if(pos <= m) erase(x * 2, s, m, pos, val);
        else erase(x * 2 + 1, m + 1, e, pos, val);
        if(cnt[x]) sum[x] = 0;
        else sum[x] = sum[x * 2] + sum[x * 2 + 1];
    }
}tree;

pair<int, int> getfar(int x, int pr = -1){
    from[x] = pr;
    pair<int, int> rv = {0, x};
    for(auto&nxt:way[x]){
        if(nxt == pr){
            continue;
        }
        rv = max(rv, getfar(nxt, x));
    }
    return {rv.first + 1, rv.second};
}

int getdep(int x, int pr = -1){
    int rv = 0;
    for(auto&nxt:way[x]){
        if(online[nxt] || nxt == pr){
            continue;
        }
        dis[nxt] = dis[x] + 1;
        rv = max(rv, getdep(nxt, x) + 1);
        if(mdis[nxt].first > mdis[x].first) mdis[x].second = mdis[x].first, mdis[x].first = mdis[nxt].first;
        else if(mdis[nxt].first > mdis[x].second) mdis[x].second = mdis[nxt].first;
    }
    if(!mdis[x].first) mdis[x].first = dis[x];
    return rv;
}

void sol(int x, int pr = -1){
    ans[x] += (int)rset.size();
    if(mdis[x].first){
        tree.push(1, 0, n - 1, dis[x] - (mdis[x].first - dis[x]), dis[x] - 1, 1);
    }
    if(dis[x]) ans[x] += tree.get(1, 0, n - 1, 0, dis[x] - 1);
    int pushed = 0;
    auto p = rset.lower_bound(col[x]);
    if(p == rset.end() || *p != col[x]){
        if(!dis[x] || !tree.get2(1, 0, n - 1, 0, dis[x] - 1, col[x])){
            tree.push2(1, 0, n - 1, dis[x], col[x]);
            pushed = 1;
        }
    }
    for(auto&nxt:way[x]){
        if(online[nxt] || nxt == pr){
            continue;
        }
        if(mdis[x].first == mdis[nxt].first && mdis[x].first > mdis[x].second){
            tree.push(1, 0, n - 1, dis[x] - (mdis[x].first - dis[x]), dis[x] - 1, -1);
            if(mdis[x].second) tree.push(1, 0, n - 1, dis[x] - (mdis[x].second - dis[x]), dis[x] - 1, 1);

            if(pushed){
                tree.erase(1, 0, n - 1, dis[x], col[x]);
            }
            int pushed2 = 0;
            auto p = rset.lower_bound(col[x]);
            if(p == rset.end() || *p != col[x]){
                if(!dis[x] || !tree.get2(1, 0, n - 1, 0, dis[x] - 1, col[x])){
                    tree.push2(1, 0, n - 1, dis[x], col[x]);
                    pushed2 = 1;
                }
            }

            sol(nxt, x);

            if(pushed2 == 1){
                tree.erase(1, 0, n - 1, dis[x], col[x]);
            }
            if(pushed){
                tree.push2(1, 0, n - 1, dis[x], col[x]);
            }

            if(mdis[x].second) tree.push(1, 0, n - 1, dis[x] - (mdis[x].second - dis[x]), dis[x] - 1, -1);
            tree.push(1, 0, n - 1, dis[x] - (mdis[x].first - dis[x]), dis[x] - 1, 1);
        }
        else{
            sol(nxt, x);
        }
    }
    if(pushed){
        tree.erase(1, 0, n - 1, dis[x], col[x]);
        pushed = 1;
    }
    if(mdis[x].first){
        tree.push(1, 0, n - 1, dis[x] - (mdis[x].first - dis[x]), dis[x] - 1, -1);
    }
}

int main(){
    ios_base::sync_with_stdio(false);
    cin.tie(0);
    cin >> n >> m;
    tree = seg(n);
    for(int i = 1; i < n; ++i){
        int x, y; cin >> x >> y; --x, --y;
        way[x].push_back(y);
        way[y].push_back(x);
    }
    for(int i = 0; i < n; ++i){
        cin >> col[i];
    }
    int l = getfar(1).second;
    int r = getfar(l).second;
    vector<int> line;
    int now = r;
    while(true){
        line.push_back(now);
        online[now] = 1;
        if(from[now] == -1){
            break;
        }
        now = from[now];
    }    for(int rep = 0; rep < 2; ++rep){
        for(int i = 0; i < n; ++i){
            dis[i] = 0;
            mdis[i] = {0, 0};
        }
        vector<int> dep((int)line.size()), lpos((int)line.size());
        for(int i = 0; i < (int)line.size(); ++i){
            dep[i] = getdep(line[i]);
        }
        vector<int> stk;
        for(int i = 0; i < (int)line.size(); ++i){
            while((int)stk.size() && stk.back() + dep[stk.back()] < i){
                stk.pop_back();
            }
            if((int)stk.size()){
                lpos[i] = stk.back();
            }
            while((int)stk.size() && stk.back() + dep[stk.back()] <= i + dep[i]){
                stk.pop_back();
            }
            stk.push_back(i);
        }
        int rpp = (int)line.size() - 1;
        vector<vector<int>> pop((int)line.size());
        for(int i = (int)line.size() / 2 - 1 + rep * ((int)line.size() % 2); i >= 0; --i){
            while((int)pop[i].size()){
                auto p = mrset.lower_bound(pop[i].back());
                mrset.erase(p);
                p = mrset.lower_bound(pop[i].back());
                if(!(p != mrset.end() && *p == pop[i].back())){
                    rset.erase(pop[i].back());
                }
                pop[i].pop_back();
            }
            while(rpp > i * 2){
                if(i >= lpos[rpp]){
                    if(lpos[rpp]){
                        pop[lpos[rpp] - 1].push_back(col[line[rpp]]);
                    }
                    mrset.insert(col[line[rpp]]);
                    rset.insert(col[line[rpp]]);
                }
                --rpp;
            }
            sol(line[i]);
        }
        rset.clear();
        mrset.clear();
        reverse(line.begin(), line.end());
    }
    for(int i = 0; i < n; ++i){
        cout << ans[i] << '\n';
    }
    return 0;
}

Compilation message

joi2019_ho_t5.cpp: In member function 'void seg::push(int, int, int, int, int, int)':
joi2019_ho_t5.cpp:41:19: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
   41 |         int m = s + e >> 1;
      |                 ~~^~~
joi2019_ho_t5.cpp: In member function 'int seg::get(int, int, int, int, int)':
joi2019_ho_t5.cpp:53:19: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
   53 |         int m = s + e >> 1;
      |                 ~~^~~
joi2019_ho_t5.cpp: In member function 'void seg::push2(int, int, int, int, int)':
joi2019_ho_t5.cpp:66:19: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
   66 |         int m = s + e >> 1;
      |                 ~~^~~
joi2019_ho_t5.cpp: In member function 'bool seg::get2(int, int, int, int, int, int)':
joi2019_ho_t5.cpp:79:19: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
   79 |         int m = s + e >> 1;
      |                 ~~^~~
joi2019_ho_t5.cpp: In member function 'void seg::erase(int, int, int, int, int)':
joi2019_ho_t5.cpp:93:19: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
   93 |         int m = s + e >> 1;
      |                 ~~^~~
# 결과 실행 시간 메모리 Grader output
1 Correct 3 ms 4948 KB Output is correct
2 Correct 11 ms 5588 KB Output is correct
3 Correct 4 ms 5420 KB Output is correct
4 Correct 8 ms 5844 KB Output is correct
5 Incorrect 9 ms 5676 KB Output isn't correct
6 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 325 ms 39524 KB Output isn't correct
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 1075 ms 52408 KB Output is correct
2 Correct 1052 ms 99244 KB Output is correct
3 Correct 194 ms 22356 KB Output is correct
4 Correct 1326 ms 64660 KB Output is correct
5 Correct 1134 ms 103668 KB Output is correct
6 Correct 1735 ms 119252 KB Output is correct
7 Correct 957 ms 63976 KB Output is correct
8 Correct 1026 ms 68956 KB Output is correct
9 Correct 964 ms 67356 KB Output is correct
10 Correct 1923 ms 72412 KB Output is correct
11 Correct 776 ms 64084 KB Output is correct
12 Correct 929 ms 92456 KB Output is correct
13 Correct 814 ms 77652 KB Output is correct
14 Execution timed out 2084 ms 86028 KB Time limit exceeded
15 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 3 ms 4948 KB Output is correct
2 Correct 11 ms 5588 KB Output is correct
3 Correct 4 ms 5420 KB Output is correct
4 Correct 8 ms 5844 KB Output is correct
5 Incorrect 9 ms 5676 KB Output isn't correct
6 Halted 0 ms 0 KB -