## Algorithm

Problem Name: Data Structures - Kitty's Calculations on a Tree

In this HackerRank in Data Structures - Kitty's Calculations on a Tree solutions

```Kitty has a tree, T , consisting of n nodes where each node is uniquely labeled from  1 to n . Her friend Alex gave her q sets, where each set contains k distinct nodes. Kitty needs to calculate the following expression on each set:

where:

{ u ,v } denotes an unordered pair of nodes belonging to the set.
dist(u , v) denotes the number of edges on the unique (shortest) path between nodes  and .
Given T and q sets of k  distinct nodes, calculate the expression for each set. For each set of nodes, print the value of the expression modulo 10^9 + 7  on a new line.

Input Format

The first line contains two space-separated integers, the respective values of n (the number of nodes in tree T ) and  q (the number of nodes in the query set).
Each of the n - 1  subsequent lines contains two space-separated integers, a and b, that describe an undirected edge between nodes  and .
The 2 * q subsequent lines define each set over two lines in the following format:

The first line contains an integer, k  , the size of the set.
The second line contains  k space-separated integers, the set's elements.

Output Format

Print q lines of output where each line i contains the expression for the ith query, modulo 10^9 + 7.```

## Code Examples

### #1 Code Example with C Programming

```Code - C Programming```

``````
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#define LIMIT 1000000007

typedef struct tree_node_list {
struct tree_node *node;
struct tree_node_list *next;
} tree_node_list;

typedef struct tree_node {
struct tree_node *parent;
uint32_t num;
int32_t depth;
} tree_node;

typedef struct aux_info {
uint32_t simple_sum;
uint32_t level_sum;
uint32_t marker;
} aux_info;

static void print_node(tree_node *node) {
printf("(num: %ld, parent: %ld, depth: %ld) ", node->num,
node->parent != NULL ? node->parent->num : 0, node->depth);
}

void print_tree(tree_node *nodes, size_t count) {
for (int i = 0; i  <  count; i++) {
tree_node *node = &nodes[i];
print_node(node);
}
printf("\n");
}

static int order_tree(const void *lhs, const void *rhs) {
tree_node *a = *((tree_node **)lhs);
tree_node *b = *((tree_node **)rhs);
return a->depth - b->depth;
}

if (node->parent == NULL) {
node->depth = 0;
} else if (node->depth == -1) {
node->depth = node->parent->depth + 1;
}
}

int main() {
long num_nodes, num_queries;
scanf("%ld %ld", &num_nodes, &num_queries);
tree_node *nodes = calloc(num_nodes, sizeof(tree_node));
tree_node **order = calloc(num_nodes, sizeof(tree_node *));
aux_info *info = calloc(num_nodes, sizeof(aux_info));
for (long i = 0; i  <  num_nodes; ++i) {
tree_node *node = &nodes[i];
node->num = i + 1;
node->depth = -1;
order[i] = &nodes[i];
}
for (long i = 0; i  <  num_nodes - 1; i++) {
long a, b;
scanf("%ld %ld", &a, &b);
tree_node *node_a = &nodes[a - 1];
tree_node *node_b = &nodes[b - 1];
if (node_b->parent == NULL) {
node_b->parent = node_a;
} else if (node_a->parent == NULL) {
node_a->parent = node_b;
} else {
exit(1);
}
}
for (long i = 0; i  <  num_nodes; ++i) {
}
qsort(order, num_nodes, sizeof(tree_node *), order_tree);
for (long i = 0; i  <  num_queries; ++i) {
unsigned long k;
scanf("%ld", &k);
for (long j = 0; j  <  k; j++) {
long node_num;
scanf("%ld", &node_num);
info[node_num - 1].marker = 1;
}

uint64_t total = 0;
for (long j = num_nodes - 1; j >= 0; --j) {
tree_node node = *order[j];
uint64_t node_num = node.num;
uint64_t node_index = node_num - 1;
aux_info node_info = info[node_index];
if (node_info.marker == 0 && node.depth == 0) {
continue;
}
uint64_t node_simple_sum = node_info.simple_sum;
uint64_t node_level_sum = node_info.level_sum;
if (node_info.marker != 0) {
// Add all the combintations made with this node and its children
total = total + node_level_sum * node_num;
if (total > LIMIT) {
total = total % LIMIT;
}
node_simple_sum += node_num;
} else if (node_simple_sum == 0) {
continue;
}
// Increment the level
node_level_sum += node_simple_sum;
tree_node *parent = node.parent;
if (parent != NULL) {
uint64_t parent_index = parent->num - 1;
aux_info parent_info = info[parent_index];
uint64_t parent_simple_sum = parent_info.simple_sum;
uint64_t parent_level_sum = parent_info.level_sum;
// Add the combinations that this subtree makes with all sibling
// subtrees processed so far
total = (total + (parent_simple_sum * node_level_sum) +
(parent_level_sum * node_simple_sum));
if (total > LIMIT) {
total = total % LIMIT;
}
parent_simple_sum = parent_simple_sum + node_simple_sum;
if (parent_simple_sum > LIMIT) {
parent_simple_sum = parent_simple_sum % LIMIT;
}
parent_level_sum = parent_level_sum + node_level_sum;
if (parent_level_sum > LIMIT) {
parent_level_sum = parent_level_sum % LIMIT;
}
info[parent_index].simple_sum = parent_simple_sum;
info[parent_index].level_sum = parent_level_sum;
}
}

memset(info, 0, sizeof(aux_info) * num_nodes);
long ans = total;
printf("%ld\n", ans);
}

return 0;
}
``````
Copy The Code &

