Submission #999269

#TimeUsernameProblemLanguageResultExecution timeMemory
999269vjudge1Sumtree (INOI20_sumtree)C++17
100 / 100
707 ms90196 KiB
#include <bits/stdc++.h>
using namespace std;
const long long mod = 1e9+7;
const int N = 4e5+10;
const int MAXNC=5e5+10;
const int LG = 20;
vector<int> g[N];
int tin[N], tout[N], tim=0;
int up[N][LG], tags[N];
long long fact[MAXNC];
int depth[N], sz[N], used[N], tagsms[N];

namespace BIT {
  int fen[N];
  void add(int at, int vl) {
    at++;
    while(at < N) {
      fen[at] += vl;
      at+=at&(-at);
    }
  }
  int query(int at) {
    at++;
    int res=0;
    while(at) {
      res+=fen[at];
      at-=at&(-at);
    }
    return res;
  }
};
void change(int x, int vl) {
  BIT::add(tin[x], vl);
  BIT::add(tout[x], -vl);
}
int query(int anc, int ch) {
  return BIT::query(tin[ch])-BIT::query(tin[anc]);
}
long long binpow(long long a, long long b) {
  long long res = 1;
  while(b) {
    if(b&1) {
      res=(res*a)%mod;
    }
    a=(a*a)%mod;
    b>>=1;
  }
  return res;
}
long long ncr(long long nn, long long rr) {
  if(rr > nn)return 0;
  if(rr < 0)return 0;
  long long res = fact[nn];
  res = (res*binpow(fact[rr], mod-2))%mod;
  res = (res*binpow(fact[nn-rr], mod-2))%mod;
  return res;
}
void dfs(int at, int p) {
  sz[at]=1;
  up[at][0] = p;
  for(int i = 1;i<LG;i++)up[at][i] = up[up[at][i-1]][i-1];
  tin[at] = tim++;
  for(int to:g[at]) {
    if(to == p)continue;
    depth[to]=depth[at]+1;
    dfs(to, at);
    sz[at]+=sz[to];
  }
  tout[at] = tim++;
}
struct DS1 {
  int fen[3*N];
  DS1() {
    for(int i = 0;i<3*N;i++)fen[i]=0;
  }
  void upd(int at, int val) {
    at++;
    while(at<3*N) {
      fen[at]+=val;
      at+=at&(-at);
    }
  }
  void update(int l, int r, int val) {
    upd(l, val);
    upd(r+1, -val);
  }
  long long get(int at) {
    at++;
    long long res=0;
    while(at) {
      res+=fen[at];
      at-=at&(-at);
    }
    return res;
  }
};
namespace HLD1 {
  DS1 ds;
  int tin[N], top[N], tim=0, par[N], V[N];
  void dfs(int at, int p, int tp) {
    tin[at]=++tim;
    top[at]=tp;
    par[at]=p;
    int big=-1;
    for(int to:g[at]) {
      if(to==p)continue;
      if(big==-1 or sz[to]>sz[big]) {
        big=to;
      }
    }
    if(big==-1)return;
    dfs(big, at, tp);
    for(int to:g[at]) {
      if(to==big or to==p)continue;
      dfs(to, at, to);
    }
  }
  int upd(int u,int v, long long val) {
    int ans=0;
    while(top[u]!=top[v]) {
     if(depth[top[u]] < depth[top[v]])swap(u,v);
     ds.update(tin[top[u]], tin[u], val);
     u=par[top[u]];
    }
    if(depth[u]>depth[v])swap(u,v);
    ds.update(tin[u], tin[v], val);
    return ans;
  }
   long long get(int u) {
    return ds.get(tin[u]);
  }
};

namespace HLD2 {
  DS1 ds;
  int tin[N], top[N], tim=0, par[N], V[N];
  void dfs(int at, int p, int tp) {
    tin[at]=++tim;
    top[at]=tp;
    par[at]=p;
    int big=-1;
    for(int to:g[at]) {
      if(to==p)continue;
      if(big==-1 or sz[to]>sz[big]) {
        big=to;
      }
    }
    if(big==-1)return;
    dfs(big, at, tp);
    for(int to:g[at]) {
      if(to==big or to==p)continue;
      dfs(to, at, to);
    }
  }
  int upd(int u,int v, int val) {
    int ans=0;
    while(top[u]!=top[v]) {
     if(depth[top[u]] < depth[top[v]])swap(u,v);
     ds.update(tin[top[u]], tin[u], val);
     u=par[top[u]];
    }
    if(depth[u]>depth[v])swap(u,v);
    ds.update(tin[u], tin[v], val);
    return ans;
  }
  int get(int u) {
    return ds.get(tin[u]);
  }
};

