Algorithm


Problem Name: Data Structures - Kitty's Calculations on a Tree

Problem Link: https://www.hackerrank.com/challenges/kittys-calculations-on-a-tree/problem?isFullScreen=true

In this HackerRank in Data Structures - Kitty's Calculations on a Tree solutions

Kitty has a tree, T , consisting of n nodes where each node is uniquely labeled from  1 to n . Her friend Alex gave her q sets, where each set contains k distinct nodes. Kitty needs to calculate the following expression on each set:

where:

{ u ,v } denotes an unordered pair of nodes belonging to the set.
 dist(u , v) denotes the number of edges on the unique (shortest) path between nodes  and .
Given T and q sets of k  distinct nodes, calculate the expression for each set. For each set of nodes, print the value of the expression modulo 10^9 + 7  on a new line.


Input Format

The first line contains two space-separated integers, the respective values of n (the number of nodes in tree T ) and  q (the number of nodes in the query set).
Each of the n - 1  subsequent lines contains two space-separated integers, a and b, that describe an undirected edge between nodes  and .
The 2 * q subsequent lines define each set over two lines in the following format:

The first line contains an integer, k  , the size of the set.
The second line contains  k space-separated integers, the set's elements.

Output Format

Print q lines of output where each line i contains the expression for the ith query, modulo 10^9 + 7.

 

 

Code Examples

#1 Code Example with C Programming

Code - C Programming


#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#define LIMIT 1000000007

typedef struct tree_node_list {
  struct tree_node *node;
  struct tree_node_list *next;
} tree_node_list;

typedef struct tree_node {
  struct tree_node *parent;
  uint32_t num;
  int32_t depth;
} tree_node;

typedef struct aux_info {
  uint32_t simple_sum;
  uint32_t level_sum;
  uint32_t marker;
} aux_info;

static void print_node(tree_node *node) {
  printf("(num: %ld, parent: %ld, depth: %ld) ", node->num,
         node->parent != NULL ? node->parent->num : 0, node->depth);
}

void print_tree(tree_node *nodes, size_t count) {
  for (int i = 0; i  <  count; i++) {
    tree_node *node = &nodes[i];
    print_node(node);
  }
  printf("\n");
}

static int order_tree(const void *lhs, const void *rhs) {
  tree_node *a = *((tree_node **)lhs);
  tree_node *b = *((tree_node **)rhs);
  return a->depth - b->depth;
}

static void add_depth(tree_node *node) {
  if (node->parent == NULL) {
    node->depth = 0;
  } else if (node->depth == -1) {
    add_depth(node->parent);
    node->depth = node->parent->depth + 1;
  }
}

int main() {
  long num_nodes, num_queries;
  scanf("%ld %ld", &num_nodes, &num_queries);
  tree_node *nodes = calloc(num_nodes, sizeof(tree_node));
  tree_node **order = calloc(num_nodes, sizeof(tree_node *));
  aux_info *info = calloc(num_nodes, sizeof(aux_info));
  for (long i = 0; i  <  num_nodes; ++i) {
    tree_node *node = &nodes[i];
    node->num = i + 1;
    node->depth = -1;
    order[i] = &nodes[i];
  }
  for (long i = 0; i  <  num_nodes - 1; i++) {
    long a, b;
    scanf("%ld %ld", &a, &b);
    tree_node *node_a = &nodes[a - 1];
    tree_node *node_b = &nodes[b - 1];
    if (node_b->parent == NULL) {
      node_b->parent = node_a;
    } else if (node_a->parent == NULL) {
      node_a->parent = node_b;
    } else {
      exit(1);
    }
  }
  for (long i = 0; i  <  num_nodes; ++i) {
    add_depth(&nodes[i]);
  }
  qsort(order, num_nodes, sizeof(tree_node *), order_tree);
  for (long i = 0; i  <  num_queries; ++i) {
    unsigned long k;
    scanf("%ld", &k);
    for (long j = 0; j  <  k; j++) {
      long node_num;
      scanf("%ld", &node_num);
      info[node_num - 1].marker = 1;
    }

    uint64_t total = 0;
    for (long j = num_nodes - 1; j >= 0; --j) {
      tree_node node = *order[j];
      uint64_t node_num = node.num;
      uint64_t node_index = node_num - 1;
      aux_info node_info = info[node_index];
      if (node_info.marker == 0 && node.depth == 0) {
        continue;
      }
      uint64_t node_simple_sum = node_info.simple_sum;
      uint64_t node_level_sum = node_info.level_sum;
      if (node_info.marker != 0) {
        // Add all the combintations made with this node and its children
        total = total + node_level_sum * node_num;
        if (total > LIMIT) {
          total = total % LIMIT;
        }
        node_simple_sum += node_num;
      } else if (node_simple_sum == 0) {
        continue;
      }
      // Increment the level
      node_level_sum += node_simple_sum;
      tree_node *parent = node.parent;
      if (parent != NULL) {
        uint64_t parent_index = parent->num - 1;
        aux_info parent_info = info[parent_index];
        uint64_t parent_simple_sum = parent_info.simple_sum;
        uint64_t parent_level_sum = parent_info.level_sum;
        // Add the combinations that this subtree makes with all sibling
        // subtrees processed so far
        total = (total + (parent_simple_sum * node_level_sum) +
                 (parent_level_sum * node_simple_sum));
        if (total > LIMIT) {
          total = total % LIMIT;
        }
        parent_simple_sum = parent_simple_sum + node_simple_sum;
        if (parent_simple_sum > LIMIT) {
          parent_simple_sum = parent_simple_sum % LIMIT;
        }
        parent_level_sum = parent_level_sum + node_level_sum;
        if (parent_level_sum > LIMIT) {
          parent_level_sum = parent_level_sum % LIMIT;
        }
        info[parent_index].simple_sum = parent_simple_sum;
        info[parent_index].level_sum = parent_level_sum;
      }
    }

    memset(info, 0, sizeof(aux_info) * num_nodes);
    long ans = total;
    printf("%ld\n", ans);
  }

  return 0;
}
Copy The Code & Try With Live Editor