### #2 Code Example with C++ Programming

```Code - C++ Programming```

``````
#include <iostream>
#include <cstdio>
#include <vector>
#include <cstring>
#include <utility>
using namespace std;

typedef long long LL;
typedef pair < LL,LL> pii;
const int MAX_N = 2e5 + 6;
const int MAX_P = 19;
const LL mod = 1e9 + 7;

vector<int> edg[MAX_N];
int dis[MAX_P][MAX_N];
bool visit[MAX_N];

struct Cen {
int par;
int depth;
pii val_v_av;  //first --> val, second --> minus
pii val_v;
} cen[MAX_N];

vector<int> v;
int sz[MAX_N];
int mx[MAX_N];

void dfs2(int id) {
v.push_back(id);
visit[id]=1;
sz[id]=1;
mx[id]=0;
for (int i:edg[id]) {
if (!visit[i]) {
dfs2(i);
sz[id] += sz[i];
}
}
}

#define SZ(x) ((int)(x).size())

int get_cen(int id) {
v.clear();
dfs2(id);
int tot=SZ(v);
int cen=-1;
for (int i:v) {
if (max(mx[i],tot-sz[i])  < = tot/2) {
cen=i;
}
visit[i]=false;
}
return cen;
}

void dfs3(int id,int par,int cen_depth,int dist)  {
dis[cen_depth][id] = dist;
for (int i:edg[id]) {
if (!visit[i] && i!=par) {
dfs3(i,id,cen_depth,dist+1);
}
}
}

void dfs(int id,int cen_par,int cen_depth) {
int ccen=get_cen(id);
dfs3(ccen,ccen,cen_depth,0);
cen[ccen]={cen_par,cen_depth,{0,0},{0,0}};
visit[ccen]=1;
for (int i:edg[ccen]) {
if (!visit[i]) dfs(i,ccen,cen_depth+1);
}
}

pii operator+(const pii &p1,const pii &p2) {
return make_pair(p1.first+p2.first,p1.second+p2.second);
}

pii operator-(const pii &p1,const pii &p2) {
return make_pair(p1.first-p2.first,p1.second-p2.second);
}

pii operator+=(pii &p1,const pii &p2) {
p1 = p1 + p2;
return p1;
}

pii operator-=(pii &p1,const pii &p2) {
p1 = p1 - p2;
return p1;
}

void Pure(pii &p) {
p.first = (p.first%mod + mod) % mod;
p.second = (p.second%mod + mod) % mod;
}

LL p=x;
while (p!=-1) {
cen[p].val_v += {x,0};
cen[p].val_v_av += {x*dis[cen[p].depth][x],0};
if (cen[p].par != -1) {
int par=cen[p].par;
cen[p].val_v -= {0,x};
cen[p].val_v_av -= {0,x*dis[cen[par].depth][x]};
}
Pure(cen[p].val_v);
Pure(cen[p].val_v_av);
p=cen[p].par;
}
}

void dell(LL x) {
LL p=x;
while (p!=-1) {
cen[p].val_v -= {x,0};
cen[p].val_v_av -= {x*dis[cen[p].depth][x],0};
if (cen[p].par != -1) {
int par=cen[p].par;
cen[p].val_v += {0,x};
cen[p].val_v_av += {0,x*dis[cen[par].depth][x]};
}
Pure(cen[p].val_v);
Pure(cen[p].val_v_av);
p=cen[p].par;
}
}

LL query(LL x) {
LL ret=0;
LL v=0;
LL v_av=0;
int p=x;
while (p!=-1) {
v += cen[p].val_v.first;
v_av += cen[p].val_v_av.first;
ret += x*v_av;
ret %= mod;
ret += x*dis[cen[p].depth][x]*v;
ret %= mod;
v = cen[p].val_v.second;
v_av = cen[p].val_v_av.second;
p=cen[p].par;
}
return ret;
}

