Algorithm


Problem Name: Data Structures - Kundu and Tree

Problem Link: https://www.hackerrank.com/challenges/kundu-and-tree/problem?isFullScreen=true

In this HackerRank in Data Structures - Kundu and Tree solutions

Kundu is true tree lover. Tree is a connected graph having N vertices and N-1 edges. Today when he got a tree, he colored each edge with one of either red(r) or black(b) color. He is interested in knowing how many triplets(a,b,c) of vertices are there , such that, there is atleast one edge having red color on all the three paths i.e. from vertex a to b, vertex b to c and vertex c to a . Note that (a,b,c), (b,a,c) and all such permutations will be considered as the same triplet.

If the answer is greater than 109 + 7, print the answer modulo (%) 109 + 7.

Input Format
The first line contains an integer N, i.e., the number of vertices in tree.
The next N-1 lines represent edges: 2 space separated integers denoting an edge followed by a color of the edge. A color of an edge is denoted by a small letter of English alphabet, and it can be either red(r) or black(b).

Output Format
Print a single number i.e. the number of triplets.

Constraints
1 ≤ N ≤ 105
A node is numbered between 1 to N.

Sample Input

5
1 2 b
2 3 r
3 4 r
4 5 b

Sample Output

4

Explanation

Given tree is something like this.
image

(2,3,4) is one such triplet because on all paths i.e 2 to 3, 3 to 4 and 2 to 4 there is atleast one edge having red color.
(2,3,5), (1,3,4) and (1,3,5) are other such triplets.
Note that (1,2,3) is NOT a triplet, because the path from 1 to 2 does not have an edge with red color.

 

 

Code Examples

#1 Code Example with C Programming

Code - C Programming


#include <stdint.h>
#include <stdio.h>
struct node {
  struct node* parent;
  int size;  
};

enum { max_nodes = 100000 };
struct node nodes[max_nodes];
int64_t sums[2][max_nodes];

void init() {
  for (int i = 0; i  <  max_nodes; i++) {
    nodes[i].parent = &nodes[i];
    nodes[i].size = 1;
  }
}

struct node* find(struct node* node) {
  while (node->parent != node) {
    struct node* parent = node->parent;
    node->parent = parent->parent;
    node = parent;
  }
  return node;
}

struct node* merge(struct node* l, struct node* r) {
  l = find(l);
  r = find(r);
  if (l == r) return l;
  // Rearrange such that l is the larger of the two.
  if (l -> size  <  r -> size) {
    struct node* temp = l;
    l = r;
    r = temp;
  }
  r->parent = l;
  l->size += r->size;
  return l;
}

int main() {
  init();
  
  int n;
  scanf("%d", &n);
  for (int i = 0; i  <  n; i++) {
    int a, b;
    char c;
    scanf("%d %d %c", &a, &b, &c);
    if (c == 'b') merge(&nodes[a - 1], &nodes[b - 1]);
  }
  // Remove all nodes which aren't roots.
  int j = 0;
  for (int i = 0; i  <  n; i++) {
    if (nodes[i].parent == &nodes[i]) {
      nodes[j].size = nodes[i].size;
      nodes[j].parent = &nodes[j];
      j++;
    }
  }
  const int num_clusters = j;
  // For each i in [0..num_clusters), compute the sum of sizes[i..num_clusters).
  sums[0][num_clusters] = 0;
  for (int i = num_clusters - 1; i >= 0; i--) {
    sums[0][i] = sums[0][i + 1] + nodes[i].size;
  }
  sums[1][num_clusters] = 0;
  for (int i = num_clusters - 1; i >= 0; i--) {
    sums[1][i] = sums[1][i + 1] + sums[0][i + 1] * nodes[i].size;
  }
  // Iterate over triplets of clusters and count the number of triplets that
  // can be constructed from that triplet of clusters.
  int64_t total = 0;
  for (int a = 0; a  <  num_clusters; a++) {
    total += sums[1][a + 1] * nodes[a].size;
  }
  printf("%d\n", (int)(total % 1000000007));
}
Copy The Code & Try With Live Editor