#2 Code Example with C++ Programming

Code - C++ Programming


#include <iostream>
#include <cstdio>
#include <vector>
#include <cstring>
#include <utility>
using namespace std;

typedef long long LL;
typedef pair < LL,LL> pii;
const int MAX_N = 2e5 + 6;
const int MAX_P = 19;
const LL mod = 1e9 + 7;

vector<int> edg[MAX_N];
int dis[MAX_P][MAX_N];
bool visit[MAX_N];

struct Cen {
    int par;
    int depth;
    pii val_v_av;  //first --> val, second --> minus
    pii val_v;
} cen[MAX_N];

vector<int> v;
int sz[MAX_N];
int mx[MAX_N];

void dfs2(int id) {
    v.push_back(id);
    visit[id]=1;
    sz[id]=1;
    mx[id]=0;
    for (int i:edg[id]) {
        if (!visit[i]) {
            dfs2(i);
            sz[id] += sz[i];
        }
    }
}

#define SZ(x) ((int)(x).size())

int get_cen(int id) {
    v.clear();
    dfs2(id);
    int tot=SZ(v);
    int cen=-1;
    for (int i:v) {
        if (max(mx[i],tot-sz[i])  < = tot/2) {
            cen=i;
        }
        visit[i]=false;
    }
    return cen;
}

void dfs3(int id,int par,int cen_depth,int dist)  {
    dis[cen_depth][id] = dist;
    for (int i:edg[id]) {
        if (!visit[i] && i!=par) {
            dfs3(i,id,cen_depth,dist+1);
        }
    }
}

void dfs(int id,int cen_par,int cen_depth) {
    int ccen=get_cen(id);
    dfs3(ccen,ccen,cen_depth,0);
    cen[ccen]={cen_par,cen_depth,{0,0},{0,0}};
    visit[ccen]=1;
    for (int i:edg[ccen]) {
        if (!visit[i]) dfs(i,ccen,cen_depth+1);
    }
}

pii operator+(const pii &p1,const pii &p2) {
    return make_pair(p1.first+p2.first,p1.second+p2.second);
}

pii operator-(const pii &p1,const pii &p2) {
    return make_pair(p1.first-p2.first,p1.second-p2.second);
}

pii operator+=(pii &p1,const pii &p2) {
    p1 = p1 + p2;
    return p1;
}

pii operator-=(pii &p1,const pii &p2) {
    p1 = p1 - p2;
    return p1;
}

void Pure(pii &p) {
    p.first = (p.first%mod + mod) % mod;
    p.second = (p.second%mod + mod) % mod;
}