LL pow(LL a,LL n,LL mod) {
if (n==0) return 1;
else if (n==1) return a;
LL ret=pow(a,n/2,mod);
ret*=ret;
ret%=mod;
if (n&1) {
ret*=a;
ret%=mod;
}
return ret;
}

int main () {
int n,q;
scanf("%d %d",&n,&q);
for (int i=1;n-1>=i;i++) {
int a,b;
scanf("%d %d",&a,&b);
edg[a].push_back(b);
edg[b].push_back(a);
}
dfs(1,-1,0);
while (q--) {
int k;
scanf("%d",&k);
vector<int> v;
while (k--) {
int x;
scanf("%d",&x);
v.push_back(x);
}
LL ans=0;
for (int i:v) {
ans += query(i);
ans%=mod;
}
for (int i:v) dell(i);
printf("%lld\n",(ans*pow(2,mod-2,mod) + mod)%mod);
}
}
``````
Copy The Code &

### #3 Code Example with Java Programming

```Code - Java Programming```

``````
import java.io.*;
import java.util.*;

public class Solution {

static final long MOD = 1_000_000_007;

static int mul(long x, long y, long z) {
return (int) ((((x * y) % MOD) * z) % MOD);
}

static int mul(long x, long y) {
return (int) ((x * y) % MOD);
}

static int sum(long x, long y) {
return (int) ((x + y) % MOD);
}

static int sum(long x, long y, long z) {
return (int) ((x + y + z) % MOD);
}

static int[] nxt;
static int[] succ;
static int[] ptr;
static int[] set;
static int[] dep;
static int[] parent;
static int index = 1;

static void addEdge(int u, int v) {
nxt[index] = ptr[u];
ptr[u] = index;
parent[v] = u;
succ[index++] = v;
}

static void bfsDeep(int source) {
Queue < Integer> q = new LinkedList<>();
while (!q.isEmpty()) {
int u = q.poll();
for (int i = ptr[u]; i > 0; i = nxt[i]) {
int v = succ[i];
dep[v] = dep[u] + 1;
}
}
}

static int lowestCommonAncestor(int u, int v) {
if (dep[u]  <  dep[v]) {
int temp = u;
u = v;
v = temp;
}
while (dep[u] > dep[v]) {
u = parent[u];
}

if (u == v) {
return u;
}
while (parent[u] != parent[v]) {
u = parent[u];
v = parent[v];
}

return parent[u];
}

static boolean[] visited;

static int lowestCommonAncestorVis(int u, int v) {
if (dep[u]  <  dep[v]) {
int temp = u;
u = v;
v = temp;
}
visited[u] = false;
visited[v] = false;
while (dep[u] > dep[v]) {
u = parent[u];
visited[u] = false;
}

if (u == v) {
return u;
}
while (parent[u] != parent[v]) {
u = parent[u];
v = parent[v];
visited[u] = false;
visited[v] = false;
}
visited[parent[u]] = false;

return parent[u];
}

static boolean[] isSet;

static class NodeDfs {
int u;
int count = 1;
long parzialInv = 0;
long sumNode = 0;
long tot = 0;
long parz2 = 0;
NodeDfs parent = null;
boolean start = true;

public NodeDfs() {
}

public void reset(int u, NodeDfs parent) {
this.u = u;
this.parent = parent;
parzialInv = tot = parz2 = sumNode = 0;
start = true;
count = 1;
}
}

static int stackIndex = 0;
static NodeDfs[] nodes;

static NodeDfs dfs(int u) {
NodeDfs root = nodes[0];
root.reset(u, null);
stackIndex = 1;

while (stackIndex > 0) {
NodeDfs node = nodes[stackIndex-1];
if (node.start) {
visited[node.u] = true;

if (isSet[node.u]) {
for (int i = ptr[node.u]; i > 0; i = nxt[i]) {
if (!visited[succ[i]]) {
nodes[stackIndex].reset(succ[i], node);
stackIndex++;
}
}
} else {
int uu = node.u;
while(true) {
int j = 0;
int v = 0;
for (int i = ptr[uu]; i > 0; i = nxt[i]) {
if (!visited[succ[i]]) {
nodes[stackIndex++].reset(v = succ[i], node);
j++;
}
}
if (isSet[v] || j != 1) {
break;
}
node.count++;
stackIndex--;
uu = v;
}
}

node.start = false;
} else {
if (node.count > 1) {
node.tot = sum(node.tot, mul(node.sumNode, node.parzialInv), MOD - node.parz2);
node.parzialInv = sum(node.parzialInv, mul(node.count-1, node.sumNode));
} else {
node.tot = sum(node.tot, mul(node.sumNode, node.parzialInv), MOD - node.parz2);
}
if (isSet[node.u]) {
node.sumNode += node.u+1;
}
if (node.u != u) {
NodeDfs nodeP = node.parent;
nodeP.sumNode = sum(nodeP.sumNode, node.sumNode);
nodeP.parzialInv = sum(nodeP.parzialInv, node.parzialInv, node.sumNode);
nodeP.parz2 = sum(nodeP.parz2, mul(node.parzialInv + node.sumNode, node.sumNode));
if (isSet[nodeP.u]) {
nodeP.tot = sum(nodeP.tot, node.tot, mul(node.sumNode + node.parzialInv, (nodeP.u + 1)));
} else {
nodeP.tot = sum(nodeP.tot, node.tot);
}
}

stackIndex--;
}
}

return root;
}

static final int MAX_SIMPLY = 3;

public static void main(String[] args) throws IOException {
BufferedWriter bw = new BufferedWriter(new FileWriter(System.getenv("OUTPUT_PATH")));

int n = Integer.parseInt(st.nextToken());
int q = Integer.parseInt(st.nextToken());

nxt = new int[2 * n];
succ = new int[2 * n];
ptr = new int[n];
dep = new int[n];
parent = new int[n];
nodes = new NodeDfs[n];
for (int i = 0; i  <  n; i++) {
nodes[i] = new NodeDfs();
}

for (int i = 0; i  <  n - 1; i++) {
int u = Integer.parseInt(st.nextToken()) - 1;
int v = Integer.parseInt(st.nextToken()) - 1;
if (u  <  v) {
} else {
}
}
bfsDeep(0);

visited = new boolean[n];
isSet = new boolean[n];

for (int h = 1; h  < = q; h++) {
int k = Integer.parseInt(st.nextToken());
set = new int[k];
if (k >= MAX_SIMPLY) {
Arrays.fill(isSet, false);
}
for (int i = 0; i  <  k; i++) {
int u = Integer.parseInt(st.nextToken()) - 1;
isSet[u] = true;
set[i] = u;
}

long result = 0;
if (k  <  MAX_SIMPLY) {
for (int i = 0; i < k - 1; i++) {
int x = set[i];
for (int j = i + 1; j  <  k; j++) {
int y = set[j];
int z = lowestCommonAncestor(x, y);
int dist = dep[y] + dep[x] - 2 * dep[z];
result = sum(result, mul(x + 1, y + 1, dist));
}
}
} else {
Arrays.fill(visited, true);
Arrays.sort(set);
int x = set[set.length -1];
for (int i = k-2; i >= 0; i--) {
if (visited[set[i]]) {
x = lowestCommonAncestorVis(x, set[i]);
}
}
NodeDfs node = dfs(x);
result = node.tot;
}
bw.write(String.valueOf(result));
bw.newLine();
}

bw.close();
br.close();
}
}
``````
Copy The Code &

