Algorithm


Problem Name: Data Structures - Lazy White Falcon

Problem Link: https://www.hackerrank.com/challenges/lazy-white-falcon/problem?isFullScreen=true

In this HackerRank in Data Structures - Lazy White Falcon solutions

White Falcon just solved the data structure problem below using heavy-light decomposition. Can you help her find a new solution that doesn't require implementing any fancy techniques?

There are 2 types of query operations that can be performed on a tree:

  1. 1 u x: Assign x as the value of node u.
  2. 2 u v: Print the sum of the node values in the unique path from node u to node v.

Given a tree with N nodes where each node's value is initially O, execute Q queries.

Input Format

The first line contains 2 space-separated integers, N and Q, respectively.
The N - 1 subsequent lines each contain 2 space-separated integers describing an undirected edge in the tree.
Each of the Q subsequent lines contains a query you must execute.

Constraints

  • 1 <- N,Q <= 10**5
  • 1 <= x <= 1000
  • It is guaranteed that the input describes a connected tree with N nodes.
  • Nodes are enumerated with O-based indexing.

Output Format

For each type-2 query, print its integer result on a new line.

Sample Input

3 3
0 1
1 2
1 0 1
1 1 2
2 0 2

Sample Output

3

Explanation

After the first 2 queries, the value of node n0 = 1 and the value of node n = 2. The third query requires us to print the sum of the node values in the path from nodes 0 to 2 which is 1 + 2 + 0 = 3. Thus, we print 3 on a new line.

 

 

 

 

 

Code Examples

#1 Code Example with C Programming

Code - C Programming


#include <stdio.h>
#include <stdlib.h>
typedef struct _lnode{
int x;
int w;
struct _lnode *next;
} lnode;
typedef struct _tree{
int sum;
} tree;
void insert_edge(int x,int y,int w);
void dfs0(int u);
void dfs1(int u,int c);
void preprocess();
int lca(int a,int b);
int sum(int v,int tl,
int tr,int l,int r,tree *t);
void update(int v,int tl,
int tr,int pos,int new_val,tree *t);
int min(int x,int y);
int max(int x,int y);
int solve(int x,int ancestor);
int N,cn,level[100000],DP[18][100000],
subtree_size[100000],special[100000],
node_chain[100000],node_idx[100000],
chain_head[100000],chain_len[100000]={0};
lnode *table[100000]={0};
tree *chain[100000];

