답안 #466277

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
466277 2021-08-18T12:49:10 Z ivan_tudor Star Trek (CEOI20_startrek) C++14
컴파일 오류
0 ms 0 KB
#include<bits/stdc++.h>
using namespace std;
const long long N = 1e5 + 5;
const long long MOD = 1e9 + 7;
long long tp[N], tpr[N];
long long crit[N], critr[N];
vector<int> gr[N];
void add(long long &x, long long y){
  y%=MOD;
  x+=y;
  if(x>=MOD)
    x-=MOD;
  if(x<0)
    x+=MOD;
}
void add(long long &x, long long y){
  y%=MOD;
  x+=y;
  if(x>=MOD)
    x-=MOD;
  if(x<0)
    x+=MOD;
}
void dfs(long long nod, long long dad){
  tp[nod] = 0;
  long long nr0 = 0, sum = 0, unq = 0;
  for(auto x:gr[nod]){
    if(x != dad){
      dfs(x, nod);
      if(tp[x] == 0)
        nr0++, unq = crit[x];
      sum += crit[x];
    }
  }
  if(nr0 == 0){
    tp[nod] = 0;
    crit[nod] = sum + 1;
  }
  else if(nr0 == 1){
    tp[nod] = 1;
    crit[nod] = unq;
  }
  else{
    tp[nod] = 1;
    crit[nod] = 0;
  }
}
void changeroot(long long nod, long long dad){
  long long OLDTP = tp[nod], OLDCRIT = crit[nod];
  long long nr0 = 0, sum = 0, unq = 0;
  for(auto x:gr[nod]){
    if(tp[x] == 0 )
      nr0++, unq = crit[x];
    sum += crit[x];
  }
  if(nr0 == 0)
    tp[nod] = 0;
  else
    tp[nod] = 1;
  if(tp[nod] == 0){
    crit[nod] = sum + 1;
  }
  else if(tp[nod] == 1){
    crit[nod] = 0;
    if(nr0 == 1)
      crit[nod] = unq;
  }
  critr[nod] = crit[nod];
  tpr[nod] = tp[nod];

  for(auto x:gr[nod]){
    if(x == dad)
      continue;
    long long oldtp = tp[nod], oldcrit = crit[nod];
    if(tp[nod] == 0){
      tp[nod] = 0;
      crit[nod] = sum + 1 - crit[x];
    }
    else{
      if(tp[x] == 1){
        tp[nod] = 1;
        crit[nod] = oldcrit;
      }
      else{
        if(nr0 == 1){
          tp[nod] = 0;
          crit[nod] = sum + 1 - crit[x];
        }
        else if(nr0 == 2){
          tp[nod] = 1;
          long long othunq = 0;
          for(auto vc:gr[nod]){
            if(vc == x)
              continue;
            if(tp[vc] == 0)
              othunq = crit[vc];
          }
          crit[nod] = othunq;
        }
        else if(nr0 >= 3){
          tp[nod] = 1;
          crit[nod] = 0;
        }
      }

    }

    changeroot(x, nod);

    tp[nod] = oldtp;
    crit[nod] = oldcrit;
  }

  tp[nod] = OLDTP;
  crit[nod] = OLDCRIT;
}
long long mat[2][2];
long long aux[2][2];
long long rez[2][2];

void multiply(long long A[2][2], long long B[2][2], long long rez[2][2]){
  for(long long i = 0; i < 2; i++)
    for(long long j = 0; j < 2; j++)
      rez[i][j] = 0;
  for(long long i = 0; i <2; i++){
    for(long long j = 0; j < 2; j++){
      for(long long k = 0; k < 2; k++){
        add(rez[i][j], 1LL * A[i][k] * B[k][j]);
      }
    }
  }

}
void cpy(long long A[2][2], long long B[2][2]){
  for(long long i = 0; i < 2; i++){
    for(long long j = 0; j < 2; j++){
      B[i][j] = A[i][j];
    }
  }
}
void mknull(long long A[2][2]){
  for(long long i = 0; i < 2; i++){
    for(long long j = 0; j < 2; j++){
      A[i][j] = 0;
    }
  }
  A[0][0] = A[1][1] = 1;
}
void lgcput(long long pwr){
  mknull(rez);
  while(pwr){
    if(pwr & 1){
      multiply(mat, rez, aux);
      cpy(aux, rez);
    }
    multiply(mat, mat, aux);
    cpy(aux, mat);
    pwr/=2;
  }
}
int main()
{
  //freopen(".in","r",stdin);
  ios::sync_with_stdio(false);
  cin.tie(0),cout.tie(0);
  long long n, d;
  cin>>n>>d;
  for(long long i = 1; i < n; i++){
    long long u, v;
    cin>>u>>v;
    gr[u].push_back(v);
    gr[v].push_back(u);
  }
  dfs(1, 0);
  changeroot(1, 0);
  long long nr0 = 0, sumc0= 0, nr1 = 0, sumc1= 0;
  for(long long i = 1; i<=n; i++){
    if(tpr[i] == 0){
      nr0++;
      add(sumc0, critr[i]);
    }
    else{
      nr1++;
      add(sumc1, critr[i]);
    }
    cerr<<tpr[i]<<" "<<critr[i]<<"\n";
  }
  mat[0][0] = 1LL *(1LL * n * nr0 + sumc1 - sumc0 + MOD)%MOD;
  mat[0][1] = 1LL *(1LL * n * nr1 + sumc0 - sumc1 + MOD)%MOD;
  mat[1][0] = 1LL * n * nr0 % MOD;
  mat[1][1] = 1LL * n * nr1 % MOD;
  lgcput(d - 1);
  // puterea se afla in rez
  mat[0][0] = nr0;
  mat[0][1] = nr1;
  mat[1][0] = mat[1][1] = 0;

  multiply(mat, rez, aux);
  long long ans = 0;
  if(tpr[1] == 0){
    ans = 1LL * critr[1] * aux[0][0] % MOD;
  }
  else{
    ans = 1LL * critr[1] * aux[0][1] % MOD;
    add(ans, 1LL * (n - critr[1] + MOD)%MOD * ((aux[0][0] + aux[0][1])%MOD));
  }
  cout<<ans;
  return 0;
}

Compilation message

startrek.cpp:16:6: error: redefinition of 'void add(long long int&, long long int)'
   16 | void add(long long &x, long long y){
      |      ^~~
startrek.cpp:8:6: note: 'void add(long long int&, long long int)' previously defined here
    8 | void add(long long &x, long long y){
      |      ^~~