Algorithm


Problem Name: Data Structures - Number Game on a Tree

Problem Link: https://www.hackerrank.com/challenges/number-game-on-a-tree/problem?isFullScreen=true

In this HackerRank in Data Structures - Number Game on a Tree solutions,

Andy and Lily love playing games with numbers and trees. Today they have a tree consisting of n nodes and n - 1 edges. Each edge i has an integer weight, wi efore the game starts, Andy chooses an unordered pair of distinct nodes, (u,v) and uses all the edge weights present on the unique path from node u to node v to construct a list of numbers. For example, in the diagram below, Andy constructs a list from the edge weights along the path (2,6):

image

 

Andy then uses this list to play the following game with Lily:

 

  • Two players move in alternating turns, and both players play optimally (meaning they will not make a move that causes them to lose the game if some better, winning move exists).
  • Andy always starts the game by removing a single integer from the list.
  • During each subsequent move, the current player removes an integer less than or equal to the integer removed in the last move.
  • The first player to be unable to move loses the game.

For example, if the list of integers is {1,1,2,3,3,4} and Andy starts the game by removing 3 the list becomes {1,1,2,3,4}. Then, in Lily's move, she must remove a remaining integer less than or equal to 3 (i.e., 1,1,2 or 3 ).

The two friends decide to play g games, where each game is in the form of a tree. For each game, calculate the number of unordered pairs of nodes that Andy can choose to ensure he always wins the game.

Input Format

The first line contains a single integer, g, denoting the number of games. The subsequent lines describe each game in the following format:

  1. The first line contains an integer, n, denoting the number of nodes in the tree.
  2. Each line i of the n - 1 subsequent lines contains three space-separated integers describing the respective values of ui, vi and wi for the i**th edge connecting nodes ui and vi with weight wi.

Constraints

  • 1 <= g <= 10
  • 1 <= n <= 5 * 10**5
  • 1 <= ui,vi <= n
  • 0 <= vi <= 10**9
  • Sum of n over all games does not exceed 5 * 10**5

Output Format

For each game, print an integer on a new line describing the number of unordered pairs Andy can choose to construct a list that allows him to win the game.

Sample Input 0

3
5
1 2 2
1 3 1
3 4 1
3 5 2
5
1 2 0
2 3 2
3 4 2
4 5 0
5
1 2 0
1 3 1
3 4 0
3 5 2

Sample Output 0

9
8
10

 

 

 

Code Examples

#1 Code Example with C++ Programming

Code - C++ Programming


#include <bits/stdc++.h>

using namespace std ;

#define ft first
#define sd second
#define pb push_back
#define all(x) x.begin(),x.end()

#define ll long long int
#define vi vector<int>
#define vii vector < pair<int,int> >
#define pii pair<int,int>
#define plii pair<pair, int>
#define piii pair
#define viii vector<pair >
#define vl vector<ll>
#define vll vector<pair >
#define pll pair
#define pli pair
#define mp make_pair
#define ms(x, v) memset(x, v, sizeof x)

#define sc1(x) scanf("%d",&x)
#define sc2(x,y) scanf("%d%d",&x,&y)
#define sc3(x,y,z) scanf("%d%d%d",&x,&y,&z)

#define scll1(x) scanf("%lld",&x)
#define scll2(x,y) scanf("%lld%lld",&x,&y)
#define scll3(x,y,z) scanf("%lld%lld%lld",&x,&y,&z)

#define pr1(x) printf("%d\n",x)
#define pr2(x,y) printf("%d %d\n",x,y)
#define pr3(x,y,z) printf("%d %d %d\n",x,y,z)

#define prll1(x) printf("%lld\n",x)
#define prll2(x,y) printf("%lld %lld\n",x,y)
#define prll3(x,y,z) printf("%lld %lld %lld\n",x,y,z)

#define pr_vec(v) for(int i=0;i<v.size();i++) cout << v[i] << " " ;

#define f_in(st) freopen(st,"r",stdin)
#define f_out(st) freopen(st,"w",stdout)

#define fr(i, a, b) for(i=a; i < =b; i++)
#define fb(i, a, b) for(i=a; i>=b; i--)
#define ASST(x, l, r) assert( x  < = r && x >= l )

#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>

const int mod = 1e9 + 7;

int ADD(int a, int b, int m = mod) {
    int s = a;
    s += b;
    if( s >= m )
      s -= m;
    return s;
}

int MUL(int a, int b, int m = mod) {
    return (1LL * a * b % m);
}

int power(int a, int b, int m = mod) {
    int res = 1;
    while( b ) {
        if( b & 1 ) {
            res = 1LL * res * a % m;
        }
        a = 1LL * a * a % m;
        b /= 2;
    }
    return res;
}

ll nC2(ll x) {
    return ( x * ( x - 1 ) / 2 );
}

const int maxn = 5 * 1e5 + 5;

