Algorithm


Problem Name: Data Structures - Balanced Forest

Problem Link: https://www.hackerrank.com/challenges/balanced-forest/problem?isFullScreen=true

In this HackerRank in Data Structures - Balanced Forest solutions

Greg has a tree of nodes containing integer data. He wants to insert a node with some non-zero integer value somewhere into the tree. His goal is to be able to cut two edges and have the values of each of the three new trees sum to the same amount. This is called a balanced forest. Being frugal, the data value he inserts should be minimal. Determine the minimal amount that a new node can have to allow creation of a balanced forest. If it's not possible to create a balanced forest, return -1.

For example, you are given node values c = [15,12,8,14,13] and edges = [[1,2],[1,3],[1,4],[4,5] It is the following tree:

image

The blue node is root, the first number in a node is node number and the second is its value. Cuts can be made between nodes 1 and 3 and nodes 1 and 4 to have three trees with sums 27,27 and 8. Adding a new node w of c[w] = 19 to the third tree completes the solution.

Function Description

Complete the balancedForest function in the editor below. It must return an integer representing the minimum value of c[w] that can be added to allow creation of a balanced forest, or -1 if it is not possible.

balancedForest has the following parameter(s):

 

  • c: an array of integers, the data values for each node
  • edges: an array of 2 element arrays, the node pairs per edge

Input Format

The first line contains a single integer, q, the number of queries. Each of the following q sets of lines is as follows:

  • The first line contains an integer, n, he number of nodes in the tree.
  • The second line contains n pace-separated integers describing the respective values of c[1],c[2], ... , c[n], where each c[i] enotes the value at node i.
  • Each of the following n - 1 lines contains two space-separated integers, x[j] and y[j], describing edge j connecting nodes x[j] and y[j].

Constraints

  • 1 <= q <= 5
  • 1 <= n <= 5 * 10**4
  • 1 <= c[i] <= 10**9
  • Each query forms a valid undirected tree

Output Format

For each query, return the minimum value of the integer

. If no such value exists, return

instead.

Sample Input

2
5
1 2 2 1 1
1 2
1 3
3 5
1 4
3
1 3 5
1 3
1 2

Sample Output

2
-1

Explanation

We perform the following two queries:

  1. The tree initially looks like this:

image
Greg can add a new node w = 6 with c[w] = 2 and create a new edge connecting nodes 4 and 6. Then he cuts the edge connecting nodes 1 and 4 and the edge connecting nodes 1 and 3. We now have a three-tree balanced forest where each tree has a sum of 3.

 

image

In the second query, it's impossible to add a node in such a way that we can split the tree into a three-tree balanced forest so we return 3.

 

 

Code Examples

#1 Code Example with C Programming

Code - C Programming


#include <assert.h>
#include <limits.h>
#include <math.h>
#include <stdbool.h>
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

char* readline();
char** split_string(char*);

struct Node
{
    int64_t sum; // sum of all node in this tree.
    int64_t testSum;
    int data;
    int parent;
    int *child;
    int childCnt;
};
typedef struct Node Node;

void sumTree(Node *root, int index)
{
    int i;
    root[index].sum = root[index].data;
    for (i = 0; i  <  root[index].childCnt; i++)
    {
        int child = root[index].child[i];
        if (child == root[index].parent) continue;
        sumTree(root, child);
        root[index].sum += root[child].sum;
    }
    root[index].testSum = root[index].sum;
}

void updateTree(Node *tree, int root, int parent)
{
    tree[root].parent = parent;
    int i;
    for (i = 0; i  <  tree[root].childCnt; i++)
    {
        if (tree[root].child[i] == parent) continue;
        updateTree(tree, tree[root].child[i], root);
    }
}

int64_t childSum(Node *tree, int root, int  branch_root, int64_t targetSum, bool *bFound)
{
    int i;
    int64_t currSum = 0;

    if (tree[root].testSum  <  targetSum) return tree[root].testSum;

    for (i = 0; (i  <  tree[root].childCnt) && (*bFound==0); i++)
    {
        int child = tree[root].child[i];
        if (child == tree[root].parent) continue;
        if (child == branch_root) continue;
        int64_t chSum = childSum(tree, child, branch_root, targetSum, bFound);

        if (chSum == targetSum)
        {
            *bFound = 1;
            break;
        }

        currSum += chSum;
    }
    return currSum + tree[root].data;
}

// Complete the balancedForest function below.
int64_t balancedForest(int c_count, int* c, int edges_rows, int edges_columns, int** edges) {
    int i, j;
    // build tree.
    Node *tree = (Node *)calloc(c_count, sizeof(Node));
    for (i = 0; i  <  c_count; i++)
    {
        tree[i].data = c[i];
        tree[i].childCnt = 0;
        tree[i].child = NULL;
        tree[i].parent = -1;
        tree[i].sum = 0;
    }
    for (i = 0; i  <  edges_rows; i++)
    {
        int pa = edges[i][0] - 1;
        int ch = edges[i][1] - 1;
        tree[pa].child = (int *)realloc(tree[pa].child, (1 + tree[pa].childCnt) * sizeof(int));
        tree[pa].child[tree[pa].childCnt] = ch;
        tree[pa].childCnt++;
        tree[ch].child = (int *)realloc(tree[ch].child, (1 + tree[ch].childCnt) * sizeof(int));
        tree[ch].child[tree[ch].childCnt] = pa;
        tree[ch].childCnt++;
    }
    // Now update the parent_node;
    int root = 0; // pick the first one as root.
    updateTree(tree, root, -1);
    sumTree(tree, root);

    int64_t treeSum = 0;
    treeSum = tree[root].sum;
    
    int64_t maxSum = (treeSum - 1) / 2 + 1;
    int64_t minSum = treeSum / 3 - 1;
    int64_t minW = -1;
    for (i = 0; i  <  c_count; i++)
    {
        if (i == root) continue;

        //if (tree[i].sum >= minSum && tree[i].sum  < = maxSum)
        {
            int64_t sumI = tree[i].sum;

            // Check for special case.
            int64_t sumJ = treeSum - sumI;

            if (sumI == sumJ)
            {
                if (minW < 0 || minW>sumI) minW = sumI;
            }
            else
            {
                bool bFound = 0;
                int64_t targetSum;
                int searchRoot = root;
                int branchRoot = i;
                int64_t w = 0;
                if (sumI > sumJ)
                {
                    targetSum = sumI;
                    sumI = sumJ;
                    sumJ = targetSum;
                    searchRoot = i;
                    branchRoot = root;
                }
                if ((sumI << 1)  <  sumJ)
                {
                    targetSum = sumJ >> 1;
                    if (sumJ - targetSum != targetSum) continue;
                    w = targetSum - sumI;
                }
                else
                {
                    targetSum = sumI;
                    w = targetSum - (sumJ - sumI);
                }

                if (minW >= 0 && minW  <  w) continue;

                // search in the main tree
                
                // first, update the testSum;
                if (searchRoot == root)
                {
                    int curr = tree[branchRoot].parent;
                    int64_t branchSum = tree[branchRoot].sum;
                    while (curr != -1)
                    {
                        tree[curr].testSum -= branchSum;
                        curr = tree[curr].parent;
                    }
                }

                childSum(tree, searchRoot, branchRoot, targetSum, &bFound);
                if (bFound)
                {
                    if (minW == -1 || minW > w) minW = w;
                }

                // last, restore the testSum
                if (searchRoot == root)
                {
                    int curr = tree[branchRoot].parent;
                    while (curr != -1)
                    {
                        tree[curr].testSum = tree[curr].sum;
                        curr = tree[curr].parent;
                    }
                }
            }
        }
    }
    return minW;
}

int main()
{
    FILE* fptr = fopen(getenv("OUTPUT_PATH"), "w");

    char* q_endptr;
    char* q_str = readline();
    int q = strtol(q_str, &q_endptr, 10);

    if (q_endptr == q_str || *q_endptr != '\0') { exit(EXIT_FAILURE); }

    for (int q_itr = 0; q_itr  <  q; q_itr++) {
        char* n_endptr;
        char* n_str = readline();
        int n = strtol(n_str, &n_endptr, 10);

        if (n_endptr == n_str || *n_endptr != '\0') { exit(EXIT_FAILURE); }

        char** c_temp = split_string(readline());

        int* c = malloc(n * sizeof(int));

        for (int i = 0; i  <  n; i++) {
            char* c_item_endptr;
            char* c_item_str = *(c_temp + i);
            int c_item = strtol(c_item_str, &c_item_endptr, 10);

            if (c_item_endptr == c_item_str || *c_item_endptr != '\0') { exit(EXIT_FAILURE); }

            *(c + i) = c_item;
        }

        int c_count = n;

        int** edges = malloc((n - 1) * sizeof(int*));

        for (int i = 0; i  <  n - 1; i++) {
            *(edges + i) = malloc(2 * (sizeof(int)));

            char** edges_item_temp = split_string(readline());

            for (int j = 0; j  <  2; j++) {
                char* edges_item_endptr;
                char* edges_item_str = *(edges_item_temp + j);
                int edges_item = strtol(edges_item_str, &edges_item_endptr, 10);

                if (edges_item_endptr == edges_item_str || *edges_item_endptr != '\0') { exit(EXIT_FAILURE); }

                *(*(edges + i) + j) = edges_item;
            }
        }

        int edges_rows = n - 1;
        int edges_columns = 2;

        int64_t result = balancedForest(c_count, c, edges_rows, edges_columns, edges);

        fprintf(fptr, "%lld\n", result);
    }

    fclose(fptr);

    return 0;
}

char* readline() {
    size_t alloc_length = 1024;
    size_t data_length = 0;
    char* data = malloc(alloc_length);

    while (true) {
        char* cursor = data + data_length;
        char* line = fgets(cursor, alloc_length - data_length, stdin);

        if (!line) {
            break;
        }

        data_length += strlen(cursor);

        if (data_length  <  alloc_length - 1 || data[data_length - 1] == '\n') {
            break;
        }

        alloc_length <<= 1;

        data = realloc(data, alloc_length);

        if (!line) {
            break;
        }
    }

    if (data[data_length - 1] == '\n') {
        data[data_length - 1] = '\0';

        data = realloc(data, data_length);
    } else {
        data = realloc(data, data_length + 1);

        data[data_length] = '\0';
    }

    return data;
}

char** split_string(char* str) {
    char** splits = NULL;
    char* token = strtok(str, " ");

    int spaces = 0;

    while (token) {
        splits = realloc(splits, sizeof(char*) * ++spaces);

        if (!splits) {
            return splits;
        }

        splits[spaces - 1] = token;

        token = strtok(NULL, " ");
    }

    return splits;
}
Copy The Code & Try With Live Editor

#2 Code Example with C++ Programming

Code - C++ Programming


#include <iostream>
#include <cstdio>
#include <vector>
#include <algorithm>
#include <string>
#include <set>
#include <map>
#include <queue>
#include <stack>
#include <deque>
#include <cassert>
#include <stdlib.h>

using namespace std;


typedef long long ll;

const ll INF = (ll) 1e18;
const int N = (int) 5e4 + 10;

vector<int> g[N];
ll c[N];
ll f[N];
ll res = INF;
ll tot = 0;
bool was[N];

void upd(ll a, ll b, ll c) {
    if (a == b && c  < = a)
        res = min(res, a - c);
    if (a == c && b <= a)
        res = min(res, a - b);
    if (b == c && a  < = b)
        res = min(res, b - a); 
}

set* unite(set* a, set* b) {
    if (a->size() > b->size())
        swap(a, b);
    for (ll x : *a) {
        if (b->count(tot - 2 * x))
            upd(tot - 2 * x, x, x);
        if (b->count(x))
            upd(x, x, tot - 2 * x);
        if ((tot - x) % 2 == 0 && b->count((tot - x) / 2))
            upd((tot - x) / 2, x, (tot - x) / 2);
    }
    for (ll x : *a) {
        b->insert(x);
    }
    delete a;
    return b;
}

set < ll>* dfs(int v) {
    was[v] = true;
    f[v] = c[v];
    set < ll>* sv = new set();
    for (int to : g[v])
        if (!was[to]) {
            set* sto = dfs(to);
            f[v] += f[to];
            sv = unite(sv, sto);
        }
    if (f[v] % 2 == 0 && sv->count(f[v] / 2))
        upd(f[v] / 2, f[v] / 2, tot - f[v]);
    if (sv->count(tot - f[v]))
        upd(tot - f[v], 2 * f[v] - tot, tot - f[v]);
    if (sv->count(2 * f[v] - tot))
        upd(2 * f[v] - tot, tot - f[v], tot - f[v]);
    sv->insert(f[v]);
    return sv;
}

void solve() {
    int n;
    cin >> n;
    for (int i = 0; i  <  N; i++) {
        was[i] = false;
        g[i].clear();
        c[i] = 0;
    }
    tot = 0;
    res = INF;
    for (int i = 0; i  <  n; i++) {
        cin >> c[i];
        tot += c[i];
    }
    for (int i = 0; i  <  n - 1; i++) {
        int x, y;
        cin >> x >> y;
        --x;
        --y;
        g[x].push_back(y);
        g[y].push_back(x);
    }
    set < ll>* s = dfs(0);
    //for (int i = 0; i  <  n; i++)
    //    cerr << f[i] << " ";
    //cerr << endl;
    delete s;
    if (res == INF)
        res = -1;
    cout << res << endl;
    // cerr << "----------" << endl;
}

int main() {
    ios_base::sync_with_stdio(0);
    int p;
    cin >> p;
    while (p--) {
        solve();
    }
    return 0;
}
Copy The Code & Try With Live Editor

#3 Code Example with Java Programming

Code - Java Programming


import java.io.*;
import java.util.*;

public class Solution {

    public static void main(String[] args) {
        /* Enter your code here. Read input from STDIN. Print output to STDOUT. Your class should be named Solution. */
        Scanner scanner = new Scanner(System.in) ;
        int q = scanner.nextInt() ;
        if (q  <  1 || q > 5) {
            throw new IllegalArgumentException("1<=Q<=50000") ;
        }
        while(q>0) {
            int n = scanner.nextInt() ;
            if (n  <  1 || n > 50000) {
                throw new IllegalArgumentException("1<=N<=50000") ;
            }
            List nodes = new ArrayList(n) ;
			List < GraphNode> graph = new ArrayList(n) ;
			List tree = new ArrayList(n) ;
            for(int i=1;i < =n;i++) {
				Node node = new Node(i,scanner.nextInt()) ;
                nodes.add(node) ;
				graph.add(new GraphNode(node)) ;
				tree.add(new TreeNode(node)) ;
            }
            List < Edge> edges = new ArrayList() ;
            for(int i=0;i < n-1;i++) {
                Edge edge = new Edge(scanner.nextInt(), scanner.nextInt()) ;
                edges.add(edge) ;
                addEdge(graph,edge) ;
            }
			graphsToTree(graph.get(0),tree,new HashSet < Node>()) ;
            //System.out.println(tree.get(0)) ;
            //System.out.println(edges) ;
            System.out.println(findMinCw(tree, edges, tree.get(0))) ;
            q-- ;
        }
    }
    
	private static class Node {
		int nodeId ;
        long coins=0 ;
		
		public Node(int nodeId, long coins) {
            this.nodeId = nodeId ;
            this.coins = coins ;
        }
        public int getNodeId() {
            return nodeId ;
        }
        public long getCoins() {
            return coins ;
        }
		
		public boolean equals(Node node) {
            return (node!=null && node.getNodeId() == nodeId) ;
        }
        
        public int hashCode() {
            return nodeId ;
        }
		public String toString() {
			return String.format("nodeId:%s, coins:%d",nodeId,coins) ;
		}
	}
	
	private static class GraphNode {
		
		final Node node ;
		Set < GraphNode> connectedNodes = new HashSet() ;
		
		public GraphNode(Node node) {
			this.node = node ;
		}
		public Node getNode() {
			return node ;
		}
		public void addConnection(GraphNode addNode) {
			connectedNodes.add(addNode) ;
		}
		public void removeConnection(GraphNode removeNode) {
			connectedNodes.remove(removeNode) ;
		}
		public Set < GraphNode> getConnectedNodes() {
			return connectedNodes ;
		}        
        public String toString() {
			return String.format("Node:%s, connectedNodes[%s]",node.toString(),connectedNodesToString()) ;
        }
		private String connectedNodesToString() {
			StringBuilder builder = new StringBuilder() ;
			for(GraphNode node : connectedNodes) {
				builder.append("Node:").append(node).append(",") ;
			}
			return builder.toString() ;
		}
    }
	
	private static class TreeNode {
		final Node node ;
		TreeNode parentNode ;
		Set < TreeNode> childNodes = new HashSet() ;
		long totalCoins ;
		
		public TreeNode(Node node) {
			this.node = node ;
			totalCoins = node.getCoins() ;
		}
		public Node getNode() {
			return node ;
		}
		public long getTotalCoins() {
			return totalCoins ;
		}
		public Set < TreeNode> getChildNodes() {
			return childNodes ;
		}
		
		public void setParent(TreeNode node) {
            if(parentNode!=null && node!=null){
                throw new RuntimeException("Multiple parent is not supported. parent:"+parentNode+" current:"+this+" new parent:"+node);
            }
            parentNode = node ;
        }
		
		public void addChildNode(TreeNode node) {
            childNodes.add(node) ;
            totalCoins+=node.getTotalCoins() ;
            node.setParent(this) ;
            if (parentNode!=null) {
                parentNode.addChildCoins(node.getTotalCoins()) ;
            }
        }
		
		public void addChildCoins(long childCoins) {
            totalCoins += childCoins ;
            if (parentNode!=null) {
                parentNode.addChildCoins(childCoins) ;
            }
        }
		
		public void removeChildNode(TreeNode node) {
            childNodes.remove(node) ;
            totalCoins-=node.getTotalCoins() ;
            node.setParent(null) ;
            if (parentNode!=null) {
                parentNode.removeChildCoins(node.getTotalCoins()) ;
            }
        }
		
		public void removeChildCoins(long childCoins) {
            totalCoins -= childCoins ;
            if (parentNode!=null) {
                parentNode.removeChildCoins(childCoins) ;
            }
        }
		
		public boolean isParentOf(TreeNode childNode) {
			return childNodes.contains(childNode) ;
		}
		
		public boolean isRoot() {
			return parentNode == null ;
		}
		
		public String toString() {
			return String.format("Node:%s, totalCoins:%d, parentNode:%s, childNodes:[%s]",node.toString(), totalCoins, 
				(parentNode!=null)?parentNode.getNode().toString():"NULL",childNodes.toString()) ;
		}
	}
	
    private static class Edge {
        int node1 ;
        int node2 ;
        
        public Edge(int node1, int node2) {
			this.node1 = node1 ;
			this.node2 = node2 ;
        }

		public int getNode1() {
			return node1 ;
		}

		public int getNode2() {
			return node2 ;
		}
        
		public void swapNode() {
			int tmpNode = node1 ;
			node1 = node2 ;
			node2 = tmpNode ;
		}
		
        public String toString() {
            return "node1:"+node1+" node2:"+node2 ;
        }
    }
    
    private static void addEdge(List < GraphNode> nodes, Edge edge) {
        nodes.get(edge.getNode1()-1).addConnection(nodes.get(edge.getNode2()-1)) ;
        nodes.get(edge.getNode2()-1).addConnection(nodes.get(edge.getNode1()-1)) ;
    }
    
    private static void removeEdge(List < GraphNode> nodes, Edge edge) {
        nodes.get(edge.getNode1()-1).removeConnection(nodes.get(edge.getNode2()-1)) ;
        nodes.get(edge.getNode2()-1).removeConnection(nodes.get(edge.getNode1()-1)) ;
    }
	
	private static void graphsToTree(GraphNode graph, List < TreeNode> treeNodes, Set visitedNode) {
		TreeNode treeNode = treeNodes.get(graph.getNode().getNodeId()-1) ;
		visitedNode.add(graph.getNode()) ;
		for(GraphNode connectedNode : graph.getConnectedNodes()) {
			if (!visitedNode.contains(connectedNode.getNode())) {
				treeNode.addChildNode(treeNodes.get(connectedNode.getNode().getNodeId()-1)) ;
				visitedNode.add(connectedNode.getNode()) ;
				graphsToTree(connectedNode, treeNodes, visitedNode) ;
			}			
		}
	}
	
	private static TreeNode removeTreeEdge(List < TreeNode> nodes, Edge edge) {
		TreeNode node1 = nodes.get(edge.getNode1()-1) ;
		TreeNode node2 = nodes.get(edge.getNode2()-1) ;
		
		if (node1.isParentOf(node2)) {
			// node1 is parent of node2
			node1.removeChildNode(node2) ;			
			return node2 ;
		} else {
			node2.removeChildNode(node1) ;
			edge.swapNode() ;
			return node1 ;
		}
	}
	
	private static void addTreeEdge(List < TreeNode> nodes, Edge edge, TreeNode rootNode) {
		TreeNode node1 = nodes.get(edge.getNode1()-1) ;
		TreeNode node2 = nodes.get(edge.getNode2()-1) ;
		
		node1.addChildNode(node2) ;	
	}
	
	//DFS on subTree with expected value
	private static boolean findSubTreeWithValue(TreeNode searchRoot, TreeNode tree, long expectedValue) {
		if (searchRoot.getTotalCoins()  < = expectedValue || tree.getTotalCoins() <= expectedValue) {
			return false ;
		}
		for(TreeNode subTree : tree.getChildNodes()) {
			long subTreeCoins = subTree.getTotalCoins() ;
			long remainingCoins = searchRoot.getTotalCoins()-subTreeCoins ;
			
			if (subTreeCoins == expectedValue || remainingCoins==expectedValue) {
				return true ;
			}
			if (findSubTreeWithValue(searchRoot,subTree,expectedValue)) {
				return true ;
			}
		}
		return false ;
	}
	
	public static long findMinCw(List < TreeNode> nodes, List edges, TreeNode rootNode) {
        long minCw = -1 ;
        for (int i = 0; i < edges.size() ;i++) {
			Edge removeEdge1 = edges.get(i) ;
			TreeNode tree1 = removeTreeEdge(nodes,removeEdge1) ;
            
            long nodes1Coins = rootNode.getTotalCoins() ;
            long nodes2Coins = tree1.getTotalCoins()  ;

            long largeSetCoins, smallSetCoins ;
            TreeNode treeToSplit = null ;
			
            if (nodes1Coins == nodes2Coins) {
                long cw = nodes1Coins ;
                if (minCw  < 0 || cw < minCw) {
                    minCw = cw ;
                }
				addTreeEdge(nodes, removeEdge1, rootNode);
                continue ;
            } else if (nodes1Coins>nodes2Coins) {
                largeSetCoins = nodes1Coins ;
                smallSetCoins = nodes2Coins ;
				treeToSplit = rootNode ;
            } else {
                largeSetCoins = nodes2Coins ;
                smallSetCoins = nodes1Coins ;                
				treeToSplit = tree1 ;
            }

            long expectedCw = -1 ;
            long expectedCw1 = -1 ;
			long searchValue ;
            if (largeSetCoins%2 == 0) {
                expectedCw1 = largeSetCoins/2l - smallSetCoins ;
            }
            long expectedCw2 = smallSetCoins - (largeSetCoins - smallSetCoins) ;

            if (expectedCw1 >= 0 && expectedCw2 >=0) {
                expectedCw = Math.min(expectedCw1,expectedCw2) ;
            } else if (expectedCw1 >= 0) {
                expectedCw = expectedCw1 ;
            } else if (expectedCw2 >= 0) {
                expectedCw = expectedCw2 ;
            }
            
            if (expectedCw < 0 || (minCw >0 && expectedCw > minCw)) {
                addTreeEdge(nodes, removeEdge1, rootNode);
                continue ;
            }

			if (expectedCw == expectedCw1) {
				searchValue = largeSetCoins/2l ;
			} else {
				searchValue = smallSetCoins ;
			}

			if (findSubTreeWithValue(treeToSplit, treeToSplit, searchValue)) {
				if (minCw  < 0 || expectedCw < minCw) {
                    minCw = expectedCw ;
                }
			}
		
            addTreeEdge(nodes, removeEdge1, rootNode);
        }
        return minCw ;
    }

}
Copy The Code & Try With Live Editor

#4 Code Example with Javascript Programming

Code - Javascript Programming


'use strict';

const fs = require('fs');

process.stdin.resume();
process.stdin.setEncoding('utf-8');

let inputString = '';
let currentLine = 0;

process.stdin.on('data', inputStdin => {
    inputString += inputStdin;
});

process.stdin.on('end', function() {
    inputString = inputString.replace(/\s*$/, '')
        .split('\n')
        .map(str => str.replace(/\s*$/, ''));

    main();
});

function readLine() {
    return inputString[currentLine++];
}

// Complete the balancedForest function below.
function balancedForest(c, edges) {
    const nodes = c.map(cost => ({ cost, adj: [], visited: false, solved: false }));
    
    for(let [a,b] of edges) {
        nodes[a-1].adj.push(b-1);
        nodes[b-1].adj.push(a-1);
    }
    
    const dfs = n => {
        if (n.visited) return 0;
        n.visited = true;
        
        for (let a of n.adj)
            n.cost += dfs(nodes[a]);
        return n.cost;
    }
    
    const sum = dfs(nodes[0]);
    //console.log(sum, nodes);

    let min = sum;
    const excsum = {};
    const incsum = {};
    
    const solve = n => {
        if (n.solved) return;
        n.solved = true;
        
        const cost_a = 3 * n.cost - sum;
        const cost_b = (sum - n.cost) / 2 - n.cost;
        //console.log("solve", n, { incsum, excsum }, { min, cost_a, cost_b });

        // can split in two equal subtrees?
        if (sum % 2 === 0 && n.cost === (sum / 2)) min = Math.min(min, sum / 2);

        if (cost_a >= 0 && (
            excsum[n.cost] // another subtree with equal cost?
            || excsum[sum - 2 * n.cost] // another subtree with 1/3 cost
            || incsum[sum - n.cost]) // edge to remove
        ) min = Math.min(min, cost_a);

        if (cost_b >= 0 && (sum - n.cost) % 2 === 0) {
            if (excsum[(sum - n.cost) / 2] || incsum[(sum + n.cost) / 2]) 
                min = Math.min(min, cost_b);
        }

        incsum[n.cost] = true;
        for (let a of n.adj) solve(nodes[a]);
        delete incsum[n.cost];
        excsum[n.cost] = true;
    }
    
    solve(nodes[0]);
    return min === sum ? -1 : min;
}

function main() {
    const ws = fs.createWriteStream(process.env.OUTPUT_PATH);

    const q = parseInt(readLine(), 10);

    for (let qItr = 0; qItr  <  q; qItr++) {
        const n = parseInt(readLine(), 10);

        const c = readLine().split(' ').map(cTemp => parseInt(cTemp, 10));

        let edges = Array(n - 1);

        for (let i = 0; i  <  n - 1; i++) {
            edges[i] = readLine().split(' ').map(edgesTemp => parseInt(edgesTemp, 10));
        }

        const result = balancedForest(c, edges);

        ws.write(result + '\n');
    }

    ws.end();
}
Copy The Code & Try With Live Editor

#5 Code Example with Python Programming

Code - Python Programming


from operator import attrgetter
from itertools import groupby
from sys import stderr

class Node:
    def __init__(self, index, value):
        self.index = index
        self.value = value
        self.children = []
        
def readtree():
    size = int(input())
    values = readints()
    assert size == len(values)
    nodes = [Node(i, v) for i, v in enumerate(values)]
    for _ in range(size - 1):
        x, y = readints()
        nodes[x-1].children.append(nodes[y-1])
        nodes[y-1].children.append(nodes[x-1])
    return nodes

def readints():
    return [int(fld) for fld in input().strip().split()]

def findbestbal(nodes):
    if len(nodes) == 1:
        return -1
    rootify(nodes[0])
#    print([(n.index, n.value, n.totalval) for n in nodes], file=stderr)
    best = total = nodes[0].totalval
    dummynode = Node(None, None)
    dummynode.totalval = 0
    sortnode = []
    for k, g in groupby(sorted([dummynode] + nodes, key = attrgetter('totalval')), attrgetter('totalval')):
        sortnode.append(list(g))
    total = nodes[0].totalval
    for ihi, n in enumerate(sortnode):
        if 3 * n[0].totalval >= total:
            break
    else:
        assert False
    ilo = ihi - 1
    for ihi in range(ihi, len(sortnode)):
        hi = sortnode[ihi][0].totalval
        lo = sortnode[ilo][0].totalval
        while 2 * hi + lo > total:
            if lo == 0:
                return -1
            if (total - lo) % 2 == 0:
                x = (total - lo) // 2
                for lonode in sortnode[ilo]:
                    if uptototalval(lonode, x + lo):
                        return x - lo
            ilo -= 1
            lo = sortnode[ilo][0].totalval
        if len(sortnode[ihi]) > 1:
            return 3 * hi - total
        hinode = sortnode[ihi][0]
        if 2 * hi + lo == total:
            for lonode in sortnode[ilo]:
                if uptototalval(lonode, hi) != hinode:
                    return hi - lo
        y = total - 2 * hi
        if uptototalval(hinode, 2 * hi) or uptototalval(hinode, hi + y):
            return hi - y

def rootify(root):
    root.parent = root.jumpup = None
    root.depth = 0
    bfnode = [root]
    i = 0
    while i < len(bfnode):
        node = bfnode[i]
        depth = node.depth + 1
        jumpup = uptodepth(node, depth & (depth - 1))
        for child in node.children:
            child.parent = node
            child.children.remove(node)
            child.depth = depth
            child.jumpup = jumpup
            bfnode.append(child)
        i += 1
    for node in reversed(bfnode):
        node.totalval = node.value + sum(child.totalval for child in node.children)
            
def uptodepth(node, depth):
    while node.depth > depth:
        if node.jumpup.depth <= depth:
            node = node.jumpup
        else:
            node = node.parent
    return node
            
def uptototalval(node, totalval):
  try:
#    print('uptototalval(%s,%s)' % (node.index, totalval), file=stderr)
    while node.totalval < totalval:
        if node.parent is None:
            return None
        if node.jumpup.totalval <= totalval:
            node = node.jumpup
        else:
            node = node.parent
#        print((node.index, node.totalval), file=stderr)
    if node.totalval == totalval:
        return node
    else:
        return None
  except Exception:
    return None
    
ncases = int(input())
for _ in range(ncases):
    print(findbestbal(readtree()))
Copy The Code & Try With Live Editor
Advertisements

Demonstration


Previous
[Solved] Square-Ten Tree solution in Hackerrank - Hacerrank solution C, C++, java,js, Python
Next
[Solved] Jenny's Subtrees solution in Hackerrank - Hacerrank solution C, C++, java,js, Python