void addd(LL x) {
    LL p=x;
    while (p!=-1) {
        cen[p].val_v += {x,0};
        cen[p].val_v_av += {x*dis[cen[p].depth][x],0};
        if (cen[p].par != -1) {
            int par=cen[p].par;
            cen[p].val_v -= {0,x};
            cen[p].val_v_av -= {0,x*dis[cen[par].depth][x]};
        }
        Pure(cen[p].val_v);
        Pure(cen[p].val_v_av);
        p=cen[p].par;
    }
}

void dell(LL x) {
    LL p=x;
    while (p!=-1) {
        cen[p].val_v -= {x,0};
        cen[p].val_v_av -= {x*dis[cen[p].depth][x],0};
        if (cen[p].par != -1) {
            int par=cen[p].par;
            cen[p].val_v += {0,x};
            cen[p].val_v_av += {0,x*dis[cen[par].depth][x]};
        }
        Pure(cen[p].val_v);
        Pure(cen[p].val_v_av);
        p=cen[p].par;
    }
}

LL query(LL x) {
    LL ret=0;
    LL v=0;
    LL v_av=0;
    int p=x;
    while (p!=-1) {
        v += cen[p].val_v.first;
        v_av += cen[p].val_v_av.first;
        ret += x*v_av;
        ret %= mod;
        ret += x*dis[cen[p].depth][x]*v;
        ret %= mod;
        v = cen[p].val_v.second;
        v_av = cen[p].val_v_av.second;
        p=cen[p].par;
    }
    return ret;
}

LL pow(LL a,LL n,LL mod) {
    if (n==0) return 1;
    else if (n==1) return a;
    LL ret=pow(a,n/2,mod);
    ret*=ret;
    ret%=mod;
    if (n&1) {
        ret*=a;
        ret%=mod;
    }
    return ret;
}

int main () {
    int n,q;
    scanf("%d %d",&n,&q);
    for (int i=1;n-1>=i;i++) {
        int a,b;
        scanf("%d %d",&a,&b);
        edg[a].push_back(b);
        edg[b].push_back(a);
    }
    dfs(1,-1,0);
    while (q--) {
        int k;
        scanf("%d",&k);
        vector<int> v;
        while (k--) {
            int x;
            scanf("%d",&x);
            v.push_back(x);
        }
        for (int i:v) addd(i);
        LL ans=0;
        for (int i:v) {
            ans += query(i);
            ans%=mod;
        }
        for (int i:v) dell(i);
        printf("%lld\n",(ans*pow(2,mod-2,mod) + mod)%mod);
    }
}
Copy The Code & Try With Live Editor

#3 Code Example with Java Programming

Code - Java Programming


import java.io.*;
import java.util.*;

public class Solution {
    
  static final long MOD = 1_000_000_007;

  static int mul(long x, long y, long z) {
    return (int) ((((x * y) % MOD) * z) % MOD);
  }

  static int mul(long x, long y) {
    return (int) ((x * y) % MOD);
  }

  static int sum(long x, long y) {
    return (int) ((x + y) % MOD);
  }

  static int sum(long x, long y, long z) {
    return (int) ((x + y + z) % MOD);
  }

  static int[] nxt;
  static int[] succ;
  static int[] ptr;
  static int[] set;
  static int[] dep;
  static int[] parent;
  static int index = 1;

  static void addEdge(int u, int v) {
    nxt[index] = ptr[u];
    ptr[u] = index;
    parent[v] = u;
    succ[index++] = v;
  }

  static void bfsDeep(int source) {
    Queue < Integer> q = new LinkedList<>();
    q.add(source);
    while (!q.isEmpty()) {
      int u = q.poll();
      for (int i = ptr[u]; i > 0; i = nxt[i]) {
        int v = succ[i];
        q.add(v);
        dep[v] = dep[u] + 1;
      }
    }
  }

  static int lowestCommonAncestor(int u, int v) {
    if (dep[u]  <  dep[v]) {
      int temp = u;
      u = v;
      v = temp;
    }
    while (dep[u] > dep[v]) {
      u = parent[u];
    }

    if (u == v) {
      return u;
    }
    while (parent[u] != parent[v]) {
      u = parent[u];
      v = parent[v];
    }

    return parent[u];
  }
  
  static boolean[] visited;