#2 Code Example with C++ Programming

Code - C++ Programming


#include<iostream>
#include <fstream>
#include <string>
#include <cstdio>
#include <memory.h>
#include <vector>
#include <sstream>
#include <algorithm>
#include <set>
#include <map>
#include <queue>
#include <complex>
 
using namespace std;
 
 
#define REP(a,b) for (int a = 0; a  <  (int)(b); ++a)
#define FOR(a,b,c) for (int a = (b); a                                                                                                                                              <  (int)(c); ++a)

const int MAXN = 100010;
 
vector  < int> tree[MAXN];
char vis[MAXN];

int main() {
    int n, a, b;
    long long res, bad;
    char s[4];
    
    scanf("%d", &n);
    
    REP(i,n-1) {
        scanf("%d%d%s", &a, &b, s);
        if (s[0] == 'b') {
            tree[a-1].push_back(b-1);
            tree[b-1].push_back(a-1);
        }
    }
    
    res = n;
    res = res*(res-1)*(res-2);
    res /= 6;
    
    memset(vis, 0, sizeof(vis));
    
    REP(i,n) if (vis[i] == 0) {
        int v;
        long long cs = 0;
        queue  < int> q;
        q.push(i); vis[i] = 1; cs = 1;
        while (!q.empty()) {
            v = q.front(); q.pop();
            REP(j,tree[v].size()) {
                int next = tree[v][j];
                if (vis[next] == 0) {
                    q.push(next);
                    vis[next] = 1;
                    ++cs;
                }
            }
        }
        
        bad = cs*(cs-1)*(n-cs)/2;
        bad += cs*(cs-1)*(cs-2)/6;
        res -= bad;
    }
    
    printf("%d\n", (int)(res%1000000007));

    return 0;
} 
Copy The Code & Try With Live Editor

#3 Code Example with Java Programming

Code - Java Programming


import java.io.*;
import java.util.*;

public class Solution {

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        Node[] nodes = new Node[n + 1];
        for(int i = 1; i  < = n; i++){
            nodes[i] = new Node();
        }
        for(int i = 1; i  <  n; i++){
            int i1 = sc.nextInt();
            int i2 = sc.nextInt();
            String c = sc.next();
            if(c.equals("b")){
                merge(nodes[i1], nodes[i2]);
            }
        }
        
        long sum = (long)n * (n - 1) * (n - 2) / 6;
        for(int i = 1; i <= n; i++){
            Node r = nodes[i].getRoot();
            if(r.size == 1) continue;
            if(r.seen) continue;
            r.seen = true;
            
            sum -= (long)r.size * (r.size - 1) * (r.size - 2) / 6;
            sum -= (long)r.size * (r.size - 1) / 2 * (n - r.size);
        }
        
        System.out.println(sum % 1000000007);
        
    }
    
    
    static void merge(Node n1, Node n2){
        Node r1 = n1.getRoot();
        Node r2 = n2.getRoot();
        if(r1 == r2> return;
        if(r1.size > r2.size){
            r1.size += r2.size;
            r2.parent = r1;
        } else {
            r2.size += r1.size;
            r1.parent = r2;
        }
    }
    
    static class Node{
        
        Node(){
            size = 1;
        }
        
        boolean seen;
        int size;
        Node parent;
        
        Node getRoot(){
            Node r = this;
            while(r.parent != null) r = r.parent;
            return r;
        }
        
    }
}
Copy The Code & Try With Live Editor

#4 Code Example with Javascript Programming

Code - Javascript Programming


const CounterMap = class {
    constructor() {
        this.map = new Map();
    }

    increment(k) {
        if(this.map.has(k)) {
            this.map.set(k, this.map.get(k) + 1);
        } else {
            this.map.set(k, 1);
        }
    }

    decrement(k) {
        if(this.map.has(k)) {
            if(this.map.get(k) <= 1) {
                this.map.delete(k);
            } else {
                this.map.set(k, this.map.get(k) - 1);
            }
        }
    }
}

