Algorithm
Problem Name: Data Structures -
https://www.hackerrank.com/challenges/heavy-light-2-white-falcon/problem?isFullScreen=true
In this HackerRank in Data Structures -
White Falcon was amazed by what she can do with heavy-light decomposition on trees. As a resut, she wants to improve her expertise on heavy-light decomposition. Her teacher gave her an another assignment which requires path updates. As always, White Falcon needs your help with the assignment.
You are given a tree with N nodes and each node's value vali is initially 0.
Let's denote the path from node u to node v like this: p1,p2,p3, ... , pk, where p1 = u and pk = v and pi and p(i+1) are connected.
The problem asks you to operate the following two types of queries on the tree:
- "1 u v x" Add x to valp1, 2x to valp2 , 3x to valp3 ,.... , kx to valpk
- 2 u v" print the sum of the nodes' values on the path between u and v at modulo 10**9 + 7
Input Format
First line cosists of two integers N and Q seperated by a space.
Following N -1 lines contains two integers which denote the undirectional edges of the tree.
Following Q lines contains one of the query types described above.
Note: Nodes are numbered by using 0-based indexing.
Constraints
1 <= N,Q <= 50000
0 <= x <= 10**9 + 7
Output Format
For every query of second type print a single integer.
Sample Input
3 2
0 1
1 2
1 0 2 1
2 1 2
Sample Output
5
Code Examples
#1 Code Example with C Programming
Code -
C Programming
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
typedef struct _lnode{
int x;
int w;
struct _lnode *next;
} lnode;
typedef struct _tree{
long long sum;
long long offset1;
long long offset2;
} tree;
#define MOD 1000000007
void insert_edge(int x,int y,int w);
void dfs0(int u);
void dfs1(int u,int c);
void preprocess();
int lca(int a,int b);
long long sum(int v,int tl,int tr,int l,int r,tree *t);
void range_update(int v,int tl,int tr,int pos1,int pos2,long long o1,long long o2,tree *t);
void push(int v,int tl,int tr,tree *t);
void range_solve(int x,int y,int z);
int min(int x,int y);
int max(int x,int y);
long long solve(int x,int ancestor);
int N,cn,level[100000],DP[18][100000],subtree_size[100000],special[100000],node_chain[100000],node_idx[100000],chain_head[100000],chain_len[100000]={0};
lnode *table[100000]={0};
tree *chain[100000];
int main(){
int Q,x,y,i;
scanf("%d%d",&N,&Q);
for(i = 0; i < N-1; i++){
scanf("%d%d",&x,&y);
insert_edge(x,y,1);
}
preprocess();
while(Q--){
scanf("%d",&x);
switch(x){
case 1:
scanf("%d%d%d",&x,&y,&i);
range_solve(x,y,i);
break;
default:
scanf("%d%d",&x,&y);
i=lca(x,y);
printf("%lld\n",(solve(x,i)+solve(y,i)-sum(1,0,chain_len[node_chain[i]]-1,node_idx[i],node_idx[i],chain[node_chain[i]])+MOD)%MOD);
}
}
return 0;
}
void insert_edge(int x,int y,int w){
lnode *t=malloc(sizeof(lnode));
t->x=y;
t->w=w;
t->next=table[x];
table[x]=t;
t=malloc(sizeof(lnode));
t->x=x;
t->w=w;
t->next=table[y];
table[y]=t;
return;
}
void dfs0(int u){
lnode *x;
subtree_size[u]=1;
special[u]=-1;
for(x=table[u];x;x=x->next)
if(x->x!=DP[0][u]){
DP[0][x->x]=u;
level[x->x]=level[u]+1;
dfs0(x->x);
subtree_size[u]+=subtree_size[x->x];
if(special[u]==-1 || subtree_size[x->x]>subtree_size[special[u]])
special[u]=x->x;
}
return;
}
void dfs1(int u,int c){
lnode *x;
node_chain[u]=c;
node_idx[u]=chain_len[c]++;
for(x=table[u];x;x=x->next)
if(x->x!=DP[0][u])
if(x->x==special[u])
dfs1(x->x,c);
else{
chain_head[cn]=x->x;
dfs1(x->x,cn++);
}
return;
}
void preprocess(){
int i,j;
level[0]=0;
DP[0][0]=0;
dfs0(0);
for(i =1;i < 18; i++)
for(j = 0;j < N; j++)
DP[i][j] = DP[i-1][DP[i-1][j]];
cn=1;
chain_head[0]=0;
dfs1(0,0);
for(i = 0; i < cn; i++){
chain[i]=(tree*)malloc(4*chain_len[i]*sizeof(tree));
memset(chain[i],0,4*chain_len[i]*sizeof(tree));
}
return;
}
int lca(int a,int b>{
int i;
if(level[a]>level[b]){
i=a;
a=b;
b=i;
}
int d = level[b]-level[a];
for(i = 0; i < 18; i++)
if(d&(1<<i))
b=DP[i][b];
if(a==b>return a;
for(i=17;i>=0;i--)
if(DP[i][a]!=DP[i][b])
a=DP[i][a],b=DP[i][b];
return DP[0][a];
}
long long sum(int v,int tl,int tr,int l,int r,tree *t){
push(v,tl,tr,t);
if(l>r)
return 0;
if(l==tl && r==tr)
return t[v].sum;
int tm=(tl+tr)/2;
return (sum(v*2,tl,tm,l,min(r,tm),t)+sum(v*2+1,tm+1,tr,max(l,tm+1),r,t))%MOD;
}
void range_update(int v,int tl,int tr,int pos1,int pos2,long long o1,long long o2,tree *t){
push(v,tl,tr,t);
if(pos2 < tl || pos1>tr)
return;
int tm=(tl+tr)/2;
if(pos1 < =tl && pos2>=tr){
t[v].offset1=(o1+o2*(tl-pos1))%MOD;
t[v].offset2=o2;
}
else{
range_update(v*2,tl,tm,pos1,pos2,o1,o2,t);
range_update(v*2+1,tm+1,tr,pos1,pos2,o1,o2,t);
push(v*2,tl,tm,t);
push(v*2+1,tm+1,tr,t);
t[v].sum=(t[v*2].sum+t[v*2+1].sum)%MOD;
}
return;
}
void push(int v,int tl,int tr,tree *t){
if(!t[v].offset1 && !t[v].offset2)
return;
t[v].sum=(t[v].sum+(t[v].offset1*2+t[v].offset2*(tr-tl))*(tr-tl+1)/2%MOD)%MOD;
if(tl!=tr){
int tm=(tl+tr)/2;
t[v*2].offset1=(t[v*2].offset1+t[v].offset1)%MOD;
t[v*2+1].offset1=(t[v*2+1].offset1+t[v].offset1+t[v].offset2*(tm-tl+1))%MOD;
t[v*2].offset2=(t[v*2].offset2+t[v].offset2)%MOD;
t[v*2+1].offset2=(t[v*2+1].offset2+t[v].offset2)%MOD;
}
t[v].offset1=t[v].offset2=0;
return;
}
void range_solve(int x,int y,int z){
int ca=lca(x,y),ty=y;
long long cac=0,cay=0;
while(node_chain[x]!=node_chain[ca]){
cac+=node_idx[x]+1;
range_update(1,0,chain_len[node_chain[x]]-1,0,node_idx[x],z*cac%MOD,MOD-z,chain[node_chain[x]]);
x=DP[0][chain_head[node_chain[x]]];
}
cac+=node_idx[x]-node_idx[ca]+1;
range_update(1,0,chain_len[node_chain[x]]-1,node_idx[ca],node_idx[x],z*cac%MOD,MOD-z,chain[node_chain[x]]);
cac=z*cac%MOD;
while(node_chain[ty]!=node_chain[ca]){
cay+=node_idx[ty]+1;
ty=DP[0][chain_head[node_chain[ty]]];
}
cay+=node_idx[ty]-node_idx[ca];
cay=(cac+z*cay)%MOD;
while(node_chain[y]!=node_chain[ca]){
cay=(cay-z*(long long)node_idx[y]%MOD+MOD)%MOD;
range_update(1,0,chain_len[node_chain[y]]-1,0,node_idx[y],cay,z,chain[node_chain[y]]);
cay=(cay-z+MOD)%MOD;
y=DP[0][chain_head[node_chain[y]]];
}
cay=(cay-z*(long long)(node_idx[y]-node_idx[ca]-1)%MOD+MOD)%MOD;
if((cay-z+MOD)%MOD!=cac)
while(1);
if(node_idx[y]!=node_idx[ca])
range_update(1,0,chain_len[node_chain[y]]-1,node_idx[ca]+1,node_idx[y],cay,z,chain[node_chain[y]]);
return;
}
int min(int x,int y){
return (x<y)?x:y;
}
int max(int x,int y>{
return (x>y)?x:y;
}
long long solve(int x,int ancestor){
long long ans=0;
while(node_chain[x]!=node_chain[ancestor]){
ans=(ans+sum(1,0,chain_len[node_chain[x]]-1,0,node_idx[x],chain[node_chain[x]]))%MOD;
x=DP[0][chain_head[node_chain[x]]];
}
ans=(ans+sum(1,0,chain_len[node_chain[x]]-1,node_idx[ancestor],node_idx[x],chain[node_chain[x]]))%MOD;
return ans;
}
Copy The Code &
Try With Live Editor
#2 Code Example with C++ Programming
Code -
C++ Programming
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int mod = 1e9 + 7;
template < int MOD>
struct mod_int {
static const int Mod = MOD;
unsigned x;
mod_int() : x(0) { }
mod_int(int sig) { int sigt = sig % MOD; if (sigt < 0) sigt += MOD; x = sigt; }
mod_int(long long sig) { int sigt = sig % MOD; if (sigt < 0) sigt += MOD; x = sigt; }
int get() const { return (int)x; }
mod_int &operator+=(mod_int that) { if ((x += that.x) >= MOD) x -= MOD; return *this; }
mod_int &operator-=(mod_int that) { if ((x += MOD - that.x) >= MOD) x -= MOD; return *this; }
mod_int &operator*=(mod_int that) { x = (unsigned long long)x * that.x % MOD; return *this; }
mod_int &operator/=(mod_int that) { return *this *= that.inverse(); }
mod_int operator+(mod_int that) const { return mod_int(*this) += that; }
mod_int operator-(mod_int that) const { return mod_int(*this) -= that; }
mod_int operator*(mod_int that) const { return mod_int(*this) *= that; }
mod_int operator/(mod_int that) const { return mod_int(*this) /= that; }
mod_int inverse() const {
long long a = x, b = MOD, u = 1, v = 0;
while (b) {
long long t = a / b;
a -= t * b; swap(a, b);
u -= t * v; swap(u, v);
}
return mod_int(u);
}
};
using mint = mod_int < mod>;
struct RS {
using type = mint;
static type id() { return 0; }
static type op(const type& l, const type & r) {
return l + r;
}
};
class lct_node {
using M = RS;
using T = typename M::type;
using U = pair < mint, mint>;
lct_node *l, *r, *p;
bool rev;
T val, all;
int size;
bool flag;
U lazy;
int pos() {
if (p && p->l == this) return 1;
if (p && p->r == this) return 3;
return 0;
}
void update() {
size = (l ? l->size : 0) + (r ? r->size : 0) + 1;
all = M::op(l ? l->all : M::id(), M::op(val, r ? r->all : M::id()));
}
void update_lazy(const U& v) {
if (!flag) lazy = make_pair(0, 0);
int ls = !rev ? (l ? l->size : 0) : (r ? r->size : 0);
val += v.first + v.second * ls;
all += v.first * size + ((v.second * (size - 1)) * size) / 2;
lazy = make_pair(M::op(lazy.first, v.first), M::op(lazy.second, v.second));
flag = true;
}
void rev_data() {
lazy = make_pair(lazy.first + lazy.second * (size - 1), mint(0) - lazy.second);
}
void push() {
if (pos()) p->push();
if (rev) {
swap(l, r);
if (l) l->rev ^= true, l->rev_data();
if (r) r->rev ^= true, r->rev_data();
rev = false;
}
if (flag) {
if (l) l->update_lazy(lazy);
if (r) r->update_lazy(make_pair(lazy.first + lazy.second * (l ? l->size + 1 : 1), lazy.second));
flag = false;
}
}
void rot() {
lct_node *par = p;
lct_node *mid;
if (p->l == this) {
mid = r;
r = par;
par->l = mid;
}
else {
mid = l;
l = par;
par->r = mid;
}
if (mid) mid->p = par;
p = par->p;
par->p = this;
if (p && p->l == par) p->l = this;
if (p && p->r == par) p->r = this;
par->update();
update();
}
void splay() {
push();
while (pos()) {
int st = pos() ^ p->pos();
if (!st) p->rot(), rot();
else if (st == 2) rot(), rot();
else rot();
}
}
public:
lct_node() : l(nullptr), r(nullptr), p(nullptr), rev(false), val(M::id()), all(M::id()), size(1), flag(false) {}
void expose() {
for (lct_node *x = this, *y = nullptr; x; y = x, x = x->p) x->splay(), x->r = y, x->update();
splay();
}
void link(lct_node *x) {
x->expose();
expose();
p = x;
}
void evert() {
expose();
rev = true;
rev_data();
}
T find() {
expose();
return all;
}
void update(U v) {
expose();
update_lazy(v);
}
};
const int MAX = 5e4;
lct_node lct[MAX];
void build(int v, int prev, const vector < vector<int>>& G) {
for (int to : G[v]) if (to != prev) {
lct[to].link(&lct[v]);
build(to, v, G);
}
}
int main()
{
ios::sync_with_stdio(false), cin.tie(0);
int N, Q;
cin >> N >> Q;
vector < vector<int>> G(N);
for (int i = 0; i < N - 1; i++) {
int u, v;
cin >> u >> v;
G[u].push_back(v);
G[v].push_back(u);
}
build(0, -1, G);
while (Q--) {
int com, u, v;
cin >> com >> u >> v;
if (com == 1) {
int x;
cin >> x;
lct[u].evert();
lct[v].update(make_pair(mint(x), mint(x)));
}
else {
lct[u].evert();
printf("%d\n", lct[v].find().get());
}
}
return 0;
}
Copy The Code &
Try With Live Editor
#3 Code Example with Java Programming
Code -
Java Programming
import java.io.*;
import java.util.*;
public class Solution {
static List < Integer>[] adj;
static int[] chain;
static int[] dep;
static int[] par;
static class NodeDfs {
long size = 1;
long maxs = 0;
int u;
int p;
boolean start = true;
NodeDfs nodep = null;
public NodeDfs(int u, int p, NodeDfs nodep) {
this.u = u;
this.p = p;
this.nodep = nodep;
}
}
static void dfs(int u, int p) {
Deque < NodeDfs> deque = new LinkedList<>();
deque.add(new NodeDfs(u, p, null));
while (!deque.isEmpty()) {
NodeDfs node = deque.peekLast();
if (node.start) {
par[node.u] = node.p;
chain[node.u] = -1;
for (int v: adj[node.u]) {
if (v != node.p) {
dep[v] = dep[node.u]+1;
deque.add(new NodeDfs(v, node.u, node));
}
}
node.start = false;
} else {
if (node.nodep != null) {
node.nodep.size += node.size;
if (node.size > node.nodep.maxs) {
node.nodep.maxs = node.size;
chain[node.nodep.u] = node.u;
}
}
deque.removeLast();
}
}
}
static class NodeHld {
int u;
int p;
int top;
int start = 0;
public NodeHld(int u, int p, int top) {
this.u = u;
this.p = p;
this.top = top;
}
}
static int[] dfn;
static int tick = 0;
static void hld(int u, int p, int top) {
Deque < NodeHld> deque = new LinkedList<>();
deque.add(new NodeHld(u, p, top));
while (!deque.isEmpty()) {
NodeHld node = deque.peekLast();
if (node.start == 0) {
dfn[node.u] = tick++;
if (chain[node.u] >= 0) {
deque.add(new NodeHld(chain[node.u], node.u, node.top));
node.start = 1;
} else {
node.start = 2;
}
} else if (node.start == 1) {
for (int v: adj[node.u]) {
if (v != node.p && v != chain[node.u]) {
deque.add(new NodeHld(v, node.u, v));
}
}
node.start = 2;
} else {
chain[node.u] = node.top;
deque.removeLast();
}
}
}
static class Pair {
int first = 0;
int second = 0;
Pair() {
}
Pair(int first, int second) {
this.first = first;
this.second = second;
}
}
static List < Pair> path(int u, int v) {
List ps0 = new ArrayList<>();
List ps1 = new ArrayList<>();
while (chain[u] != chain[v]) {
if (dep[chain[u]] > dep[chain[v]]) {
ps0.add(new Pair(~ dfn[chain[u]], ~ (dfn[u]+1)));
u = par[chain[u]];
} else {
ps1.add(new Pair(dfn[chain[v]], dfn[v]+1));
v = par[chain[v]];
}
}
if (dep[u] > dep[v]) {
ps0.add(new Pair(~ dfn[v], ~ (dfn[u]+1)));
} else {
ps1.add(new Pair(dfn[u], dfn[v]+1));
}
for (int i = ps1.size()-1; i >= 0; i--) {
ps0.add(ps1.get(i));
}
return ps0;
}
static final int LN = 63-Long.numberOfLeadingZeros(50000-1)+1;
static final int NN = 1 << LN;
static final int MOD = 1_000_000_007;
static final int INV2 = (MOD+1)/2;
static Pair[] ap = new Pair[2*NN];
static long[] sum = new long[2*NN];
static long sum(long a, long b) {
return (a + b) % MOD;
}
static long mult(long a, long b) {
return (a * b) % MOD;
}
static void apply(int i, long start, Pair x) {
long h = LN-(63-Long.numberOfLeadingZeros(i));
long k = 1L << h;
long first = sum(x.first, mult((i<<h) - NN - start, x.second));
sum[i] = sum(sum[i], mult(mult(sum(2*first, mult(k-1, x.second)), k), INV2));
ap[i].first = (int) sum(ap[i].first, first);
ap[i].second = (int) sum(ap[i].second, x.second);
}
static void untag(int i) {
if (i < 0 || i >= NN) {
return;
}
i += NN;
for (int j, h = LN; h > 0; h--) {
if ((j = i >> h) > 0 && ap[j].first != 0 || ap[j].second != 0) {
apply(2*j, (j << h) - NN, ap[j]);
apply(2*j+1, (j << h) - NN, ap[j]);
ap[j].first = 0;
ap[j].second = 0;
}
}
}
static void mconcat(int i) {
sum[i] = sum(sum[2*i], sum[2*i+1]);
}
static long getSum(int l, int r) {
long s = 0;
untag(l-1);
untag(r);
for (l += NN, r += NN; l < r; l >>= 1, r >>= 1) {
if ((l & 1) > 0) {
s = sum(s, sum[l++]);
}
if ((r & 1) > 0) {
s = sum(s, sum[--r]);
}
}
return s;
}
static void modify(int l, int r, Pair x) {
int start = l;
boolean lf = false;
boolean rf = false;
untag(l-1);
untag(r);
for (l += NN, r += NN; l < r; ) {
if ((l & 1) > 0) {
lf = true;
apply(l++, start, x);
}
l >>= 1;
if (lf) {
mconcat(l-1);
}
if ((r & 1) > 0) {
rf = true;
apply(--r, start, x);
}
r >>= 1;
if (rf) {
mconcat(r);
}
}
for (l--; (l >>= 1) > 0 && (r >>= 1) > 0; ) {
if (lf || l == r) {
mconcat(l);
}
if (rf && l != r) {
mconcat(r);
}
}
}
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
BufferedWriter bw = new BufferedWriter(new FileWriter(System.getenv("OUTPUT_PATH")));
StringTokenizer st = new StringTokenizer(br.readLine());
int n = Integer.parseInt(st.nextToken());
int q = Integer.parseInt(st.nextToken());
adj = new List[n];
for (int i = 0; i < n; i++) {
adj[i] = new ArrayList<>();
}
for (int i = 0; i < n-1; i++) {
st = new StringTokenizer(br.readLine());
int u = Integer.parseInt(st.nextToken());
int v = Integer.parseInt(st.nextToken());
adj[u].add(v);
adj[v].add(u);
}
chain = new int[n];
dep = new int[n];
par = new int[n];
dfs(0, -1);
dfn = new int[n];
hld(0, -1, 0);
for (int i = 0; i < ap.length; i++) {
ap[i] = new Pair();
}
while (q-- > 0) {
st = new StringTokenizer(br.readLine());
int op = Integer.parseInt(st.nextToken());
int u = Integer.parseInt(st.nextToken());
int v = Integer.parseInt(st.nextToken());
List < Pair> ps = path(u, v);
if (op == 1) {
int x = Integer.parseInt(st.nextToken());
int y = x;
for (Pair p: ps) {
u = p.first;
v = p.second;
if (u >= 0) {
modify(u, v, new Pair(y, x));
y = (int) sum(y, mult(v-u, x));
} else {
modify(~ u, ~ v, new Pair((int)sum(y, mult(u-v-1, x)), - x));
y = (int) sum(y, mult(u-v, x));
}
}
} else {
long ans = 0;
for (Pair p: ps) {
u = p.first;
v = p.second;
if (u < 0) {
u = ~ u;
v = ~ v;
}
ans = sum(ans, getSum(u, v));
}
bw.write(sum(ans, MOD) + "\n");
}
}
bw.newLine();
bw.close();
br.close();
}
}
Copy The Code &
Try With Live Editor
#4 Code Example with Python Programming
Code -
Python Programming
from operator import attrgetter
MOD = 10**9 + 7
def solve(edges, queries):
nodes, leaves = make_tree(edges)
hld(leaves)
results = []
for query in queries:
if query[0] == 1:
update(nodes[query[1]], nodes[query[2]], query[3])
elif query[0] == 2:
results.append(sum_range(nodes[query[1]], nodes[query[2]]))
return results
def make_tree(edges):
nodes = [
Node(i)
for i in range(len(edges) + 1)
]
# the tree is a graph for now
# as we don't know the direction of the edges
for edge in edges:
nodes[edge[0]].children.append(nodes[edge[1]])
nodes[edge[1]].children.append(nodes[edge[0]])
# pick the root of the tree
root = nodes[0]
root.depth = 0
# for each node, remove its parent of its children
stack = []
leaves = []
for child in root.children:
stack.append((child, root, 1))
for node, parent, depth in stack:
node.children.remove(parent)
node.parent = parent
node.depth = depth
if len(node.children) == 0:
leaves.append(node)
continue
for child in node.children:
stack.append((child, node, depth + 1))
return nodes, leaves
def hld(leaves):
leaves = sorted(leaves, key=attrgetter('depth'), reverse=True)
for leaf in leaves:
leaf.chain = Chain()
leaf.chain_i = 0
curr_node = leaf
while curr_node.parent is not None:
curr_chain = curr_node.chain
if curr_node.parent.chain is not None:
curr_chain.init_fenwick_tree()
curr_chain.parent = curr_node.parent.chain
curr_chain.parent_i = curr_node.parent.chain_i
break
curr_node.parent.chain = curr_chain
curr_node.parent.chain_i = curr_chain.size
curr_node.chain.size += 1
curr_node = curr_node.parent
if curr_node.parent is None:
curr_chain.init_fenwick_tree()
def update(node1, node2, x):
path_len = 0
chain1 = node1.chain
chain_i1 = node1.chain_i
depth1 = node1.depth
chains1 = []
chain2 = node2.chain
chain_i2 = node2.chain_i
depth2 = node2.depth
chains2 = []
while chain1 is not chain2:
step1 = chain1.size - chain_i1
step2 = chain2.size - chain_i2
if depth1 - step1 > depth2 - step2:
path_len += step1
chains1.append((chain1, chain_i1))
depth1 -= step1
chain_i1 = chain1.parent_i
chain1 = chain1.parent
else:
path_len += step2
chains2.append((chain2, chain_i2))
depth2 -= step2
chain_i2 = chain2.parent_i
chain2 = chain2.parent
path_len += abs(chain_i1 - chain_i2) + 1
curr_val1 = 0
for (chain, chain_i) in chains1:
chain.ftree.add(chain_i, chain.size-1, curr_val1, x)
curr_val1 += (chain.size - chain_i) * x
curr_val2 = (path_len + 1) * x
for (chain, chain_i) in chains2:
chain.ftree.add(chain_i, chain.size-1, curr_val2, -x)
curr_val2 -= (chain.size - chain_i) * x
if chain_i1 <= chain_i2:
chain1.ftree.add(chain_i1, chain_i2, curr_val1, x)
else:
chain1.ftree.add(chain_i2, chain_i1, curr_val2, -x)
def sum_range(node1, node2):
sum_ = 0
chain1 = node1.chain
chain_i1 = node1.chain_i
depth1 = node1.depth
chain2 = node2.chain
chain_i2 = node2.chain_i
depth2 = node2.depth
while chain1 is not chain2:
step1 = chain1.size - chain_i1
step2 = chain2.size - chain_i2
if depth1 - step1 > depth2 - step2:
sum_ += chain1.ftree.range_sum(chain_i1, chain1.size - 1)
depth1 -= step1
chain_i1 = chain1.parent_i
chain1 = chain1.parent
else:
sum_ += chain2.ftree.range_sum(chain_i2, chain2.size - 1)
depth2 -= step2
chain_i2 = chain2.parent_i
chain2 = chain2.parent
if chain_i1 > chain_i2:
chain_i1, chain_i2 = chain_i2, chain_i1
sum_ += chain1.ftree.range_sum(chain_i1, chain_i2)
return int(sum_ % MOD)
class Node():
__slots__ = ['i', 'val', 'parent', 'children', 'depth', 'chain', 'chain_i']
def __init__(self, i):
self.i = i
self.val = 0
self.parent = None
self.depth = None
self.children = []
self.chain = None
self.chain_i = -1
class Chain():
__slots__ = ['size', 'ftree', 'parent', 'parent_i']
def __init__(self):
self.size = 1
self.ftree = None
self.parent = None
self.parent_i = -1
def init_fenwick_tree(self):
self.ftree = RURQFenwickTree(self.size)
def g(i):
return i & (i + 1)
def h(i):
return i | (i + 1)
class RURQFenwickTree():
def __init__(self, size):
self.tree1 = RUPQFenwickTree(size)
self.tree2 = RUPQFenwickTree(size)
self.tree3 = RUPQFenwickTree(size)
def add(self, l, r, k, x):
k2 = k * 2
self.tree1.add(l, x)
self.tree1.add(r+1, -x)
self.tree2.add(l, (3 - 2*l) * x + k2)
self.tree2.add(r+1, -((3 - 2*l) * x + k2))
self.tree3.add(l, (l**2 - 3*l + 2) * x + k2 * (1 - l))
self.tree3.add(r+1, (r**2 + 3*r - 2*r*l) * x + k2 * r)
def prefix_sum(self, i):
sum_ = i**2 * self.tree1.point_query(i)
sum_ += i * self.tree2.point_query(i)
sum_ += self.tree3.point_query(i)
return ((sum_ % (2 * MOD)) / 2) % MOD
def range_sum(self, l, r):
return self.prefix_sum(r) - self.prefix_sum(l - 1)
class RUPQFenwickTree():
def __init__(self, size):
self.size = size
self.tree = [0] * size
def add(self, i, x):
j = i
while j < self.size:
self.tree[j] += x
j = h(j)
def point_query(self, i):
res = 0
j = i
while j >= 0:
res += self.tree[j]
j = g(j) - 1
return res
if __name__ == '__main__':
nq = input().split()
n = int(nq[0])
q = int(nq[1])
tree = []
for _ in range(n-1):
tree.append(list(map(int, input().rstrip().split())))
queries = []
for _ in range(q):
queries.append(list(map(int, input().rstrip().split())))
results = solve(tree, queries)
print('\n'.join(map(str, results)))
Copy The Code &
Try With Live Editor