Algorithm


Problem Name: Data Structures - Heavy Light 2 White Falcon

Problem Link:https://www.hackerrank.com/challenges/heavy-light-2-white-falcon/problem?isFullScreen=true

In this HackerRank in Data Structures - Heavy Light 2 White Falcon solutions,

White Falcon was amazed by what she can do with heavy-light decomposition on trees. As a resut, she wants to improve her expertise on heavy-light decomposition. Her teacher gave her an another assignment which requires path updates. As always, White Falcon needs your help with the assignment.

You are given a tree with N nodes and each node's value vali is initially 0.

Let's denote the path from node u to node v like this: p1,p2,p3, ... , pk, where p1 = u and pk = v and pi and p(i+1) are connected.

The problem asks you to operate the following two types of queries on the tree:

  • "1 u v x" Add x to valp1, 2x to valp2 , 3x to valp3 ,.... , kx to valpk
  • 2 u v" print the sum of the nodes' values on the path between u and v at modulo 10**9 + 7

Input Format

First line cosists of two integers N and Q seperated by a space.
Following N -1 lines contains two integers which denote the undirectional edges of the tree.
Following Q lines contains one of the query types described above.

Note: Nodes are numbered by using 0-based indexing.

Constraints

1 <= N,Q <= 50000

0 <= x <= 10**9 + 7

Output Format

For every query of second type print a single integer.

Sample Input

3 2
0 1
1 2
1 0 2 1
2 1 2

Sample Output

5

 

 

 

 

 

 

 

Code Examples

#1 Code Example with C Programming

Code - C Programming


#include <stdio.h>
#include <stdlib.h>
#include <string.h>
typedef struct _lnode{
  int x;
  int w;
  struct _lnode *next;
} lnode;
typedef struct _tree{
  long long sum;
  long long offset1;
  long long offset2;
} tree;
#define MOD 1000000007
void insert_edge(int x,int y,int w);
void dfs0(int u);
void dfs1(int u,int c);
void preprocess();
int lca(int a,int b);
long long sum(int v,int tl,int tr,int l,int r,tree *t);
void range_update(int v,int tl,int tr,int pos1,int pos2,long long o1,long long o2,tree *t);
void push(int v,int tl,int tr,tree *t);
void range_solve(int x,int y,int z);
int min(int x,int y);
int max(int x,int y);
long long solve(int x,int ancestor);
int N,cn,level[100000],DP[18][100000],subtree_size[100000],special[100000],node_chain[100000],node_idx[100000],chain_head[100000],chain_len[100000]={0};
lnode *table[100000]={0};
tree *chain[100000];

