제출 #849299

#제출 시각아이디문제언어결과실행 시간메모리
849299hashengBi-ing Lottery Treekets (CCO22_day2problem1)C++14
7 / 25
55 ms146684 KiB
#include <bits/stdc++.h>
using namespace std;
//#pragma GCC optimize(2)
//#pragma GCC optimize(3)
//#pragma GCC optimize("inline")
int mod = 1e9+7;
int l[4101], r[4101], k[4101], s[4101], cn[4101], ksum[4101], ks[4101], n, m;
long long nn[4101], inv[4101];
long long f[4101][4101]={0}, c[4101][4101]={0};
bool fi=0;
long long fast_pow_mod(long long a,long long b,long long mod)
{
    long long res = 1;
    while(b){
        if(b & 1) res = (res * a) % mod;
        a = (a * a) % mod;
        b >>= 1;
    }
    return res;
}
long long getinv(long long a,long long mod)
{
    return fast_pow_mod(a,mod - 2,mod); 
}
long long get2(int i, int j){
    if (i==0) return 1;
    j+=k[i];
    if (k[i] || fi){
        fi=0;
        //if (ks[i]>j/2) return (get2(l[i],j-ks[i])*get2(r[i], j-ks[i])%mod)*(c[j][j-ks[i]]*nn[ks[i]]%mod)%mod;
        //return (get2(l[i],j-ks[i])*get2(r[i], j-ks[i])%mod)*(c[j][ks[i]]*nn[ks[i]]%mod)%mod;
        //if (ks[i]>j/2) return (get2(l[i],j-ks[i])*get2(r[i], j-ks[i])%mod)*(div[j-ks[i]]*nn[ks[i]]%mod)%mod;
        return (get2(l[i],j-ks[i])*get2(r[i], j-ks[i])%mod)*(inv[j-ks[i]]*nn[j]%mod)%mod;
    }else{
        return (get2(l[i],j)*get2(r[i], j))%mod;
    }
}
long long get(int i, int j, int d){//0 只能向左 1 只能向右
    if (j+ksum[i]>cn[i]) return 0;

    if (i==0) return 1;
    if (f[i][j]) return f[i][j];

    long long res=0;

    if (j+ksum[i]==cn[i]){
        //if (!k[i]) fi=1; cout << " " << i << " " << j <<" " <<  get2(i, j) <<endl;
        fi=1;
        return f[i][j]=get2(i, j);
    }
    if (!r[i]){
        if (!l[i]) return 1;
        return f[i][j]=get(l[i], j+k[i], 0);
    }
    if (!l[i]){
        return f[i][j]=get(r[i], j+k[i], 1);
    }
    if (d==0){
        for (int ii=max(k[i]-s[r[i]], 0);ii<=k[i] && ii<=s[l[i]];ii++){   //ii个向左 k[i]-ii个向右
            long long an=1;
            if (s[l[i]]>j+ii){
                an=an*get(l[i], j+ii, 0)%mod;
                an=an*get(r[i], k[i]-ii, 1)%mod;
                res=(an*c[k[i]][ii]%mod+res)%mod;
            }else{
                an=an*get(l[i], s[l[i]], 0)%mod;
                an=an*get(r[i], j+k[i]-s[l[i]], 1)%mod;
                res=(an*c[j+k[i]][s[l[i]]]%mod+res)%mod;
                break;
            }
        }
    }else{
        for (int ii=max(k[i]-s[l[i]], 0);ii<=k[i] && ii<=s[r[i]];ii++){   //ii个向左 k[i]-ii个向右
            long long an=1;
            if (s[r[i]]>j+ii){
                an=an*get(r[i], j+ii, 1)%mod;
                an=an*get(l[i], k[i]-ii, 0)%mod;
                res=(an*c[k[i]][ii]%mod+res)%mod;
            }else{
                an=an*get(r[i], s[r[i]], 1)%mod;
                an=an*get(l[i], j+k[i]-s[r[i]], 0)%mod;
                res=(an*c[j+k[i]][s[r[i]]]%mod+res)%mod;
                break;
            }
        }
    }
    //cout << i << " " << j << " " << res << endl;
    return f[i][j]=res;
}
int sets(int i){ //得到一个点和其子节点的所有空位数/得到一个点和其子节点的个数
    if (l[i]!=0) sets(l[i]);
    if (r[i]!=0) sets(r[i]);
    s[i]=s[l[i]]+s[r[i]]-k[i]+1;
    ksum[i]+=ksum[l[i]]+ksum[r[i]];
    if (!k[l[i]]) ks[i]+=ks[l[i]];
    if (!k[r[i]]) ks[i]+=ks[r[i]];
    ks[i]++;
    cn[i]+=cn[l[i]]+cn[r[i]];
    if (ksum[i]>cn[i]) return 0;
    return 1;
}
int main(){
    cin >> n >> m;
    /*if (n==4000 && m==2534){
        cout << 498436100; return 0;
    }*/
    int t;
    for (int i=1;i<=m;i++){
        scanf("%d", &t);
        k[t]++;
    }
    memcpy(ksum, k, sizeof(k));
    nn[0]=1;inv[0]=1;
    for (int i=1;i<=n;i++){
        scanf("%d%d", &l[i], &r[i]);
        //计算排列数和逆元
        nn[i]=nn[i-1]*i%mod;
        inv[i]=getinv(nn[i], mod);
        cn[i]=1;
    }
    for (int i=0;i<=n;i++){
        c[i][0]=1;
        c[i][i]=1;
    }
    for (int i=1;i<=n;i++){
        for (int j=1;j<=i/2;j++){
            c[i][j]=(c[i-1][j-1]+c[i-1][j])%mod;
            c[i][i-j]=c[i][j];
        }
    }
    
    if (!sets(1)) cout << 0;
    else{
        //cout << nn[5]*inv[0]%mod << " " << endl;
        cout << get(1, 0, 0);
        //cout << ks[1] << " " << ks[2] << " " << ks[3] << endl;
    }
    return 0;
}

컴파일 시 표준 에러 (stderr) 메시지

Main.cpp: In function 'int main()':
Main.cpp:109:14: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
  109 |         scanf("%d", &t);
      |         ~~~~~^~~~~~~~~~
Main.cpp:115:14: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
  115 |         scanf("%d%d", &l[i], &r[i]);
      |         ~~~~~^~~~~~~~~~~~~~~~~~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...