### #4 Code Example with Python Programming

```Code - Python Programming```

``````
from collections import Counter, defaultdict

MOD = 10**9 + 7

return (int(x) for x in input().split())

def mul(x, y):
return (x * y) % MOD

return sum(args) % MOD

def sub(x, y):
return (x - y) % MOD

# Construct adjacency list of the tree

for _ in range(n - 1):

# Construct element to set mapping {element: [sets it belongs to]}
elements = {v: set() for v in adj_list}

for set_no in range(q):

# Do BFS to find parent for each node and order them in reverse depth
current = [root]
current_depth = 0
order = []
parent = {root: None}
depth = {root: current_depth}

while current:
current_depth += 1
order.extend(current)
nxt = []
for node in current:
if neighbor not in parent:
parent[neighbor] = node
depth[neighbor] = current_depth
nxt.append(neighbor)

current = nxt

# Process nodes in the order created above
score = Counter()
# {node: {set_a: [depth, sum of nodes, flow]}}
state = {}
for node in reversed(order):
states = [state[neighbor] for neighbor in adj_list[node] if neighbor != parent[node]]
largest = {s: [depth[node], node, 0] for s in elements[node]}

if states:
max_index = max(range(len(states)), key=lambda x: len(states[x]))
if len(states[max_index]) > len(largest):
states[max_index], largest = largest, states[max_index]

sets = defaultdict(list)
for cur_state in states:
for set_no, v in cur_state.items():
sets[set_no].append(v)

for set_no, states in sets.items():
if len(states) == 1 and set_no not in largest:
largest[set_no] = states[0]
continue

if set_no in largest:
states.append(largest.pop(set_no))

total_flow = 0
total_node_sum = 0

for node_depth, node_sum, node_flow in states:
flow_delta = mul(node_depth - depth[node], node_sum)
total_node_sum += node_sum

set_score = 0

for node_depth, node_sum, node_flow in states:
node_flow = add(mul(node_depth - depth[node], node_sum), node_flow)
diff = mul(sub(total_flow, node_flow), node_sum)