const UF = class {
    constructor(len, counters) {
        this.parents = Array(len + 1).fill(null).map((e, i> => i);
        this.sizes = Array(len + 1).fill(1);
        this.counters = counters;
    }

    find(a) {
        while(a !== this.parents[a]) {
            a = this.parents[a];
        }
        return a;
    }

    union(a, b) {
        const rootOfA = this.find(a);
        const rootOfB = this.find(b);
        if(rootOfA !== rootOfB) {
            const sizeOfA = this.sizes[rootOfA];
            const sizeOfB = this.sizes[rootOfB];

            this.counters.decrement(sizeOfA);
            this.counters.decrement(sizeOfB);

            if(sizeOfA < sizeOfB) {
                this.parents[rootOfA] = rootOfB;
                this.sizes[rootOfB] += this.sizes[rootOfA];
                this.counters.increment(this.sizes[rootOfB]);
            } else {
                this.parents[rootOfB] = rootOfA;
                this.sizes[rootOfA] += this.sizes[rootOfB];
                this.counters.increment(this.sizes[rootOfA]);
            }
        }
    }
}

const nC2 = (n> => {
    return (n * (n - 1)) / 2;
}

const nC3 = (n) => {
    return (n * (n - 1) * (n - 2)) / 6;
}

const kunduTree = (tree) => {
    // console.log(tree);
    const len = tree.length + 1;
    const counters = new CounterMap();
    const uf = new UF(len, counters);

    tree.forEach(edge => {
        if(edge[2] === 'b') {
            uf.union(edge[0], edge[1]);
        }
    });

    const sizesCounts = Array.from(counters.map.entries());
    let result = nC3(len);
    sizesCounts.forEach(sizeCount => {
        const [size, count] = sizeCount;
        result -= nC3(size) * count;
        result -= nC2(size) * (len - size) * count;
    });
    result = result % 1000000007;
    console.log(result);
    return result;
}

let inputString = '';
let currentLine = 0;

const readLine = () => {
    return inputString[currentLine++];
}

const main = (data) => {
    inputString = data.split('\n');
    const n = parseInt(readLine());

    let tree = Array(n - 1);

    for (let treeRowItr = 0; treeRowItr < n - 1; treeRowItr++) {
        tree[treeRowItr] = readLine().split(' ');
        tree[treeRowItr][0] = parseInt(tree[treeRowItr][0]);
        tree[treeRowItr][1] = parseInt(tree[treeRowItr][1]);
    }
    let result = kunduTree(tree);
    return [result];
}

process.stdin.resume();
process.stdin.setEncoding("ascii");
_input = "";
process.stdin.on("data", function (input) {
    _input += input;
});

process.stdin.on("end", function () {
   main(_input);
}>;
Copy The Code & Try With Live Editor

#5 Code Example with Python Programming

Code - Python Programming


def find(x):
    while uf[x] >= 0:
        if uf[uf[x]] >= 0:
            uf[x] = uf[uf[x]]
        x = uf[x]
    return x

n = int(input())
uf = [-1]*n
for i in range(n-1):
    x, y, col = input().split()
    if col == 'b':
        x, y = find(int(x)-1), find(int(y)-1)
        if uf[x] > uf[y]:
            x, y = y, x
        uf[x] += uf[y]
        uf[y] = x
a, b, c = 0, 0, 0
for i in range(n):
    if uf[i] < 0:
        c -= b*uf[i]
        b -= a*uf[i]
        a -= uf[i]
print(c % int(1e9+7))
Copy The Code & Try With Live Editor
Advertisements

Demonstration


Previous
[Solved] Median Updates solution in Hackerrank - Hacerrank solution C, C++, java,js, Python
Next
[Solved] Find the Running Median solution in Hackerrank - Hacerrank solution C, C++, java,js, Python