int t, n, vis[maxn], cnt;
map < int, int> M;
vii adj[ maxn ];
int prime1 = 23, prime2 = 7, base[2][maxn];
int mod1 = 1589917477;
int mod2 = 1897266401;
vii a;
void dfs(int u, int p = 0, ll cst1 = 0, ll cst2 = 0) {
    a[u-1].ft = cst1;
    a[u-1].sd = cst2;
    for( auto it: adj[u] ) {
        if( it.ft != p ) {
            if(!M.count(it.sd)) {
                M[it.sd] = cnt ++;
            }
            vis[M[it.sd]] = 1 - vis[M[it.sd]];
            cst1 += (vis[M[it.sd]] ? base[0][M[it.sd]] : -base[0][M[it.sd]]);
            cst2 += (vis[M[it.sd]] ? base[1][M[it.sd]] : -base[1][M[it.sd]]);
            if( cst1 >= mod1 ) cst1 -= mod1; if( cst1 < 0 > cst1 += mod1;
            if( cst2 >= mod2 ) cst2 -= mod2; if( cst2 < 0 ) cst2 += mod2;
            dfs(it.ft, u, cst1, cst2);
            vis[M[it.sd]] = 1 - vis[M[it.sd]];
            cst1 += (vis[M[it.sd]] ? base[0][M[it.sd]] : -base[0][M[it.sd]]);
            cst2 += (vis[M[it.sd]] ? base[1][M[it.sd]] : -base[1][M[it.sd]]>;
            if( cst1 >= mod1 ) cst1 -= mod1; if( cst1 < 0 > cst1 += mod1;
            if( cst2 >= mod2 ) cst2 -= mod2; if( cst2 < 0 ) cst2 += mod2;
        }
    }
}

int main() {
    cin >> t;
    int sum = 0;
    while( t-- ) {
        cin >> n; sum += n;
            assert(sum  < = 500000);
        int i; base[0][0] = base[1][0] = 1;
        fr(i, 1, n) {
            base[0][i] = 1LL * base[0][i-1] * prime1 % mod1;
            base[1][i] = 1LL * base[1][i-1] * prime2 % mod2;
        }
        fr(i, 1, n-1) {
            int u, v, cst; 
            cin >> u >> v >> cst;
            adj[u].pb( {v, cst} );
            adj[v].pb( {u, cst} );
        }
        cnt = 0;
        a.resize(n);
        dfs(1, 0, 0);
        assert(a.size() == n);
        sort(all(a));
        i = 0;
        ll ans = 0;
        while( i  <  n ) {
            pii x = a[i]; int c = 0;
            while( i  <  n && x == a[i] ) {
                c ++; i ++;
            }
            ans += 1LL * c * (c-1) / 2;
        }
        ans = nC2(n) - ans;
        cout << ans << "\n";
        M.clear(); a.clear();
        fr(i, 0, n) {
            adj[i].clear(); 
            vis[i] = base[0][i] = base[1][i] = 0;
        }
    }
    assert(n  < = 500000>;
    return 0;
}
Copy The Code & Try With Live Editor

#2 Code Example with Java Programming

Code - Java Programming


import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.InputMismatchException;
import java.util.Random;

public class D2 {
    InputStream is;
    PrintWriter out;
    String INPUT = "";
    
    Random gen = new Random();
    long[] zh = gen.longs(500005).toArray();
    
    void solve()
    {
        for(int T = ni();T > 0;T--){
            int n = ni();
            int[] from = new int[n - 1];
            int[] to = new int[n - 1];
            int[] w = new int[n-1];
            for (int i = 0; i  <  n - 1; i++) {
                from[i] = ni() - 1;
                to[i] = ni() - 1;
                w[i] = ni();
            }
            w = shrink(w);
            
            int[][][] g = packWU(n, from, to, w);
            int[][] pars = parents(g, 0);
            int[] par = pars[0], ord = pars[1], dep = pars[2];
            int[] pw = pars[4];
            
            long[] dp = new long[n];
            for(int i = 1;i  <  n;i++){
                int cur = ord[i];
                dp[cur] = dp[par[cur]] ^ zh[pw[cur]];
            }
            Arrays.sort(dp);
            long ret = (long)n*(n-1)/2;
            for(int i = 0;i  <  n;){
                int j = i;
                while(j  <  n && dp[i] == dp[j])j++;
                ret -= (long)(j-i)*(j-i-1)/2;
                
                i = j;
            }
            out.println(ret);
        }
    }
    
    public static int[] shrink(int[] a) {
        int n = a.length;
        long[] b = new long[n];
        for (int i = 0; i  <  n; i++)
            b[i] = (long) a[i] << 32 | i;
        Arrays.sort(b);
        int[] ret = new int[n];
        int p = 0;
        for (int i = 0; i  <  n; i++) {
            if (i > 0 && (b[i] ^ b[i - 1]) >> 32 != 0)
                p++;
            ret[(int) b[i]] = p;
        }
        return ret;
    }


    public static int[][] parents(int[][][] g, int root) {
        int n = g.length;
        int[] par = new int[n];
        Arrays.fill(par, -1);
        int[] dw = new int[n];
        int[] pw = new int[n];
        int[] dep = new int[n];

        int[] q = new int[n];
        q[0] = root;
        for (int p = 0, r = 1; p  <  r; p++) {
            int cur = q[p];
            for (int[] nex : g[cur]) {
                if (par[cur] != nex[0]) {
                    q[r++] = nex[0];
                    par[nex[0]] = cur;
                    dep[nex[0]] = dep[cur] + 1;
                    dw[nex[0]] = dw[cur] + nex[1];
                    pw[nex[0]] = nex[1];
                }
            }
        }
        return new int[][] { par, q, dep, dw, pw };
    }


    public static int[][][] packWU(int n, int[] from, int[] to, int[] w) {
        int[][][] g = new int[n][][];
        int[] p = new int[n];
        for (int f : from)
            p[f]++;
        for (int t : to)
            p[t]++;
        for (int i = 0; i  <  n; i++)
            g[i] = new int[p[i]][2];
        for (int i = 0; i  <  from.length; i++) {
            --p[from[i]];
            g[from[i]][p[from[i]]][0] = to[i];
            g[from[i]][p[from[i]]][1] = w[i];
            --p[to[i]];
            g[to[i]][p[to[i]]][0] = from[i];
            g[to[i]][p[to[i]]][1] = w[i];
        }
        return g;
    }

    
    void run() throws Exception
    {
        is = INPUT.isEmpty() ? System.in : new ByteArrayInputStream(INPUT.getBytes());
        out = new PrintWriter(System.out);
        
        long s = System.currentTimeMillis();
        solve();
        out.flush();
        if(!INPUT.isEmpty())tr(System.currentTimeMillis()-s+"ms");
    }
    
    public static void main(String[] args) throws Exception { new D2().run(); }
    
    private byte[] inbuf = new byte[1024];
    public int lenbuf = 0, ptrbuf = 0;
    
    private int readByte()
    {
        if(lenbuf == -1)throw new InputMismatchException();
        if(ptrbuf >= lenbuf){
            ptrbuf = 0;
            try { lenbuf = is.read(inbuf); } catch (IOException e) { throw new InputMismatchException(); }
            if(lenbuf <= 0)return -1;
        }
        return inbuf[ptrbuf++];
    }
    
    private boolean isSpaceChar(int c> { return !(c >= 33 && c  < = 126); }
    private int skip() { int b; while((b = readByte()) != -1 && isSpaceChar(b)); return b; }
    
    private double nd() { return Double.parseDouble(ns()); }
    private char nc() { return (char)skip(); }
    
    private String ns()
    {
        int b = skip();
        StringBuilder sb = new StringBuilder();
        while(!(isSpaceChar(b))){ // when nextLine, (isSpaceChar(b) && b != ' ')
            sb.appendCodePoint(b);
            b = readByte();
        }
        return sb.toString();
    }
    
    private char[] ns(int n)
    {
        char[] buf = new char[n];
        int b = skip(), p = 0;
        while(p  <  n && !(isSpaceChar(b))){
            buf[p++] = (char)b;
            b = readByte();
        }
        return n == p ? buf : Arrays.copyOf(buf, p);
    }
    
    private char[][] nm(int n, int m)
    {
        char[][] map = new char[n][];
        for(int i = 0;i  <  n;i++)map[i] = ns(m);
        return map;
    }
    
    private int[] na(int n)
    {
        int[] a = new int[n];
        for(int i = 0;i  <  n; i++)a[i] = ni();
        return a;
    }
    
    private int ni()
    {
        int num = 0, b;
        boolean minus = false;
        while((b = readByte()) != -1 && !((b >= '0' && b  < = '9') || b == '-'));
        if(b == '-'){
            minus = true;
            b = readByte();
        }
        
        while(true){
            if(b >= '0' && b <= '9'){
                num = num * 10 + (b - '0');
            }else{
                return minus ? -num : num;
            }
            b = readByte();
        }
    }
    
    private long nl()
    {
        long num = 0;
        int b;
        boolean minus = false;
        while((b = readByte()> != -1 && !((b >= '0' && b  < = '9') || b == '-'));
        if(b == '-'){
            minus = true;
            b = readByte();
        }
        
        while(true){
            if(b >= '0' && b <= '9'){
                num = num * 10 + (b - '0');
            }else{
                return minus ? -num : num;
            }
            b = readByte();
        }
    }
    
    private static void tr(Object... o) { System.out.println(Arrays.deepToString(o)>; }
}
Copy The Code & Try With Live Editor
Advertisements

Demonstration


Previous
[Solved] Sum of the Maximums solution in Hackerrank - Hacerrank solution C, C++, java,js, Python
Next
[Solved] Heavy Light 2 White Falcon solution in Hackerrank - Hacerrank solution C, C++, java,js, Python