int findroot(int at) {
  for(int i = LG-1;i>=0;i--) {
    if(up[at][i] != 0 and query(up[at][i], at) == depth[at]-depth[up[at][i]]) {
      at = up[at][i];
    }
  }
  return at;
}

struct ANS {
  long long mult = 1;
  long long zero = 0;
  void add(long long vl) {
    if(vl == 0)
      zero++;
    else
      mult = (mult*vl)%mod;
  }
  void del(long long vl) {
    if(vl == 0) {
      zero--;
    }
    else {
      mult = (mult*binpow(vl, mod-2))%mod;
    }
  }
  long long get() {
    if(zero>0)return 0;
    return mult;
  }

};
int main () {
  cin.tie(0)->sync_with_stdio(0);
  fact[0] = 1;
  for(int i = 1;i<MAXNC;i++)fact[i] = (fact[i-1]*i)%mod;
  int n, r;
  cin >> n >> r;
  for(int i = 1;i<n;i++) {
    int u, v;
    cin >> u >> v;
    g[u].push_back(v);
    g[v].push_back(u);
  }
  dfs(1, 0);
  HLD1::dfs(1, 0, 1);
  HLD2::dfs(1, 0, 1);
  tags[1] = r;
  used[1] = 1;
  for(int i = 2;i<=n;i++){
    change(i, 1);
  }
  for(int i=  1;i<=n;i++) {
    HLD1::upd(i, i, sz[i]);
    HLD2::upd(i, i, tagsms[i]);
  }
  ANS ans;
  ans.add(ncr(n+r-1, r));
  cout << ans.get() << "\n";
  int q;
  cin >> q;
  while(q--) {
    int t;
    cin >> t;
    if(t == 1) {
      int u, v;
      cin >> u >> v;
      tags[u]=v;
      int rr = findroot(u);
      {
        long long tmp = tags[rr] - HLD2::get(rr);
        long long szz = HLD1::get(rr);
        ans.del(ncr(szz+tmp-1, tmp)); 
      }
      int par = up[u][0];
      HLD1::upd(par, rr, -HLD1::get(u));
      HLD2::upd(par, rr, v-HLD2::get(u));
//      while(1) {
//        sz[par]-=sz[u];
//        tagsms[par]-=tagsms[u];
//        tagsms[par]+=v;
//        if(par==rr)break;
//        par = up[par][0];
//      }
      {
        long long tmp = tags[u] - HLD2::get(u);
        long long szz = HLD1::get(u);
        ans.add(ncr(szz+tmp-1, tmp)); 
      }
      {
        long long tmp = tags[rr] - HLD2::get(rr);
        long long szz = HLD1::get(rr);
        ans.add(ncr(szz+tmp-1, tmp)); 
      }
      change(u, -1);
    }
    else {
      int u;
      cin >> u;
      int v = tags[u];
      change(u, 1);
      int rr = findroot(u);
      {
        long long tmp = tags[rr] - HLD2::get(rr);
        long long szz = HLD1::get(rr);
        ans.del(ncr(szz+tmp-1, tmp)); 
      }
      int par = up[u][0];
      HLD1::upd(par, rr, HLD1::get(u));
      HLD2::upd(par, rr, HLD2::get(u)-v);
//      while(1) {
//        sz[par]+=sz[u];
//        tagsms[par]+=tagsms[u];
//        tagsms[par]-=v;
//        if(par==rr)break;
//        par = up[par][0];
//      }
      {
        long long tmp = tags[u] - HLD2::get(u);
        long long szz = HLD1::get(u);
        ans.del(ncr(szz+tmp-1, tmp)); 
      }
      {
        long long tmp = tags[rr] - HLD2::get(rr);
        long long szz = HLD1::get(rr);
        ans.add(ncr(szz+tmp-1, tmp)); 
      }
    }
    cout << ans.get() << "\n";
  }
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...