  static int lowestCommonAncestorVis(int u, int v) {
    if (dep[u]  <  dep[v]) {
      int temp = u;
      u = v;
      v = temp;
    }
    visited[u] = false;
    visited[v] = false;
    while (dep[u] > dep[v]) {
      u = parent[u];
      visited[u] = false;
    }

    if (u == v) {
      return u;
    }
    while (parent[u] != parent[v]) {
      u = parent[u];
      v = parent[v];
      visited[u] = false;
      visited[v] = false;
    }
    visited[parent[u]] = false;

    return parent[u];
  }

  static boolean[] isSet;

  static class NodeDfs {
    int u;
    int count = 1;
    long parzialInv = 0;
    long sumNode = 0;
    long tot = 0;
    long parz2 = 0;
    NodeDfs parent = null;
    boolean start = true;

    public NodeDfs() {
    }
    
    public void reset(int u, NodeDfs parent) {
      this.u = u;
      this.parent = parent;
      parzialInv = tot = parz2 = sumNode = 0;
      start = true;
      count = 1;
    }
  }

  static int stackIndex = 0;
  static NodeDfs[] nodes;

  static NodeDfs dfs(int u) {
      NodeDfs root = nodes[0];
      root.reset(u, null);
      stackIndex = 1;
      
      while (stackIndex > 0) {
          NodeDfs node = nodes[stackIndex-1];
          if (node.start) {
              visited[node.u] = true;
              
              if (isSet[node.u]) {
                for (int i = ptr[node.u]; i > 0; i = nxt[i]) {
                    if (!visited[succ[i]]) {
                        nodes[stackIndex].reset(succ[i], node);
                        stackIndex++;
                    }
                }
              } else {
                int uu = node.u;
                while(true) {
                    int j = 0;
                    int v = 0;
                    for (int i = ptr[uu]; i > 0; i = nxt[i]) {
                        if (!visited[succ[i]]) {
                            nodes[stackIndex++].reset(v = succ[i], node);
                            j++;
                        }
                    }
                    if (isSet[v] || j != 1) {
                        break;
                    }
                      node.count++;
                    stackIndex--;
                    uu = v;
                }
              }
              
              
              node.start = false;
          } else {
        if (node.count > 1) {
                  node.tot = sum(node.tot, mul(node.sumNode, node.parzialInv), MOD - node.parz2);
                  node.parzialInv = sum(node.parzialInv, mul(node.count-1, node.sumNode));
        } else {
                  node.tot = sum(node.tot, mul(node.sumNode, node.parzialInv), MOD - node.parz2);
        }
                if (isSet[node.u]) {
              node.sumNode += node.u+1;
          }
        if (node.u != u) {
            NodeDfs nodeP = node.parent;
          nodeP.sumNode = sum(nodeP.sumNode, node.sumNode);
          nodeP.parzialInv = sum(nodeP.parzialInv, node.parzialInv, node.sumNode);
          nodeP.parz2 = sum(nodeP.parz2, mul(node.parzialInv + node.sumNode, node.sumNode));
          if (isSet[nodeP.u]) {
              nodeP.tot = sum(nodeP.tot, node.tot, mul(node.sumNode + node.parzialInv, (nodeP.u + 1)));
          } else {
            nodeP.tot = sum(nodeP.tot, node.tot);
          }
        }

        stackIndex--;
      }
      }
      
    return root;
  }

  static final int MAX_SIMPLY = 3;

