#include <bits/stdc++.h>
#define REP(v, i, j) for (int v = i; v != j; v++)
#define FORI(v) for (auto i : v)
#define FORJ(v) for (auto j : v)
#define OUT(v, a) \
    FORI(v)       \
    cout << i << a;
#define OUTS(v, a, b)      \
    cout << v.size() << a; \
    OUT(v, b)
#define in(a, n) \
    REP(i, 0, n) \
    cin >> a[i];
#define SORT(v) sort(begin(v), end(v))
#define REV(v) reverse(begin(v), end(v))
#define MEMSET(m) memset(m, -1, sizeof m)
#define pb push_back
#define fi first
#define se second
#define detachIO                      \
    ios_base::sync_with_stdio(false); \
    cin.tie(0);                       \
    cout.tie(0);
    
using namespace std;
template<typename _Tp, typename _Alloc = std::allocator<_Tp> >
bool operator==(const vector<_Tp, _Alloc>& __x, const vector<_Tp, _Alloc>& __y) {
    if(__x.size() != __y.size()) return false;
    return std::equal(__x.begin(), __x.end(), __y.begin());
}
typedef pair<int, int> pii;
typedef pair<pii, int> piii;
typedef pair<pii, pii> piiii;
const int MOD = 1e9+7;
struct modint {
    long long val;
    modint() = default;
    modint(int _val): val(_val){}
    modint operator+(modint b){ return ((this->val + b.val)%MOD); }
    modint operator-(modint b){ return ((MOD + this->val - b.val)%MOD); }
    modint operator*(modint b){ return ((this->val * b.val)%MOD); }
    modint operator^(int a){
        if(a==0)return 1;
        if(a==1)return *this;
        return (((*this)*(*this))^(a>>1))*((*this)^(a&1)); 
    }
};
modint invert(modint a){
    return a^(MOD-2);
}
modint operator/(modint a, modint b){
    return a*invert(b);
}
struct Point {
    int x, y;
    Point(): x(0), y(0) {}
    Point(int _x, int _y): x(_x),y(_y) {}
};
bool operator<(Point a, Point b){
    return a.x < b.x || (a.x==b.x && a.y<b.y);
}
bool operator==(Point a, Point b){
    return a.x==b.x && a.y==b.y;
}
int dot(Point a, Point b){
    return a.x*b.x + a.y*b.y;
}
int cross(Point a, Point b){
    return (a.x*b.y - a.y*b.x);
}
int mag2(Point a){
    return dot(a,a);
}
Point operator-(Point a, Point b){
    return {a.x-b.x,a.y-b.y};
}
// Sudut bac (radian)
double angle(Point a, Point b, Point c) {
    return acosl(dot(a-b,c-b));
}
struct Line {
    int a, b, c;
    // membuat a, b, dan c menjadi standar
    void normalise(){
        if(a<0)a*=-1,b*=-1,c*=-1;
        int _g=std::__gcd(std::__gcd(abs(a),abs(b)),abs(c));
        if(_g == 0)return;
        a/=_g,b/=_g,c/=_g;
    }
    Line() = default;
    Line(Point A, Point B){
        a = B.y-A.y;
        b = A.x-B.x;
        c = -(A.x*a+A.y*b);
        this->normalise();
    }
    Line(int _a, int _b, int _c){
        a=_a;
        b=_b;
        c=_c;
        this->normalise();
    }
};
bool operator==(Line a, Line b){
    a.normalise();
    b.normalise();
    return (a.a==b.a)&&(a.b==b.b)&&(a.c==b.c);
}
bool intersects(Point a, Point b, Point c, Point d){
    bool linear=true;
    if(cross(a-b,c-a)==cross(a-b,d-a)&&(cross(a-b,d-a)!=0))return false;
    if(cross(a-b,c-a))linear=false;
    if(cross(a-b,d-a))linear=false;
    swap(a,b);
    if(cross(a-b,c-a)==cross(a-b,d-a)&&(cross(a-b,d-a)!=0))return false;
    if(cross(a-b,c-a))linear=false;
    if(cross(a-b,d-a))linear=false;
    swap(a,c);swap(b,d);
    if(cross(a-b,c-a)==cross(a-b,d-a)&&(cross(a-b,d-a)!=0))return false;
    if(cross(a-b,c-a))linear=false;
    if(cross(a-b,d-a))linear=false;
    swap(a,b);
    if(cross(a-b,c-a)==cross(a-b,d-a)&&(cross(a-b,d-a)!=0))return false;
    if(cross(a-b,c-a))linear=false;
    if(cross(a-b,d-a))linear=false;
    if(!linear)return true;
    if(a.x > b.x)swap(a,b);
    if(c.x > d.x)swap(c,d);
    if(a.x > c.x)swap(a,c),swap(b,d);
    if(b.x >= c.x)return true;
    return false;
}
vector<Point> convexhull(vector<Point> coordinates){
    vector<Point> stck;
    REP(i,0,coordinates.size()){
        while(stck.size()>=2 && cross(stck.end()[-1]-coordinates[i],stck.end()[-2]-coordinates[i])<0)stck.pop_back();
        stck.pb(coordinates[i]);
    }
    REV(coordinates);
    REP(i,0,coordinates.size()){
        while(stck.size()>=2 && cross(stck.end()[-1]-coordinates[i],stck.end()[-2]-coordinates[i])<0)stck.pop_back();
        stck.pb(coordinates[i]);
    }
    return stck;
}
struct segment_tree {
    long long n;
    typedef int T;
    vector<T> seg;
    const T e = 0;
    T merge(T a, T b){
        if(a==e)return b;
        if(b==e)return a;
        // Processing here...
    }
    segment_tree(long long _n): n(_n),seg(4*_n+10,e) /** BUG: GCC does not initialize the elements properly this way */ {
        for(auto &i:seg)i=e;
    }
    void update(long long pos, long long l, long long r, long long idx, T val){
        if(r<idx)return;
        if(l>idx)return;
        if(l==r){
            // Merge logic...
            return;
        }
        update(pos*2+1,l,(l+r)/2,idx,val);
        update(pos*2+2,((l+r)/2)+1,r,idx,val);
        seg[pos]=merge(seg[pos*2+1],seg[pos*2+2]);
    }
    T query(long long pos, long long l, long long r, long long ql, long long qr){
        if(l > qr)return e;
        if(ql > r)return e;
        if(ql <= l && r <= qr)return seg[pos];
        if(l==r)return e;
        return merge(query(pos*2+1,l,(l+r)/2,ql,qr),query(pos*2+2,((l+r)/2)+1,r,ql,qr));
    }
};
vector<long long> adj[200100];
long long memo1[200100];
bool memo2[200100];
long long memo3[200100][2];
long long parent[200100];
long long memof[200100];
long long override=-1;
long long root=1;
bool overdata=false;
bool winning(long long node, long long parent){
    if(parent==-1)root=node;
    if(node==override)return overdata;
    bool ans=false;
    memo1[node]=0;
    FORI(adj[node]){
        if(i!=parent)ans=(not winning(i,node)) or ans,memo1[node]+=!memo1[i];
    }
    // cerr<<node<<": "<<ans<<'\n';
    ::parent[node]=parent;
    return ans;
}
void update(long long node, long long parent, bool dp){
    long long onset=0;
    if(!dp)onset++;
    FORI(adj[node]){
        if(i!=parent && !memo1[i])onset++;
    }
    memo2[node]=dp;
    memof[node]=onset;
    FORI(adj[node]){
        if(i==parent)continue;
        if(!memo1[i])onset--;
        update(i,node,onset);
        if(!memo1[i])onset++;
    }
    // cerr<<node<<"-> "<<onset<<'\n';
}
bool dp(long long node, bool b){
    if(node==root)return b;
    if(memo3[node][b]!=-1)return memo3[node][b];
    if(b == (bool)(memo1[node]))return memo3[node][b]=memo1[root];
    if(b && !memo1[node] && memo1[parent[node]]==1){
        return memo3[node][b]=dp(parent[node],false);
    }
    if(!b && memo1[node] && memo1[parent[node]]==0){
        return memo3[node][b]=dp(parent[node],true);
    }
    return memo3[node][b]=dp(parent[node],memo1[parent[node]]);
}
int main(){
    detachIO;
    long long n;cin>>n;
    long long d;cin>>d;
    REP(i,1,n){
        long long u,v;cin>>u>>v;
        adj[u].pb(v);
        adj[v].pb(u);
    }
    long long w=0,l=0;
    MEMSET(memof);
    winning(1,-1);
    update(1,-1,true);
    REP(i,1,n+1){
        assert(memof[i]!=-1);
        // assert(((bool)memof[i])==winning(i,-1));
        if(memof[i])w++;
        else l++;
    }
    // cerr<<endl;
    long long ans=0;
    winning(1,-1);
    update(1,-1,true);
    MEMSET(memo3);
    REP(i,1,n+1){
        if(dp(i,true))ans+=l;
        // cerr<<i<<" -> "<<dp(i,true)<<", "<<parent[i]<<", "<<memo1[i]<<'\n';
    }
    override=-1;
    overdata=true;
    const bool r=winning(1,-1);
    REP(i,1,n+1){
        if(r)ans+=w;
    }
    cout<<ans%MOD<<'\n';
}
| # | Verdict | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict | Execution time | Memory | Grader output | 
|---|
| Fetching results... |