int main(){
  int Q,x,y,i;
  scanf("%d%d",&N,&Q);
  for(i = 0; i  <  N-1; i++){
    scanf("%d%d",&x,&y);
    insert_edge(x,y,1);
  }
  preprocess();
  while(Q--){
    scanf("%d",&x);
    switch(x){
      case 1:
        scanf("%d%d%d",&x,&y,&i);
        range_solve(x,y,i);
        break;
      default:
        scanf("%d%d",&x,&y);
        i=lca(x,y);
        printf("%lld\n",(solve(x,i)+solve(y,i)-sum(1,0,chain_len[node_chain[i]]-1,node_idx[i],node_idx[i],chain[node_chain[i]])+MOD)%MOD);
    }
  }
  return 0;
}
void insert_edge(int x,int y,int w){
  lnode *t=malloc(sizeof(lnode));
  t->x=y;
  t->w=w;
  t->next=table[x];
  table[x]=t;
  t=malloc(sizeof(lnode));
  t->x=x;
  t->w=w;
  t->next=table[y];
  table[y]=t;
  return;
}
void dfs0(int u){
  lnode *x;
  subtree_size[u]=1;
  special[u]=-1;
  for(x=table[u];x;x=x->next)
    if(x->x!=DP[0][u]){
      DP[0][x->x]=u;
      level[x->x]=level[u]+1;
      dfs0(x->x);
      subtree_size[u]+=subtree_size[x->x];
      if(special[u]==-1 || subtree_size[x->x]>subtree_size[special[u]])
        special[u]=x->x;
    }
  return;
}
void dfs1(int u,int c){
  lnode *x;
  node_chain[u]=c;
  node_idx[u]=chain_len[c]++;
  for(x=table[u];x;x=x->next)
    if(x->x!=DP[0][u])
      if(x->x==special[u])
        dfs1(x->x,c);
      else{
        chain_head[cn]=x->x;
        dfs1(x->x,cn++);
      }
  return;
}
void preprocess(){
  int i,j;
  level[0]=0;
  DP[0][0]=0;
  dfs0(0);
  for(i =1;i < 18; i++)
    for(j = 0;j  <  N; j++)
      DP[i][j] = DP[i-1][DP[i-1][j]];
  cn=1;
  chain_head[0]=0;
  dfs1(0,0);
  for(i = 0; i  <  cn; i++){
    chain[i]=(tree*)malloc(4*chain_len[i]*sizeof(tree));
    memset(chain[i],0,4*chain_len[i]*sizeof(tree));
  }
  return;
}
int lca(int a,int b>{
  int i;
  if(level[a]>level[b]){
    i=a;
    a=b;
    b=i;
  }
  int d = level[b]-level[a];
  for(i = 0; i < 18; i++)
    if(d&(1<<i))
      b=DP[i][b];
  if(a==b>return a;
  for(i=17;i>=0;i--)
    if(DP[i][a]!=DP[i][b])
      a=DP[i][a],b=DP[i][b];
  return DP[0][a];
}
long long sum(int v,int tl,int tr,int l,int r,tree *t){
  push(v,tl,tr,t);
  if(l>r)
    return 0;
  if(l==tl && r==tr)
    return t[v].sum;
  int tm=(tl+tr)/2;
  return (sum(v*2,tl,tm,l,min(r,tm),t)+sum(v*2+1,tm+1,tr,max(l,tm+1),r,t))%MOD;
}
void range_update(int v,int tl,int tr,int pos1,int pos2,long long o1,long long o2,tree *t){
  push(v,tl,tr,t);
  if(pos2 < tl || pos1>tr)
    return;
  int tm=(tl+tr)/2;
  if(pos1 < =tl && pos2>=tr){
    t[v].offset1=(o1+o2*(tl-pos1))%MOD;
    t[v].offset2=o2;
  }
  else{
    range_update(v*2,tl,tm,pos1,pos2,o1,o2,t);
    range_update(v*2+1,tm+1,tr,pos1,pos2,o1,o2,t);
    push(v*2,tl,tm,t);
    push(v*2+1,tm+1,tr,t);
    t[v].sum=(t[v*2].sum+t[v*2+1].sum)%MOD;
  }
  return;
}
void push(int v,int tl,int tr,tree *t){
  if(!t[v].offset1 && !t[v].offset2)
    return;
  t[v].sum=(t[v].sum+(t[v].offset1*2+t[v].offset2*(tr-tl))*(tr-tl+1)/2%MOD)%MOD;
  if(tl!=tr){
    int tm=(tl+tr)/2;
    t[v*2].offset1=(t[v*2].offset1+t[v].offset1)%MOD;
    t[v*2+1].offset1=(t[v*2+1].offset1+t[v].offset1+t[v].offset2*(tm-tl+1))%MOD;
    t[v*2].offset2=(t[v*2].offset2+t[v].offset2)%MOD;
    t[v*2+1].offset2=(t[v*2+1].offset2+t[v].offset2)%MOD;
  }
  t[v].offset1=t[v].offset2=0;
  return;
}
void range_solve(int x,int y,int z){
  int ca=lca(x,y),ty=y;
  long long cac=0,cay=0;
  while(node_chain[x]!=node_chain[ca]){
    cac+=node_idx[x]+1;
    range_update(1,0,chain_len[node_chain[x]]-1,0,node_idx[x],z*cac%MOD,MOD-z,chain[node_chain[x]]);
    x=DP[0][chain_head[node_chain[x]]];
  }
  cac+=node_idx[x]-node_idx[ca]+1;
  range_update(1,0,chain_len[node_chain[x]]-1,node_idx[ca],node_idx[x],z*cac%MOD,MOD-z,chain[node_chain[x]]);
  cac=z*cac%MOD;
  while(node_chain[ty]!=node_chain[ca]){
    cay+=node_idx[ty]+1;
    ty=DP[0][chain_head[node_chain[ty]]];
  }
  cay+=node_idx[ty]-node_idx[ca];
  cay=(cac+z*cay)%MOD;
  while(node_chain[y]!=node_chain[ca]){
    cay=(cay-z*(long long)node_idx[y]%MOD+MOD)%MOD;
    range_update(1,0,chain_len[node_chain[y]]-1,0,node_idx[y],cay,z,chain[node_chain[y]]);
    cay=(cay-z+MOD)%MOD;
    y=DP[0][chain_head[node_chain[y]]];
  }
  cay=(cay-z*(long long)(node_idx[y]-node_idx[ca]-1)%MOD+MOD)%MOD;
    
    if((cay-z+MOD)%MOD!=cac)
        while(1);
    
  if(node_idx[y]!=node_idx[ca])
    range_update(1,0,chain_len[node_chain[y]]-1,node_idx[ca]+1,node_idx[y],cay,z,chain[node_chain[y]]);
  return;
}
int min(int x,int y){
  return (x<y)?x:y;
}
int max(int x,int y>{
  return (x>y)?x:y;
}
long long solve(int x,int ancestor){
  long long ans=0;
  while(node_chain[x]!=node_chain[ancestor]){
    ans=(ans+sum(1,0,chain_len[node_chain[x]]-1,0,node_idx[x],chain[node_chain[x]]))%MOD;
    x=DP[0][chain_head[node_chain[x]]];
  }
  ans=(ans+sum(1,0,chain_len[node_chain[x]]-1,node_idx[ancestor],node_idx[x],chain[node_chain[x]]))%MOD;
  return ans;
}
Copy The Code & Try With Live Editor

#2 Code Example with C++ Programming

Code - C++ Programming


#include <bits/stdc++.h>
using namespace std;
using ll = long long;

const int mod = 1e9 + 7;

template < int MOD>
struct mod_int {
    static const int Mod = MOD;
    unsigned x;
    mod_int() : x(0) { }
    mod_int(int sig) { int sigt = sig % MOD; if (sigt  <  0) sigt += MOD; x = sigt; }
    mod_int(long long sig) { int sigt = sig % MOD; if (sigt  <  0) sigt += MOD; x = sigt; }
    int get() const { return (int)x; }

    mod_int &operator+=(mod_int that) { if ((x += that.x) >= MOD) x -= MOD; return *this; }
    mod_int &operator-=(mod_int that) { if ((x += MOD - that.x) >= MOD) x -= MOD; return *this; }
    mod_int &operator*=(mod_int that) { x = (unsigned long long)x * that.x % MOD; return *this; }
    mod_int &operator/=(mod_int that) { return *this *= that.inverse(); }

    mod_int operator+(mod_int that) const { return mod_int(*this) += that; }
    mod_int operator-(mod_int that) const { return mod_int(*this) -= that; }
    mod_int operator*(mod_int that) const { return mod_int(*this) *= that; }
    mod_int operator/(mod_int that) const { return mod_int(*this) /= that; }

    mod_int inverse() const {
        long long a = x, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b; swap(a, b);
            u -= t * v; swap(u, v);
        }
        return mod_int(u);
    }
};

using mint = mod_int < mod>;

struct RS {
    using type = mint;
    static type id() { return 0; }
    static type op(const type& l, const type & r) {
        return l + r;
    }
};

class lct_node {
    using M = RS;
    using T = typename M::type;
    using U = pair < mint, mint>;

    lct_node *l, *r, *p;
    bool rev;
    T val, all;
    int size;
    bool flag;
    U lazy;

    int pos() {
        if (p && p->l == this) return 1;
        if (p && p->r == this) return 3;
        return 0;
    }
    void update() {
        size = (l ? l->size : 0) + (r ? r->size : 0) + 1;
        all = M::op(l ? l->all : M::id(), M::op(val, r ? r->all : M::id()));
    }
    void update_lazy(const U& v) {
        if (!flag) lazy = make_pair(0, 0);
        int ls = !rev ? (l ? l->size : 0) : (r ? r->size : 0);
        val += v.first + v.second * ls;
        all += v.first * size + ((v.second * (size - 1)) * size) / 2;
        lazy = make_pair(M::op(lazy.first, v.first), M::op(lazy.second, v.second));
        flag = true;
    }
    void rev_data() {
        lazy = make_pair(lazy.first + lazy.second * (size - 1), mint(0) - lazy.second);
    }
    void push() {
        if (pos()) p->push();
        if (rev) {
            swap(l, r);
            if (l) l->rev ^= true, l->rev_data();
            if (r) r->rev ^= true, r->rev_data();
            rev = false;
        }
        if (flag) {
            if (l) l->update_lazy(lazy);
            if (r) r->update_lazy(make_pair(lazy.first + lazy.second * (l ? l->size + 1 : 1), lazy.second));
            flag = false;
        }
    }
    void rot() {
        lct_node *par = p;
        lct_node *mid;
        if (p->l == this) {
            mid = r;
            r = par;
            par->l = mid;
        }
        else {
            mid = l;
            l = par;
            par->r = mid;
        }
        if (mid) mid->p = par;
        p = par->p;
        par->p = this;
        if (p && p->l == par) p->l = this;
        if (p && p->r == par) p->r = this;
        par->update();
        update();
    }
    void splay() {
        push();
        while (pos()) {
            int st = pos() ^ p->pos();
            if (!st) p->rot(), rot();
            else if (st == 2) rot(), rot();
            else rot();
        }
    }

public:
    lct_node() : l(nullptr), r(nullptr), p(nullptr), rev(false), val(M::id()), all(M::id()), size(1), flag(false) {}
    void expose() {
        for (lct_node *x = this, *y = nullptr; x; y = x, x = x->p) x->splay(), x->r = y, x->update();
        splay();
    }
    void link(lct_node *x) {
        x->expose();
        expose();
        p = x;
    }
    void evert() {
        expose();
        rev = true;
        rev_data();
    }
    T find() {
        expose();
        return all;
    }
    void update(U v) {
        expose();
        update_lazy(v);
    }
};

const int MAX = 5e4;
lct_node lct[MAX];

void build(int v, int prev, const vector < vector<int>>& G) {
    for (int to : G[v]) if (to != prev) {
        lct[to].link(&lct[v]);
        build(to, v, G);
    }
}

int main()
{
    ios::sync_with_stdio(false), cin.tie(0);
    int N, Q;
    cin >> N >> Q;
    vector < vector<int>> G(N);
    for (int i = 0; i  <  N - 1; i++) {
        int u, v;
        cin >> u >> v;
        G[u].push_back(v);
        G[v].push_back(u);
    }
    build(0, -1, G);
    while (Q--) {
        int com, u, v;
        cin >> com >> u >> v;
        if (com == 1) {
            int x;
            cin >> x;
            lct[u].evert();
            lct[v].update(make_pair(mint(x), mint(x)));
        }
        else {
            lct[u].evert();
            printf("%d\n", lct[v].find().get());
        }
    }
    return 0;
}
Copy The Code & Try With Live Editor

#3 Code Example with Java Programming

Code - Java Programming


import java.io.*;
import java.util.*;

public class Solution {

  static List < Integer>[] adj;
  static int[] chain;
  static int[] dep;
  static int[] par;

  static class NodeDfs {
    long size = 1;
    long maxs = 0;
    int u;
    int p;
    boolean start = true;
    NodeDfs nodep = null;
    
    public NodeDfs(int u, int p, NodeDfs nodep) {
      this.u = u;
      this.p = p;
      this.nodep = nodep;
    }
  } 
  
  static void dfs(int u, int p) {
    Deque < NodeDfs> deque = new LinkedList<>();
    deque.add(new NodeDfs(u, p, null));
    while (!deque.isEmpty()) {
      NodeDfs node = deque.peekLast();
      if (node.start) {
        par[node.u] = node.p;
        chain[node.u] = -1;
        for (int v: adj[node.u]) {
          if (v != node.p) {
            dep[v] = dep[node.u]+1;
            deque.add(new NodeDfs(v, node.u, node));
          }
        }
        node.start = false;
      } else {
        if (node.nodep != null) {
          node.nodep.size += node.size;
          if (node.size > node.nodep.maxs) {
            node.nodep.maxs = node.size;
            chain[node.nodep.u] = node.u;
          }          
        }
        deque.removeLast();
      }
    }
  }

  static class NodeHld {
    int u;
    int p;
    int top;
    int start = 0;
    
    public NodeHld(int u, int p, int top) {
      this.u = u;
      this.p = p;
      this.top = top;
    }
  } 
  
  static int[] dfn;
  static int tick = 0;

  static void hld(int u, int p, int top) {
    Deque < NodeHld> deque = new LinkedList<>();
    deque.add(new NodeHld(u, p, top));
    while (!deque.isEmpty()) {
      NodeHld node = deque.peekLast();
      if (node.start == 0) {
        dfn[node.u] = tick++;
        if (chain[node.u] >= 0) {
          deque.add(new NodeHld(chain[node.u], node.u, node.top));
          node.start = 1;
        } else {
          node.start = 2;
        }
      } else if (node.start == 1) {
        for (int v: adj[node.u]) {
          if (v != node.p && v != chain[node.u]) {
            deque.add(new NodeHld(v, node.u, v));
          }
        }
        node.start = 2;
      } else {
        chain[node.u] = node.top;
        deque.removeLast();
      }
    }
  }

  static class Pair {
    int first = 0;
    int second = 0;

    Pair() {
    }
    
    Pair(int first, int second) {
      this.first = first;
      this.second = second;
    }
  }
  
  static List < Pair> path(int u, int v) {
    List ps0 = new ArrayList<>();
    List ps1 = new ArrayList<>();
    while (chain[u] != chain[v]) {
      if (dep[chain[u]] > dep[chain[v]]) {
        ps0.add(new Pair(~ dfn[chain[u]], ~ (dfn[u]+1)));
        u = par[chain[u]];
      } else {
        ps1.add(new Pair(dfn[chain[v]], dfn[v]+1));
        v = par[chain[v]];
      }
    }
    if (dep[u] > dep[v]) {
      ps0.add(new Pair(~ dfn[v], ~ (dfn[u]+1)));
    } else {
      ps1.add(new Pair(dfn[u], dfn[v]+1));
    }
    for (int i = ps1.size()-1; i >= 0; i--) {
      ps0.add(ps1.get(i));
    }
    return ps0;
  }

  static final int LN = 63-Long.numberOfLeadingZeros(50000-1)+1;
  static final int NN = 1 << LN;
  static final int MOD = 1_000_000_007;
  static final int INV2 = (MOD+1)/2;

  static Pair[] ap = new Pair[2*NN];
  static long[] sum = new long[2*NN];

  static long sum(long a, long b) {
    return (a + b) % MOD;
  }
  
  static long mult(long a, long b) {
    return (a * b) % MOD;
  }

  static void apply(int i, long start, Pair x) {
    long h = LN-(63-Long.numberOfLeadingZeros(i));
    long k = 1L << h;
    long first = sum(x.first, mult((i<<h) - NN - start, x.second));
    sum[i] = sum(sum[i], mult(mult(sum(2*first, mult(k-1, x.second)), k), INV2));
    ap[i].first = (int) sum(ap[i].first, first);
    ap[i].second = (int) sum(ap[i].second, x.second);
  }

  static void untag(int i) {
    if (i  <  0 || i >= NN) {
      return;
    }
    i += NN;
    for (int j, h = LN; h > 0; h--) {
      if ((j = i >> h) > 0 && ap[j].first != 0 || ap[j].second != 0) {
        apply(2*j, (j << h) - NN, ap[j]);
        apply(2*j+1, (j << h) - NN, ap[j]);
        ap[j].first = 0;
        ap[j].second = 0;
      }
    }
  }

  static void mconcat(int i) {
    sum[i] = sum(sum[2*i], sum[2*i+1]);
  }

  static long getSum(int l, int r) {
    long s = 0;
    untag(l-1);
    untag(r);
    for (l += NN, r += NN; l  <  r; l >>= 1, r >>= 1) {
      if ((l & 1) > 0) {
        s = sum(s, sum[l++]);
      }
      if ((r & 1) > 0) {
        s = sum(s, sum[--r]);
      }
    }
    return s;
  }

  static void modify(int l, int r, Pair x) {
    int start = l;
    boolean lf = false;
    boolean rf = false;
    untag(l-1);
    untag(r);
    for (l += NN, r += NN; l  <  r; ) {
      if ((l & 1) > 0) {
        lf = true;
        apply(l++, start, x);
      }
      l >>= 1;
      if (lf) {
        mconcat(l-1);
      }
      if ((r & 1) > 0) {
        rf = true;
        apply(--r, start, x);
      }
      r >>= 1;
      if (rf) {
        mconcat(r);
      }
    }
    for (l--; (l >>= 1) > 0 && (r >>= 1) > 0; ) {
      if (lf || l == r) {
        mconcat(l);
      }
      if (rf && l != r) {
        mconcat(r);
      }
    }
  }
  
  public static void main(String[] args) throws IOException {
    BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
    BufferedWriter bw = new BufferedWriter(new FileWriter(System.getenv("OUTPUT_PATH")));

    StringTokenizer st = new StringTokenizer(br.readLine());
    int n = Integer.parseInt(st.nextToken());
    int q = Integer.parseInt(st.nextToken());
    adj = new List[n];
    for (int i = 0; i  <  n; i++) {
      adj[i] = new ArrayList<>();
    }
    for (int i = 0; i  <  n-1; i++) {
      st = new StringTokenizer(br.readLine());
      int u = Integer.parseInt(st.nextToken());
      int v = Integer.parseInt(st.nextToken());
      adj[u].add(v);
      adj[v].add(u);
    }
    chain = new int[n];
    dep = new int[n];
    par = new int[n];
    dfs(0, -1);

    dfn = new int[n];
    hld(0, -1, 0);

    for (int i = 0; i  <  ap.length; i++) {
      ap[i] = new Pair();
    }
    
    while (q-- > 0) {
      st = new StringTokenizer(br.readLine());
      int op = Integer.parseInt(st.nextToken());
      int u = Integer.parseInt(st.nextToken());
      int v = Integer.parseInt(st.nextToken());
      List < Pair> ps = path(u, v);
      if (op == 1) {
        int x = Integer.parseInt(st.nextToken());
        int y = x;
        for (Pair p: ps) {
          u = p.first;
          v = p.second;
          if (u >= 0) {
            modify(u, v, new Pair(y, x));
            y = (int) sum(y, mult(v-u, x));
          } else {
            modify(~ u, ~ v, new Pair((int)sum(y, mult(u-v-1, x)), - x));
            y = (int) sum(y, mult(u-v, x));
          }
        }
      } else {
        long ans = 0;
        for (Pair p: ps) {
          u = p.first;
          v = p.second;
          if (u  <  0) {
            u = ~ u;
            v = ~ v;
          }
          ans = sum(ans, getSum(u, v));
        }
        bw.write(sum(ans, MOD) + "\n");
      }
    }
    
    bw.newLine();
    bw.close();
    br.close();
  }
}
Copy The Code & Try With Live Editor

#4 Code Example with Python Programming

Code - Python Programming


from operator import attrgetter

MOD = 10**9 + 7

def solve(edges, queries):
    nodes, leaves = make_tree(edges)
    hld(leaves)

    results = []
    for query in queries:
        if query[0] == 1:
            update(nodes[query[1]], nodes[query[2]], query[3])
        elif query[0] == 2:
            results.append(sum_range(nodes[query[1]], nodes[query[2]]))

    return results

def make_tree(edges):
    nodes = [
        Node(i)
        for i in range(len(edges) + 1)
    ]

    # the tree is a graph for now
    # as we don't know the direction of the edges
    for edge in edges:
        nodes[edge[0]].children.append(nodes[edge[1]])
        nodes[edge[1]].children.append(nodes[edge[0]])

    # pick the root of the tree
    root = nodes[0]
    root.depth = 0

    # for each node, remove its parent of its children
    stack = []
    leaves = []
    for child in root.children:
        stack.append((child, root, 1))
    for node, parent, depth in stack:
        node.children.remove(parent)
        node.parent = parent
        node.depth = depth

        if len(node.children) == 0:
            leaves.append(node)
            continue

        for child in node.children:
            stack.append((child, node, depth + 1))

    return nodes, leaves


def hld(leaves):
    leaves = sorted(leaves, key=attrgetter('depth'), reverse=True)

    for leaf in leaves:
        leaf.chain = Chain()
        leaf.chain_i = 0

        curr_node = leaf
        while curr_node.parent is not None:
            curr_chain = curr_node.chain
            if curr_node.parent.chain is not None:
                curr_chain.init_fenwick_tree()
                curr_chain.parent = curr_node.parent.chain
                curr_chain.parent_i = curr_node.parent.chain_i
                break

            curr_node.parent.chain = curr_chain
            curr_node.parent.chain_i = curr_chain.size
            curr_node.chain.size += 1
            curr_node = curr_node.parent

        if curr_node.parent is None:
            curr_chain.init_fenwick_tree()


def update(node1, node2, x):
    path_len = 0
    chain1 = node1.chain
    chain_i1 = node1.chain_i
    depth1 = node1.depth
    chains1 = []
    chain2 = node2.chain
    chain_i2 = node2.chain_i
    depth2 = node2.depth
    chains2 = []

    while chain1 is not chain2:
        step1 = chain1.size - chain_i1
        step2 = chain2.size - chain_i2

        if depth1 - step1 > depth2 - step2:
            path_len += step1
            chains1.append((chain1, chain_i1))
            depth1 -= step1
            chain_i1 = chain1.parent_i
            chain1 = chain1.parent
        else:
            path_len += step2
            chains2.append((chain2, chain_i2))
            depth2 -= step2
            chain_i2 = chain2.parent_i
            chain2 = chain2.parent

    path_len += abs(chain_i1 - chain_i2) + 1

    curr_val1 = 0
    for (chain, chain_i) in chains1:
        chain.ftree.add(chain_i, chain.size-1, curr_val1, x)
        curr_val1 += (chain.size - chain_i) * x

    curr_val2 = (path_len + 1) * x
    for (chain, chain_i) in chains2:
        chain.ftree.add(chain_i, chain.size-1, curr_val2, -x)
        curr_val2 -= (chain.size - chain_i) * x

    if chain_i1 <= chain_i2:
        chain1.ftree.add(chain_i1, chain_i2, curr_val1, x)
    else:
        chain1.ftree.add(chain_i2, chain_i1, curr_val2, -x)


def sum_range(node1, node2):
    sum_ = 0
    chain1 = node1.chain
    chain_i1 = node1.chain_i
    depth1 = node1.depth
    chain2 = node2.chain
    chain_i2 = node2.chain_i
    depth2 = node2.depth
    while chain1 is not chain2:
        step1 = chain1.size - chain_i1
        step2 = chain2.size - chain_i2
        if depth1 - step1 > depth2 - step2:
            sum_ += chain1.ftree.range_sum(chain_i1, chain1.size - 1)

            depth1 -= step1
            chain_i1 = chain1.parent_i
            chain1 = chain1.parent
        else:
            sum_ += chain2.ftree.range_sum(chain_i2, chain2.size - 1)

            depth2 -= step2
            chain_i2 = chain2.parent_i
            chain2 = chain2.parent

    if chain_i1 > chain_i2:
        chain_i1, chain_i2 = chain_i2, chain_i1

    sum_ += chain1.ftree.range_sum(chain_i1, chain_i2)

    return int(sum_ % MOD)

class Node():
    __slots__ = ['i', 'val', 'parent', 'children', 'depth', 'chain', 'chain_i']

    def __init__(self, i):
        self.i = i
        self.val = 0
        self.parent = None
        self.depth = None
        self.children = []
        self.chain = None
        self.chain_i = -1


class Chain():
    __slots__ = ['size', 'ftree', 'parent', 'parent_i']

    def __init__(self):
        self.size = 1
        self.ftree = None
        self.parent = None
        self.parent_i = -1

    def init_fenwick_tree(self):
        self.ftree = RURQFenwickTree(self.size)

def g(i):
    return i & (i + 1)

def h(i):
    return i | (i + 1)

class RURQFenwickTree():
    def __init__(self, size):
        self.tree1 = RUPQFenwickTree(size)
        self.tree2 = RUPQFenwickTree(size)
        self.tree3 = RUPQFenwickTree(size)

    def add(self, l, r, k, x):
        k2 = k * 2
        self.tree1.add(l, x)
        self.tree1.add(r+1, -x)
        self.tree2.add(l, (3 - 2*l) * x + k2)
        self.tree2.add(r+1, -((3 - 2*l) * x + k2))
        self.tree3.add(l, (l**2 - 3*l + 2) * x + k2 * (1 - l))
        self.tree3.add(r+1, (r**2 + 3*r - 2*r*l) * x + k2 * r)

    def prefix_sum(self, i):
        sum_ = i**2 * self.tree1.point_query(i)
        sum_ += i * self.tree2.point_query(i)
        sum_ += self.tree3.point_query(i)

        return ((sum_ % (2 * MOD)) / 2) % MOD

    def range_sum(self, l, r):
        return self.prefix_sum(r) - self.prefix_sum(l - 1)

class RUPQFenwickTree():
    def __init__(self, size):
        self.size = size
        self.tree = [0] * size

    def add(self, i, x):
        j = i
        while j < self.size:
            self.tree[j] += x
            j = h(j)

    def point_query(self, i):
        res = 0
        j = i
        while j >= 0:
            res += self.tree[j]
            j = g(j) - 1

        return res

if __name__ == '__main__':
    nq = input().split()

    n = int(nq[0])

    q = int(nq[1])

    tree = []

    for _ in range(n-1):
        tree.append(list(map(int, input().rstrip().split())))

    queries = []

    for _ in range(q):
        queries.append(list(map(int, input().rstrip().split())))

    results = solve(tree, queries)

    print('\n'.join(map(str, results)))
Copy The Code & Try With Live Editor
Advertisements

Demonstration


Previous
[Solved] Number Game on a Tree solution in Hackerrank - Hacerrank solution C, C++, java,js, Python
Next
[Solved] Library Query solution in Hackerrank - Hacerrank solution C, C++, java,js, Python