  public static void main(String[] args) throws IOException {
    BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
    BufferedWriter bw = new BufferedWriter(new FileWriter(System.getenv("OUTPUT_PATH")));

    StringTokenizer st = new StringTokenizer(br.readLine());
    int n = Integer.parseInt(st.nextToken());
    int q = Integer.parseInt(st.nextToken());

    nxt = new int[2 * n];
    succ = new int[2 * n];
    ptr = new int[n];
    dep = new int[n];
    parent = new int[n];
    nodes = new NodeDfs[n];
    for (int i = 0; i  <  n; i++) {
        nodes[i] = new NodeDfs();
    }
    
    for (int i = 0; i  <  n - 1; i++) {
      st = new StringTokenizer(br.readLine());
      int u = Integer.parseInt(st.nextToken()) - 1;
      int v = Integer.parseInt(st.nextToken()) - 1;
      if (u  <  v) {
        addEdge(u, v);
      } else {
        addEdge(v, u);
      }
    }
    bfsDeep(0);

    visited = new boolean[n];
    isSet = new boolean[n];

    for (int h = 1; h  < = q; h++) {
      st = new StringTokenizer(br.readLine());
      int k = Integer.parseInt(st.nextToken());
      st = new StringTokenizer(br.readLine());
      set = new int[k];
      if (k >= MAX_SIMPLY) {
        Arrays.fill(isSet, false);
      }
      for (int i = 0; i  <  k; i++) {
        int u = Integer.parseInt(st.nextToken()) - 1;
        isSet[u] = true;
        set[i] = u;
      }

      long result = 0;
      if (k  <  MAX_SIMPLY) {
        for (int i = 0; i < k - 1; i++) {
          int x = set[i];
          for (int j = i + 1; j  <  k; j++) {
            int y = set[j];
            int z = lowestCommonAncestor(x, y);
            int dist = dep[y] + dep[x] - 2 * dep[z];
            result = sum(result, mul(x + 1, y + 1, dist));
          }
        }
      } else {
        Arrays.fill(visited, true);
        Arrays.sort(set);
        int x = set[set.length -1];
        for (int i = k-2; i >= 0; i--) {
            if (visited[set[i]]) {
                x = lowestCommonAncestorVis(x, set[i]);
            }
        }
        NodeDfs node = dfs(x);
        result = node.tot;
      }
      bw.write(String.valueOf(result));
      bw.newLine();
    }

    bw.close();
    br.close();
  }
}
Copy The Code & Try With Live Editor

#4 Code Example with Python Programming

Code - Python Programming


from collections import Counter, defaultdict

MOD = 10**9 + 7

def read_row():
    return (int(x) for x in input().split())

def mul(x, y):
    return (x * y) % MOD

def add(*args):
    return sum(args) % MOD

def sub(x, y):
    return (x - y) % MOD

n, q = read_row()

# Construct adjacency list of the tree
adj_list = defaultdict(list)

for _ in range(n - 1):
    u, v = read_row()
    adj_list[u].append(v)
    adj_list[v].append(u)

# Construct element to set mapping {element: [sets it belongs to]}
elements = {v: set() for v in adj_list}

for set_no in range(q):
    read_row()
    for x in read_row():
        elements[x].add(set_no)

# Do BFS to find parent for each node and order them in reverse depth
root = next(iter(adj_list))
current = [root]
current_depth = 0
order = []
parent = {root: None}
depth = {root: current_depth}

while current:
    current_depth += 1
    order.extend(current)
    nxt = []
    for node in current:
        for neighbor in adj_list[node]:
            if neighbor not in parent:
                parent[neighbor] = node
                depth[neighbor] = current_depth
                nxt.append(neighbor)

    current = nxt

# Process nodes in the order created above
score = Counter()
# {node: {set_a: [depth, sum of nodes, flow]}}
state = {}
for node in reversed(order):
    states = [state[neighbor] for neighbor in adj_list[node] if neighbor != parent[node]]
    largest = {s: [depth[node], node, 0] for s in elements[node]}

    if states:
        max_index = max(range(len(states)), key=lambda x: len(states[x]))
        if len(states[max_index]) > len(largest):
            states[max_index], largest = largest, states[max_index]


    sets = defaultdict(list)
    for cur_state in states:
        for set_no, v in cur_state.items():
            sets[set_no].append(v)

    for set_no, states in sets.items():
        if len(states) == 1 and set_no not in largest:
            largest[set_no] = states[0]
            continue

        if set_no in largest:
            states.append(largest.pop(set_no))

        total_flow = 0
        total_node_sum = 0

        for node_depth, node_sum, node_flow in states:
            flow_delta = mul(node_depth - depth[node], node_sum)
            total_flow = add(total_flow, flow_delta, node_flow)
            total_node_sum += node_sum

        set_score = 0

        for node_depth, node_sum, node_flow in states:
            node_flow = add(mul(node_depth - depth[node], node_sum), node_flow)
            diff = mul(sub(total_flow, node_flow), node_sum)
            set_score = add(set_score, diff)

        score[set_no] = add(score[set_no], set_score)
        largest[set_no] = (depth[node], total_node_sum, total_flow)

    state[node] = largest

print(*(score[i] for i in range(q)), sep='\n')
Copy The Code & Try With Live Editor
Advertisements

Demonstration


Previous
[Solved] Array Manipulation in Hackerrank - Hacerrank solution C, C++, java,js, Python
Next
[Solved] Square-Ten Tree solution in Hackerrank - Hacerrank solution C, C++, java,js, Python