Submission #960909

#TimeUsernameProblemLanguageResultExecution timeMemory
960909danikoynovConstruction of Highway (JOI18_construction)C++14
100 / 100
402 ms30440 KiB
#include <bits/stdc++.h>
#define endl '\n'

using namespace std;
typedef long long ll;

void speed()
{
     ios_base::sync_with_stdio(false);
     cin.tie(NULL);
     cout.tie(NULL);
}

const int maxn = 1e5 + 10;

struct edge
{
     int v, u;

     edge(int _v = 0, int _u = 0)
     {
          v = _v;
          u = _u;
     }
}edges[maxn];

int n, c[maxn];
vector < int > adj[maxn];

void input()
{
     cin >> n;
     for (int i = 1; i <= n; i ++)
          cin >> c[i];
     for (int i = 1; i < n; i ++)
     {
          int v, u;
          cin >> v >> u;
          adj[v].push_back(u);
          edges[i] = edge(v, u);
     }
}

int sub[maxn], depth[maxn], heavy[maxn], par[maxn];

void calc(int v)
{
     heavy[v] = -1;
     sub[v] = 1;
     for (int u : adj[v])
     {
          par[u] = v;
          depth[u] = depth[v] + 1;
          calc(u);
          sub[v] += sub[u];
          if (heavy[v] == -1 || sub[u] > sub[heavy[v]])
               heavy[v] = u;
     }
}

int cidx[maxn], chead[maxn], cpos[maxn];
vector < int > chain[maxn];
int csz = 1;
void hld(int v, int head)
{
     cidx[v] = csz;
     chead[v] = head;
     chain[cidx[v]].push_back(v);
     if (v == head)
          cpos[v] = 1;
     else
          cpos[v] = cpos[par[v]] + 1;

     if (heavy[v] != -1)
          hld(heavy[v], head);

     for (int u : adj[v])
     {
          if (u == heavy[v])
               continue;
          csz ++;
          hld(u, u);
     }
}


set < pair < int, int > > range[maxn];

vector < pair < int, int > > get_path_values(int v)
{
     vector < pair < int, int > > path;
     while(v != 0)
     {

          int d = cpos[v];
          while(d > 0)
          {

               set < pair < int, int > > :: iterator it = range[cidx[v]].lower_bound({d, n + 1});
               /**for (pair < int, int > cur : range[cidx[v]])
               cout << cur.first << " " << cur.second << endl;
               cout << d << endl;*/
               it = prev(it);
               path.push_back({d - it -> first + 1, c[chain[cidx[v]][it -> second - 1]]});
               d = it -> first - 1;
          }
          v = par[chead[v]];
     }
     reverse(path.begin(), path.end());
     return path;
}

void clear_path(int v)
{
     int col = c[v];
     range[cidx[v]].clear();
     range[cidx[v]].insert({1, cpos[v]});
     ///cout << "here " << 1 << " " << cpos[v] << endl;
     v = par[chead[v]];
     while(v != 0)
     {
          ///cout << "back " << v << endl;
          int d = cpos[v];
          ///cout << "cidx " << cidx[v] << "  " << d << endl;
          set < pair < int, int > > :: iterator it = range[cidx[v]].lower_bound({d, n + 1});
          it = prev(it);
          ///cout << it -> first << " " << it -> second << endl;
          pair < int, int > nw = {d + 1, it -> second};
          range[cidx[v]].erase(it);
          range[cidx[v]].insert(nw);

          while(true)
          {
               set < pair < int, int > > :: iterator it = range[cidx[v]].lower_bound({d, n + 1});

               if (it == range[cidx[v]].begin())
                    {
                         ///cout << it -> first << " --- " << it -> second << endl;
                         break;
                    }
               it = prev(it);
               range[cidx[v]].erase(it);
          }

          range[cidx[v]].insert({1, d});
          c[v] = col;

          v = par[chead[v]];

     }
}


ll fen[maxn];

void update(int pos, ll val)
{
     for (int i = pos; i <= n; i += (i & -i))
          fen[i] += val;
}

ll query(int pos)
{
     ll s = 0;
     for (int i = pos; i > 0; i -= (i & -i))
          s += fen[i];
     return s;
}

void find_inversions(vector < pair < int, int > > &path)
{
     ll ans = 0;
     int sz = path.size();
     for (int i = sz - 1; i >= 0; i --)
     {
          ans = ans + query(path[i].second - 1) * (ll)(path[i].first);
          update(path[i].second, path[i].first);
     }
     for (int i = 0; i < sz; i ++)
          update(path[i].second, -path[i].first);
     /**for (int i = 0; i < sz; i ++)
          for (int j = i + 1; j < sz; j ++)
          {
               if (path[i].second > path[j].second)
               {
                    ans = ans + (ll)(path[i].first) * (ll)(path[j].first);
               }
          }*/

     cout << ans << endl;

}
void simulate()
{
     //for (int i = 1; i <= n; i ++)
       //cout << i << " : " << cidx[i] <<  " " << cpos[i] << " " << heavy[i] << " " << sub[i] << endl;
     //exit(0);
     range[cidx[1]].insert({1, 1});
     for (int i = 1; i < n; i ++)
     {
          ///cout << "pass " << i << endl;
          vector < pair < int, int > > path = get_path_values(edges[i].v);
          //cout << "survive " << i << endl;
          //cout << "inversions  ";
          find_inversions(path);
          //cout << "path " << endl;
          //for (pair < int, int > cur : path)
               //cout << cur.first << " " << cur.second << endl;
          //cout << "path" << endl;
          //for (pair <int, int > cur : path)
            //   cout << cur.first << " " << cur.second << endl;
          clear_path(edges[i].u);
          /**cout << "edge " << i << endl;
          for (int i = 1; i <= csz; i ++)
          {
               cout << "chain " << i << endl;
               for (pair < int, int > cur : range[i])
                    cout << cur.first << " " << cur.second << endl;
          }
          if (i == 4)
               break;*/
     }
}

void compress_data()
{
     vector < int > d;
     for (int i = 1; i <= n; i ++)
          d.push_back(c[i]);
     sort(d.begin(), d.end());
     unordered_map < int, int > mp;
     for (int i = 0; i < n; i ++)
          mp[d[i]] = i + 1;
     for (int i = 1; i <= n; i ++)
          c[i] = mp[c[i]];
}
void solve()
{
     input();
     compress_data();
     calc(1);
     hld(1, 1);
     simulate();
}

int main()
{
     speed();
     solve();
     return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...