#include<bits/stdc++.h>
using namespace std;
const int N = 1e5 + 5;
const int MOD = 1e9 + 7;
int tp[N], tpr[N];
int crit[N], critr[N];
vector<int> gr[N];
void add(long long &x, long long y){
y%=MOD;
x+=y;
if(x>=MOD)
x-=MOD;
if(x<0)
x+=MOD;
}
void add(int &x, long long y){
y%=MOD;
x+=y;
if(x>=MOD)
x-=MOD;
if(x<0)
x+=MOD;
}
void dfs(int nod, int dad){
tp[nod] = 0;
for(auto x:gr[nod]){
if(x != dad){
dfs(x, nod);
if(tp[x] == 0)
tp[nod] = 1;
}
}
bool ok = false;
if(tp[nod] == 0)
crit[nod] = 1;
else
crit[nod] = 0;
for(auto x:gr[nod]){
if(x != dad){
if(tp[nod] == 0)
crit[nod] += crit[x];
else{
if(crit[x] == 0)
continue;
if(crit[nod] == 0)
crit[nod] = crit[x];
else
ok = true;
}
}
}
if(ok == true)
crit[nod] = 0;
}
void changeroot(int nod, int dad){
int nr0 = 0, sum = 0, unq = 0;
for(auto x:gr[nod]){
if(tp[x] == 0 )
nr0++, unq = crit[x];
sum += crit[x];
}
if(nr0 == 0)
tp[nod] = 0;
else
tp[nod] = 1;
if(tp[nod] == 0){
crit[nod] = sum + 1;
}
else if(tp[nod] == 1){
crit[nod] = 0;
if(nr0 == 1)
crit[nod] = unq;
}
critr[nod] = crit[nod];
tpr[nod] = tp[nod];
for(auto x:gr[nod]){
if(x == dad)
continue;
int oldtp = tp[nod], oldcrit = crit[nod];
if(tp[nod] == 0){
tp[nod] = 0;
crit[nod] = sum + 1 - crit[x];
}
else{
if(tp[x] == 1){
tp[nod] = 1;
crit[nod] = oldcrit;
}
else{
if(nr0 == 1){
tp[nod] = 0;
crit[nod] = sum + 1 - crit[x];
}
else if(nr0 == 2){
tp[nod] = 1;
int othunq = 0;
for(auto vc:gr[nod]){
if(vc == x)
continue;
if(tp[vc] == 0)
othunq = crit[vc];
}
crit[nod] = othunq;
}
else if(nr0 >= 3){
tp[nod] = 1;
crit[nod] = 0;
}
}
}
changeroot(x, nod);
tp[nod] = oldtp;
crit[nod] = oldcrit;
}
}
int main()
{
//freopen(".in","r",stdin);
ios::sync_with_stdio(false);
cin.tie(0),cout.tie(0);
int n, d;
cin>>n>>d;
for(int i = 1; i < n; i++){
int u, v;
cin>>u>>v;
gr[u].push_back(v);
gr[v].push_back(u);
}
dfs(1, 0);
changeroot(1, 0);
long long nr0 = 0, sumc0= 0, nr1 = 0, sumc1= 0;
for(int i = 1; i<=n; i++){
if(tpr[i] == 0){
nr0++;
add(sumc0, critr[i]);
}
else{
nr1++;
add(sumc1, critr[i]);
}
}
if(tpr[1] == 0){
cout<<1LL * critr[1] * nr0 % MOD;
}
else{
//cout<<(1LL * n * n % MOD - 1LL * critr[1] * nr0 % MOD + MOD) % MOD;
}
return 0;
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
2 ms |
2636 KB |
Output is correct |
2 |
Incorrect |
2 ms |
2636 KB |
Output isn't correct |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
2 ms |
2636 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
2 ms |
2636 KB |
Output is correct |
2 |
Incorrect |
2 ms |
2636 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
2 ms |
2636 KB |
Output is correct |
2 |
Incorrect |
2 ms |
2636 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
2 ms |
2636 KB |
Output is correct |
2 |
Incorrect |
2 ms |
2636 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
2 ms |
2636 KB |
Output is correct |
2 |
Incorrect |
2 ms |
2636 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
2 ms |
2636 KB |
Output is correct |
2 |
Incorrect |
2 ms |
2636 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
2 ms |
2636 KB |
Output is correct |
2 |
Incorrect |
2 ms |
2636 KB |
Output isn't correct |