제출 #1162939

#제출 시각아이디문제언어결과실행 시간메모리
1162939yeediotJOI tour (JOI24_joitour)C++20
100 / 100
2387 ms420968 KiB
#include<bits/stdc++.h> using namespace std; #define F first #define S second #define all(x) x.begin(),x.end() #define pii pair<long long,long long> #define pb push_back #define sz(x) (long long)(x.size()) #define chmin(x,y) x=min(x,y) #ifdef local void CHECK(); void setio(){ freopen("/Users/iantsai/cpp/input.txt","r",stdin); freopen("/Users/iantsai/cpp/output.txt","w",stdout); } #else void setio(){} #endif #define TOI_is_so_de ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);setio(); const long long mxn = 2e5 + 5; struct node{ vector<vector<vector<long long>>>s; vector<vector<long long>>ss; long long neg; void init(int siz){ neg = 0; ss = vector<vector<long long>>(2, vector<long long>(3)); s = vector<vector<vector<long long>>>(2, vector<vector<long long>>(siz, vector<long long>(3))); } }; struct BIT{ vector<vector<long long>>bit; void init(long long tmr){ bit = vector<vector<long long>>(3, vector<long long>(tmr + 1)); } void m(long long d, long long p, long long v){ for(; p < sz(bit[d]); p += p & -p) bit[d][p] += v; } void m(long long d, long long l, long long r, long long v){ m(d, l, v); m(d, r + 1, -v); } long long q(long long d, long long p){ long long r = 0; for(; p; p -= p & -p) r += bit[d][p]; return r; } long long q(long long d, long long l, long long r){ return q(d, r) - q(d, l - 1); } }; long long ans = 0, ty[mxn], cnt[mxn], tmr; vector<long long>adj[mxn], p[mxn], in[mxn], out[mxn], num[mxn], ord; long long cal[mxn][3]; node dt[mxn]; bool vis[mxn]; BIT bt[mxn]; void upd(long long v, long long d, bool f){ auto add = [&](){ long long pa = p[v].back(), l = in[v].back(), r = out[v].back(), id = num[v].back(); if(ty[v] == 1) ans += (dt[pa].ss[0][0] * dt[pa].ss[0][2] - dt[pa].neg) * d; else{ ans += dt[pa].ss[1][2 - ty[v]] * d; } }; if(d == -1) add(); for(long long _ = 0; _ < sz(p[v]) - 1; _++){ long long pa = p[v][_], l = in[v][_], r = out[v][_], id = num[v][_]; int a[3] = {}; if(f) a[ty[pa]] = 1; auto m = [&](){ if(ty[v] == 1){ bt[pa].m(1, l, r, d); dt[pa].ss[0][1] += d; dt[pa].s[0][id][1] += d; for(int j = 0; j < 3; j++){ if(j == 1) continue; long long val = bt[pa].q(j, l, r);//1j dt[pa].ss[1][j] += val * d; dt[pa].s[1][id][j] += val * d; } } else{ bt[pa].m(ty[v], l, d); dt[pa].s[0][id][ty[v]] += d; dt[pa].ss[0][ty[v]] += d; dt[pa].neg += dt[pa].s[0][id][2 - ty[v]] * d; long long val = bt[pa].q(1, l); dt[pa].ss[1][ty[v]] += val * d; dt[pa].s[1][id][ty[v]] += val * d; } }; auto calc = [&](){ if(ty[v] == 1){ for(int j = 0; j < 3; j++){ if(j == 1) continue; long long val = bt[pa].q(j, l, r);//1j ans += val * (dt[pa].ss[0][2 - j] - dt[pa].s[0][id][2 - j] + a[2 - j]) * d; } } else{ long long val = bt[pa].q(1, l); ans += (dt[pa].ss[0][2 - ty[v]] - dt[pa].s[0][id][2 - ty[v]] + (a[2 - ty[v]])) * val * d; ans += (dt[pa].ss[1][2 - ty[v]] - dt[pa].s[1][id][2 - ty[v]] + a[1] * (dt[pa].ss[0][2 - ty[v]] - dt[pa].s[0][id][2 - ty[v]])) * d; } }; if(d == 1) m(); calc(); if(d == -1) m(); } if(d == 1) add(); } void cnt_sz(long long v, long long pa){ cnt[v] = 1; for(auto u : adj[v]){ if(u == pa or vis[u]) continue; cnt_sz(u, v); cnt[v] += cnt[u]; } } long long find(long long v, long long pa, long long tar){ for(auto u : adj[v]){ if(u == pa or vis[u]) continue; if(cnt[u] * 2 > tar) return find(u, v, tar); } return v; } void dfs(long long v, long long pa, long long top, long long cc){ p[v].pb(top); in[v].pb(++tmr); num[v].pb(cc); for(auto u : adj[v]){ if(u == pa or vis[u]) continue; dfs(u, v, top, cc); } out[v].pb(tmr); } void cd(long long v){ cnt_sz(v, v); v = find(v, v, cnt[v]); vis[v] = 1; ord.pb(v); tmr = 0; long long cc = 0; p[v].pb(v); in[v].pb(++tmr); num[v].pb(-1); for(auto u : adj[v]){ if(vis[u]) continue; dfs(u, v, v, cc); cc++; } out[v].pb(tmr); dt[v].init(cc); for(long long i = 0; i < 3; i ++){ bt[v].init(tmr); } for(auto u : adj[v]){ if(vis[u]) continue; cd(u); } } #ifdef local void solve(){ long long n; cin >> n; for(long long i = 1; i <= n; i++){ cin >> ty[i]; } for(long long i = 1; i < n; i++){ long long a, b; cin >> a >> b; a++, b++; adj[a].pb(b); adj[b].pb(a); } cd(1); reverse(all(ord)); for(auto v : ord){ upd(v, 1, 0); } cout << ans << '\n'; long long q; cin >> q; while(q--){ long long x, y; cin >> x >> y; x++; upd(x, -1, 1); ty[x] = y; upd(x, 1, 1); cout << ans << '\n'; } } signed main(){ TOI_is_so_de; long long t = 1; //cin >> t; while(t--){ solve(); } #ifdef local CHECK(); #endif } #else #include "joitour.h" void init(int n, vector<int>typ, vector<int>ea, vector<int>eb, int q){ for(long long i = 1; i <= n; i ++){ ty[i] = typ[i - 1]; } for(long long i = 1; i < n; i++){ ea[i - 1]++; eb[i - 1]++; adj[ea[i - 1]].pb(eb[i - 1]); adj[eb[i - 1]].pb(ea[i - 1]); } cd(1); reverse(all(ord)); for(auto v : ord){ upd(v, 1, 0); } } void change(int x, int y){ upd(x + 1, -1, 1); ty[x + 1] = y; upd(x + 1, 1, 1); } long long num_tours(){ return ans; } #endif /* input: */ #ifdef local void CHECK(){ cerr << "\n[Time]: " << 1000.0 * clock() / CLOCKS_PER_SEC << " ms.\n"; function<bool(string,string)> compareFiles = [](string p1, string p2)->bool { std::ifstream file1(p1); std::ifstream file2(p2); if(!file1.is_open() || !file2.is_open()) return false; std::string line1, line2; while (getline(file1, line1) && getline(file2, line2)) { if (line1 != line2)return false; } long long cnta = 0, cntb = 0; while(getline(file1,line1))cnta++; while(getline(file2,line2))cntb++; return cntb - cnta <= 1; }; bool check = compareFiles("output.txt","expected.txt"); if(check) cerr<<"ACCEPTED\n"; else cerr<<"WRONG ANSWER!\n"; } #else #endif
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...