int main(){
int Q,x,y,i;
scanf("%d%d",&N,&Q);
for(i=0;i < N-1;i++){
scanf("%d%d",&x,&y);
insert_edge(x,y,1);
}
preprocess();
while(Q--){
scanf("%d",&x);
switch(x){
case 1:
scanf("%d%d",&x,&y);
update(1,0,chain_len[node_chain[x]]
-1,node_idx[x],y,chain[node_chain[x]]);
break;
default:
scanf("%d%d",&x,&y);
i=lca(x,y);
printf("%d\n",
solve(x,i)+solve(y,i)-
sum(1,0,chain_len[node_chain[i]]
-1,node_idx[i],node_idx[i],chain[node_chain[i]]));
}
}
return 0;
}
void insert_edge(int x,int y,int w){
lnode *t=malloc(sizeof(lnode));
t->x=y;
t->w=w;
t->next=table[x];
table[x]=t;
t=malloc(sizeof(lnode));
t->x=x;
t->w=w;
t->next=table[y];
table[y]=t;
return;
}
void dfs0(int u){
lnode *x;
subtree_size[u]=1;
special[u]=-1;
for(x=table[u];x;x=x->next)
if(x->x!=DP[0][u]){
DP[0][x->x]=u;
level[x->x]=level[u]+1;
dfs0(x->x);
subtree_size[u]+=subtree_size[x->x];
if(special[u]==-1 || 
subtree_size[x->x]>subtree_size[special[u]])
special[u]=x->x;
}
return;
}
void dfs1(int u,int c){
lnode *x;
node_chain[u]=c;
node_idx[u]=chain_len[c]++;
for(x=table[u];x;x=x->next)
if(x->x!=DP[0][u])
if(x->x==special[u])
dfs1(x->x,c);
else{
chain_head[cn]=x->x;
dfs1(x->x,cn++);
}
return;
}
void preprocess(){
int i,j;
level[0]=0;
DP[0][0]=0;
dfs0(0);
for(i =1;i < 18; i++)
for(j = 0; j  <  N; j++)
DP[i][j] = DP[i-1][DP[i-1][j]];
cn=1;
chain_head[0]=0;
dfs1(0,0);
for(i = 0; i  <  cn; i++)
chain[i]=(tree*)malloc(
    4*chain_len[i]*sizeof(tree));
for(i=0;i < N;i++)
update(1,0,chain_len[node_chain[i]]-1,
node_idx[i],0,chain[node_chain[i]]);
return;
}
int lca(int a,int b>{
int i;
if(level[a]>level[b]){
i=a;
a=b;
b=i;
}
int d = level[b]-level[a];
for(i = 0; i < 18; i++)
if(d&(1<<i))
b=DP[i][b];
if(a==b>return a;
for(i =17; i >= 0; i--)
if(DP[i][a]!=DP[i][b])
a=DP[i][a],b=DP[i][b];
return DP[0][a];
}
int sum(int v,int tl,int tr,int l,
int r,tree *t){
if(l>r)
return 0;
if(l==tl && r==tr)
return t[v].sum;
int tm=(tl+tr)/2;
return sum(v*2,tl,tm,l,min(r,tm),t)+
sum(v*2+1,tm+1,tr,max(l,tm+1),r,t);
}
void update(int v,int tl,int tr,
int pos,int new_val,tree *t){
if(tl==tr)
t[v].sum=new_val;
else{
int tm=(tl+tr)/2;
if(pos<=tm)
update(v*2,tl,tm,pos,new_val,t);
else
update(v*2+1,tm+1,tr,pos,new_val,t);
t[v].sum=t[v*2].sum+t[v*2+1].sum;
}
}
int min(int x,int y){
return (x < y)?x:y;
}
int max(int x,int y>{
return (x>y)?x:y;
}
int solve(int x,int ancestor){
int ans=0;
while(node_chain[x]!=node_chain[ancestor]){
ans+=sum(1,0,chain_len[node_chain[x]]-1,
0,node_idx[x],chain[node_chain[x]]);
x=DP[0][chain_head[node_chain[x]]];
}
ans+=sum(1,0,chain_len[node_chain[x]]-1,
node_idx[ancestor],node_idx[x],
chain[node_chain[x]]);
return ans;
}
Copy The Code & Try With Live Editor

#2 Code Example with C++ Programming

Code - C++ Programming


#include <bits/stdc++.h>

using namespace std;

vector < string> split_string(string);

struct Node
{
    int path;
    int size;
    int depth;
    int parent;
};

struct Path
{
    int root;
    int depth;
    int size;
    vector<int> sums;

    int sum(int ti, int tl, int tr, int l, int r)
    {
        if (l  < = r)
        {
            if (l == tl && r == tr)
            {
                return sums[ti];
            }
            else
            {
                int tm = (tl + tr) / 2;
                return sum(2 * ti, tl, tm, l, min(r, tm)) + sum(2 * ti + 1, tm + 1, tr, max(l, tm + 1), r);
            }
        }
        return 0;
    }

    void assign(int value, int ti, int tl, int tr, int i)
    {
        if (i == tl && i == tr)
        {
            sums[ti] = value;
        }
        else
        {
            int tm = (tl + tr) / 2;
            if (i  < = tm)
            {
                assign(value, 2 * ti, tl, tm, i);
            }
            else
            {
                assign(value, 2 * ti + 1, tm + 1, tr, i);
            }
            sums[ti] = sums[2 * ti] + sums[2 * ti + 1];
        }
    }

    void assign(int i, int v)
    {
        assign(v, 1, 0, size - 1, i);
    }

    int sum(int l, int r)
    {
        return sum(1, 0, size - 1, l, r);
    }

    Path(int root, int depth, int size) : root(root), depth(depth), size(size)
    {
        int temp = 1;
        while (temp  < = size)
        {
            temp <<= 1;
        }
        sums.resize(2 * temp, 0);
    }
};

class Tree
{
    vector < Path> paths;
    vector<Node> nodes;

    bool isHeavy(int node)
    {
        int parent = nodes[node].parent;
        return (parent  <  0) ? false : (2 * nodes[node].size >= nodes[parent].size);
    }

public:
    Tree(vector<vector<int>>& edges)
    {
        vector < vector<int>> tree(edges.size() + 1);
        for (int i = 0; i  <  edges.size(); i++)
        {
            tree[edges[i][0]].push_back(edges[i][1]);
            tree[edges[i][1]].push_back(edges[i][0]);
        }

        nodes.resize(edges.size() + 1);
        init(0, -1, 0, tree);
        createPaths(0, -1, tree);
    }

    void init(int curr, int parent, int depth, vector < vector<int>>& tree)
    {
        nodes[curr].size = 1;
        nodes[curr].depth = depth;
        nodes[curr].parent = parent;
        for (int i = 0; i  <  tree[curr].size(); i++)
        {
            int next = tree[curr][i];
            if (next != parent)
            {
                init(next, curr, depth + 1, tree);
                nodes[curr].size += nodes[next].size;
            }
        }
    }

    void createPaths(int curr, int parent, vector < vector<int>>& tree)
    {
        bool hasHeavy = false;
        for (int i = 0; i  <  tree[curr].size(); i++)
        {
            int next = tree[curr][i];
            if (next != parent)
            {
                createPaths(next, curr, tree);
                if (isHeavy(next))
                {
                    hasHeavy = true;
                }
            }
        }
        if (!hasHeavy)
        {
            createPath(curr);
        }
    }

    void createPath(int node)
    {
        int length = 1;
        while (true)
        {
            nodes[node].path = paths.size();
            if (!isHeavy(node))
            {
                break;
            }
            node = nodes[node].parent;
            length += 1;
        }
        paths.push_back(Path(node, nodes[node].depth, length));
    }

    void assign(int node, int value)
    {
        int path = nodes[node].path;
        paths[path].assign(nodes[node].depth - paths[path].depth, value);
    }

    int sum(int u, int v)
    {
        if (nodes[u].path == nodes[v].path)
        {
            int path = nodes[u].path;
            int l = std::min(nodes[u].depth, nodes[v].depth);
            int r = std::max(nodes[u].depth, nodes[v].depth);
            return paths[path].sum(l - paths[path].depth, r - paths[path].depth);
        }
        int rootU = paths[nodes[u].path].root;
        int rootV = paths[nodes[v].path].root;
        if (nodes[rootU].depth  <  nodes[rootV].depth)
        {
            int path = nodes[v].path;
            return paths[path].sum(0, nodes[v].depth - paths[path].depth) + sum(u, nodes[rootV].parent);
        }
        else
        {
            int path = nodes[u].path;
            return paths[path].sum(0, nodes[u].depth - paths[path].depth) + sum(nodes[rootU].parent, v);
        }
    }
};

vector<int> solve(vector < vector<int>>& edges, vector < vector<int>>& queries)
{
    Tree tree(edges);
    vector<int> res;
    for (int i = 0; i  <  queries.size(); i++)
    {
        if (queries[i][0] == 1)
        {
            tree.assign(queries[i][1], queries[i][2]);
        }
        else
        {
            res.push_back(tree.sum(queries[i][1], queries[i][2]));
        }
    }
    return res;
}

int main()
{
    ofstream fout(getenv("OUTPUT_PATH"));

    string nq_temp;
    getline(cin, nq_temp);

    vector < string> nq = split_string(nq_temp);

    int n = stoi(nq[0]);

    int q = stoi(nq[1]);

    vector < vector<int>> tree(n-1);
    for (int tree_row_itr = 0; tree_row_itr  <  n-1; tree_row_itr++) {
        tree[tree_row_itr].resize(2);

        for (int tree_column_itr = 0; tree_column_itr  <  2; tree_column_itr++) {
            cin >> tree[tree_row_itr][tree_column_itr];
        }

        cin.ignore(numeric_limits < streamsize>::max(), '\n');
    }

    vector<vector<int>> queries(q);
    for (int queries_row_itr = 0; queries_row_itr  <  q; queries_row_itr++) {
        queries[queries_row_itr].resize(3);

        for (int queries_column_itr = 0; queries_column_itr  <  3; queries_column_itr++) {
            cin >> queries[queries_row_itr][queries_column_itr];
        }

        cin.ignore(numeric_limits < streamsize>::max(), '\n');
    }

    vector<int> result = solve(tree, queries);

    for (int result_itr = 0; result_itr  <  result.size(); result_itr++) {
        fout << result[result_itr];

        if (result_itr != result.size() - 1) {
            fout << "\n";
        }
    }

    fout << "\n";

    fout.close();

    return 0;
}

vector < string> split_string(string input_string) {
    string::iterator new_end = unique(input_string.begin(), input_string.end(), [] (const char &x, const char &y) {
        return x == y and x == ' ';
    });

    input_string.erase(new_end, input_string.end());

    while (input_string[input_string.length() - 1] == ' ') {
        input_string.pop_back();
    }

    vector < string> splits;
    char delimiter = ' ';

    size_t i = 0;
    size_t pos = input_string.find(delimiter);

    while (pos != string::npos) {
        splits.push_back(input_string.substr(i, pos - i));

        i = pos + 1;
        pos = input_string.find(delimiter, i);
    }

    splits.push_back(input_string.substr(i, min(pos, input_string.length()) - i + 1));

    return splits;
}
Copy The Code & Try With Live Editor

#3 Code Example with Java Programming

Code - Java Programming


import java.io.*;
import java.math.*;
import java.text.*;
import java.util.*;
import java.util.regex.*;
import java.util.stream.Collectors;

    class Node {
        int id;
        Node parent;
        List < Node> children;
        List connections;
        int value;

        public Node(int id) {
            this.id = id;
            this.children = new LinkedList < >();
            this.connections = new LinkedList<>();
        }

        public void addConnection(Node c) {
            this.connections.add(c);
        }

        public void addChild(Node child) {
            this.children.add(child);
            child.parent = this;
        }

        public void updateValue(int nodeId, int value) {
            Node toUpdateNode = nodes[nodeId];
            int increment = value - toUpdateNode.value;
            toUpdateNode.value = value;
            int firstEuler = eulerFirsts[toUpdateNode.id],
                    lastEuler = eulerLasts[toUpdateNode.id];
            int startIndex = Math.min(firstEuler, lastEuler),
                    endIndex = Math.max(firstEuler, lastEuler);
            this.rootAggregates.update(startIndex, endIndex, increment);
        }

        public int rootAggregation(int nodeId) {
            Node toUpdateNode = nodes[nodeId];
            int firstEuler = eulerFirsts[toUpdateNode.id];
            return this.rootAggregates.get(firstEuler);
        }

        public boolean isParent(Node another) {
            return this.parent != null && this.parent.id == another.id;
        }

        public int LCAIndex(int n1, int n2) {
            int n1EulerAppearance = eulerFirsts[n1],
                    n2EulerAppearance = eulerFirsts[n2];
            int startIndex = Math.min(n1EulerAppearance, n2EulerAppearance),
                    endIndex = Math.max(n1EulerAppearance, n2EulerAppearance) + 1;
            EulerNode lcaNode = eulerNodes.min(startIndex, endIndex);
            return lcaNode.vertex;
        }

        private int[] eulerFirsts;
        private int[] eulerLasts;
        private int[] eulerPath;
        private Node[] nodes;
        private SegmentTreeSegmentedAddition rootAggregates;
        private SegmentTreeMin < EulerNode> eulerNodes;
        public void processLCA(Node[] nodes) {
            this.nodes = nodes;
            boolean[] visited = new boolean[nodes.length];
            this.eulerFirsts = new int[nodes.length];
            Arrays.fill(this.eulerFirsts, -1);
            this.eulerLasts = new int[nodes.length];
            Arrays.fill(this.eulerLasts, -1);
            LinkedList < Integer> eulerHeights = new LinkedList<>();
            LinkedList stack = new LinkedList<>(),
                    eulerPath = new LinkedList<>();
            stack.push(this);
            int depth = 1;
            while(!stack.isEmpty()) {
                Node current = stack.pop();
                eulerPath.add(current);
                eulerHeights.add(depth);
                int eulerLength = eulerPath.size() - 1;
                if (this.eulerFirsts[current.id] == -1) {
                    this.eulerFirsts[current.id] = eulerLength;
                }
                if (this.eulerLasts[current.id] == -1 || eulerLength > this.eulerLasts[current.id]) {
                    this.eulerLasts[current.id] = eulerLength;
                }
                if (!visited[current.id]) {
                    visited[current.id] = true;
                    for (Node child : current.children) {
                        stack.push(current);
                        stack.push(child);
                    }
                }
                Node next = stack.peekFirst();
                if (next != null && next.isParent(current)) depth++;
                else depth--;
            }
            int[] vertices = eulerPath.stream().mapToInt(n -> n.id).toArray();
            int[] heights = eulerHeights.stream().mapToInt(Integer::intValue).toArray();
            EulerNode[] eulerNodes = new EulerNode[eulerPath.size()];
            for (int i = 0; i  <  heights.length; i++) {
                eulerNodes[i] = new EulerNode(heights[i], vertices[i]);
            }
            this.eulerNodes = new SegmentTreeMin < >(eulerNodes);
            this.eulerPath = vertices;
            this.rootAggregates = new SegmentTreeSegmentedAddition( new int[this.eulerPath.length] );
        }
        public void processRoot() {
            LinkedList < Node> stack = new LinkedList<>();
            stack.push(this);
            while (!stack.isEmpty()) {
                Node current = stack.pop();
                for (Node connection : current.connections) {
                    if (connection == current || connection == current.parent) continue;
                    stack.push(connection);
                    current.addChild(connection);
                }
            }
        }
    }

    class EulerNode implements Comparable {
        int height;
        int vertex;

        public EulerNode(int height, int vertex) {
            this.height = height;
            this.vertex = vertex;
        }

        @Override
        public int compareTo(Object o) {
            EulerNode other = (EulerNode)o;
            return this.height - other.height;
        }
    }

    class SegmentTreeSegmentedAddition {

        private int[] values;
        private int[] tree;

        public SegmentTreeSegmentedAddition(int[] values) {
            this.values = values;
            this.tree = new int[values.length * 4];
            build(1, 0, values.length - 1);
        }

        private void build(int v, int tl, int tr) {
            if (tl == tr) {
                tree[v] = values[tl];
            } else {
                int tm = (tl + tr) / 2;
                build(v * 2, tl, tm);
                build(v * 2 + 1, tm + 1, tr);
                tree[v] = 0;
            }
        }

        public void update(int l, int r, int add) {
            this.update(1, 0, values.length - 1, l, r, add);
        }

        private void update(int v, int tl, int tr, int l, int r, int add) {
            if (l > r)
                return;
            if (l == tl && r == tr) {
                tree[v] += add;
            } else {
                int tm = (tl + tr) / 2;
                update(v * 2, tl, tm, l, Math.min(r, tm), add);
                update(v * 2 + 1, tm + 1, tr, Math.max(l, tm + 1), r, add);
            }
        }

        public int get(int pos) {
            return this.get(1, 0, values.length - 1, pos);
        }

        private int get(int v, int tl, int tr, int pos) {
            if (tl == tr)
                return tree[v];
            int tm = (tl + tr) / 2;
            if (pos  < = tm)
                return tree[v] + get(v * 2, tl, tm, pos);
            else
                return tree[v] + get(v * 2 + 1, tm + 1, tr, pos);
        }
    }

    class SegmentTreeMin < T extends  Comparable> {

        private T[] values;
        private ArrayList tree;

        public SegmentTreeMin(T[] values) {
            this.values = values;
            this.tree = new ArrayList < T>(this.values.length * 4);
            for (int i = 0; i  <  this.values.length * 4; i++) {
                this.tree.add(null);
            }
            build(1, 0, values.length - 1);
        }

        private void build(int v, int tl, int tr) {
            if (tl == tr) {
                tree.set(v, values[tl]);
            } else {
                int tm = (tl + tr) / 2;
                build(v * 2, tl, tm);
                build(v * 2 + 1, tm + 1, tr);
                T left = tree.get(v * 2),
                        right = tree.get(v * 2 + 1);
                tree.set(v, left.compareTo(right)  <  0 ? left : right);
            }
        }

        public T min(int l, int r) {
            return min(1, 0, values.length - 1, l, r);
        }

        private T min(int v, int tl, int tr, int l, int r) {
            if (l > r)
                return null; // return value that will not affect the reduction
            if (l == tl && r == tr) {
                return tree.get(v);
            }
            int tm = (tl + tr) / 2;
            T min1 = min(v*2, tl, tm, l, Math.min(r, tm)),
                    min2 = min(v*2+1, tm+1, tr, Math.max(l, tm+1), r);
            if (min1 == null && min2 == null) throw new NullPointerException("Cannot retrieve minimum value from 2 nulls");
            else if (min1 == null) return min2;
            else if (min2 == null) return min1;
            else return min1.compareTo(min2)  <  0 ? min1 : min2;
        }

    }


public class Solution {

    static int[] solve(int[][] edges, int[][] queries) {
        List result = new LinkedList<>();
        Node[] nodes = new Node[edges.length + 1];
        for (int i = 0; i  <  edges.length; i++) {
            int[] link = edges[i];
            int n1 = link[0], n2 = link[1];
            Node node1 = nodes[n1] != null ? nodes[n1] : new Node(n1),
                    node2 = nodes[n2] != null ? nodes[n2] : new Node(n2);
            node1.addConnection(node2);
            node2.addConnection(node1);
            nodes[n1] = node1;
            nodes[n2] = node2;
        }
       
        Node root = nodes[0];
        root.processRoot();
       
        root.processLCA(nodes);
       
        for (int[] query : queries) {
            int operation = query[0], arg1 = query[1], arg2 = query[2];
            if (operation == 1) { // update value
                root.updateValue(arg1, arg2);
            } else if (operation == 2) { // print path aggregation
                if (arg1 == arg2) {
                    result.add(nodes[arg1].value);
                }
                else {
                    int lcaIndex = root.LCAIndex(arg1, arg2);
                    int lcaValue = root.rootAggregation(nodes[lcaIndex].id),
                            n1Aggregation = root.rootAggregation(nodes[arg1].id),
                            n2Aggregation = root.rootAggregation(nodes[arg2].id);
                    int rootAggregationOffset = 0;
                    if (lcaIndex != root.id) {
                        rootAggregationOffset += root.rootAggregation(nodes[lcaIndex].parent.id);
                    }
                    int aggregation = n1Aggregation + n2Aggregation - lcaValue - rootAggregationOffset;
                    result.add( aggregation );
                }
            }
        }
        
        return result.stream().mapToInt(Integer::intValue).toArray();
    }

    private static final Scanner scanner = new Scanner(System.in);

    public static void main(String[] args) throws IOException {
        BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(System.getenv("OUTPUT_PATH")));

        String[] nq = scanner.nextLine().split(" ");

        int n = Integer.parseInt(nq[0]);

        int q = Integer.parseInt(nq[1]);

        int[][] tree = new int[n-1][2];

        for (int treeRowItr = 0; treeRowItr  <  n-1; treeRowItr++) {
            String[] treeRowItems = scanner.nextLine().split(" ");
            scanner.skip("(\r\n|[\n\r\u2028\u2029\u0085])?");

            for (int treeColumnItr = 0; treeColumnItr  <  2; treeColumnItr++) {
                int treeItem = Integer.parseInt(treeRowItems[treeColumnItr]);
                tree[treeRowItr][treeColumnItr] = treeItem;
            }
        }

        int[][] queries = new int[q][3];

        for (int queriesRowItr = 0; queriesRowItr  <  q; queriesRowItr++) {
            String[] queriesRowItems = scanner.nextLine().split(" ");
            scanner.skip("(\r\n|[\n\r\u2028\u2029\u0085])?");

            for (int queriesColumnItr = 0; queriesColumnItr  <  3; queriesColumnItr++) {
                int queriesItem = Integer.parseInt(queriesRowItems[queriesColumnItr]);
                queries[queriesRowItr][queriesColumnItr] = queriesItem;
            }
        }

        int[] result = solve(tree, queries);

        for (int resultItr = 0; resultItr  <  result.length; resultItr++) {
            bufferedWriter.write(String.valueOf(result[resultItr]));

            if (resultItr != result.length - 1) {
                bufferedWriter.write("\n");
            }
        }

        bufferedWriter.newLine();

        bufferedWriter.close();

        scanner.close();
    }
}
Copy The Code & Try With Live Editor

#4 Code Example with Python Programming

Code - Python Programming


from collections import defaultdict, deque
import sys

def read():
    return (int(s) for s in sys.stdin.readline().split())

N, Q = read()
adj_matrix = defaultdict(list)
for _ in range(N - 1):
    x, y = read()
    adj_matrix[x].append(y)
    adj_matrix[y].append(x)
segment_size = 120
parents = [None] * N
s_parent = [None] * N
s_weight = [0] * N
levels = [0] * N
update = [-1] * N
refresh = [-1] * N
weight = [0] * N
def refresh_segment(node, sp, tick):
    if node == sp or refresh[node] >= update[sp]:
        return
    parent = parents[node]
    if parent == sp:
        s_weight[node] = weight[parent]
    else:
        refresh_segment(parent, sp, tick)
        s_weight[node] = s_weight[parent] + weight[parent]
    refresh[node] = tick

# Initialize above with BFS, que is (segment parent, parent, node, level)
que = deque([(0, 0, 0, 0)])
while que:
    segment, parent, node, level = que.popleft()
    s_parent[node] = segment
    parents[node] = parent
    levels[node] = level
    child_segment = segment if level % segment_size else node
    for n in adj_matrix[node]:
        if n != parent:
            que.append((child_segment, node, n, level + 1))
results = []
for i in range(Q):
    op, u, x = read()
    if op == 1:
        weight[u] = x
        if levels[u] % segment_size:
            update[s_parent[u]] = i
        else:
            update[u] = i
    else:
        if levels[u] > levels[x]:
            u, x = x, u
        result = weight[x] + weight[u]
        # Traverse x upwards segment by segment until
        # levels[segments[x]] == levels[segments[u]]
        u_s = s_parent[u]
        while levels[s_parent[x]] > levels[u_s]:
            refresh_segment(x, s_parent[x], i)
            result += s_weight[x]
            x = s_parent[x]
        while s_parent[x] != s_parent[u]:
            refresh_segment(x, s_parent[x], i)
            refresh_segment(u, s_parent[u], i)
            result += s_weight[x]
            result += s_weight[u]
            x = s_parent[x]
            u = s_parent[u]
        for _ in range(levels[x] - levels[u]):
            x = parents[x]
            result += weight[x]
        while u != x:
            x = parents[x]
            u = parents[u]
            result += weight[x]
            result += weight[u]
        result -= weight[x]
        results.append(result)
print('\n'.join(str(x) for x in results))
Copy The Code & Try With Live Editor
Advertisements

Demonstration


Previous
[Solved] Functional Palindromes solution in Hackerrank - Hacerrank solution C, C++, java,js, Python
Next
[Solved] Ticket to Ride solution in Hackerrank - Hacerrank solution C, C++, java,js, Python