Algorithm
Problem Name: Data Structures -
In this HackerRank in Data Structures -
Taylor loves trees, and this new challenge has him stumped!
Consider a tree, t , consisting of n nodes. Each node is numbered from 1 to n and each node i has an integer, ci, attached to it.
A query on tree t takes the form w x y z
. To process a query, you must print the count of ordered pairs of integers (i,j) such that the following four conditions are all satisfied:
- i not= j
- i E the path from node w to node x.
- j E path from node y to node z.
- ci = cj
Given t and q queries, process each query in order, printing the pair count for each query on a new line.
Input Format
The first line contains two space-separated integers describing the respective values of n (the number of nodes) and q (the number of queries).
The second line contains n space-separated integers describing the respective values of each node (i.e., c1,c2, ... , cn)
Each of the n - 1 subsequent lines contains two space-separated integers, u and v defining a bidirectional edge between nodes u and v.
Each of the q subsequent lines contains a w x y z
query, defined above.
Constraints
- 1 <= n <= 10**5
- 1 <= q <= 50000
- 1 <=ci <= 10**9
- 1 <= u,v,w,x,y,z <= n
Scoring for this problem is Binary, that means you have to pass all the test cases to get a positive score.
Output Format
For each query, print the count of ordered pairs of integers satisfying the four given conditions on a new line.
Sample Input
10 5
10 2 3 5 10 5 3 6 2 1
1 2
1 3
3 4
3 5
3 6
4 7
5 8
7 9
2 10
8 5 2 10
3 8 4 9
1 9 5 9
4 6 4 6
5 8 5 8
Sample Output
0
1
3
2
0
Code Examples
#1 Code Example with C Programming
Code -
C Programming
#include <stdio.h>
#include <string.h>
#include <math.h>
#include <stdlib.h>
#define floor_log2_X86(self) (__builtin_clz(self) ^ 31U)
#define floor_log2 floor_log2_X86
void heap_sort(unsigned *self, unsigned *weights, unsigned length) {
unsigned
at = length >> 1,
member,
node;
for (self--; at; self[node >> 1] = member) {
member = self[at];
for (node = at-- << 1; node < = length; node <<= 1) {
node |= (node < length) && (weights[self[node]] < weights[self[node | 1]]);
if (weights[self[node]] < weights[member])
break ;
self[node >> 1] = self[node];
}
}
for (; length > 1; self[at >> 1] = member) {
member = self[length];
self[length--] = self[1];
for (at = 2; at < = length; at <<= 1) {
at |= (at < length) && (weights[self[at]] < weights[self[at | 1]]);
if (weights[self[at]] < weights[member])
break ;
self[at >> 1] = self[at];
}
}
}
void compress(unsigned length, unsigned values[length]) {
unsigned
at,
order[length];
unsigned long sum = 0x0000000100000000UL;
for (at = 0; at < (length >> 1); sum += 0x0000000200000002UL)
((unsigned long *)order)[at++] = sum;
order[length - 1] = length - 1;
heap_sort(order, values, length);
unsigned roots[length], seen = 1, max = 0, others;
for (roots[at = 0] = -1U; at < length; roots[seen++] = at - 1) {
for (others = at; (at < length) && values[order[at]] == values[order[others]]; at++);
if (max < (at - others))
max = (at - others);
}
unsigned
indices[max + 1],
ranks[seen];
memset(indices, 0, sizeof(indices));
for (at = 0; ++at < seen; indices[roots[at] - roots[at - 1]]++);
for (at = max; at--; indices[at] += indices[at + 1]);
for (at = seen; --at; ranks[--indices[roots[at] - roots[at - 1]]] = at);
for (; at < (seen - 1); at++)
for (others = roots[ranks[at] - 1]; ++others < = roots[ranks[at]]; values[order[others]] = at);
}
static inline unsigned nearest_common_ancestor(
unsigned depth,
unsigned base_cnt,
unsigned vertex_cnt,
unsigned base_ids[vertex_cnt],
unsigned bases[base_cnt][depth],
unsigned char depths[base_cnt],
unsigned weights[vertex_cnt],
unsigned lower,
unsigned upper
) {
if (upper < (lower + weights[lower]))
return lower;
if (depths[upper] > depths[lower])
upper = bases[base_ids[upper]][depths[upper] - depths[lower] - 1];
if (upper < lower)
return upper;
unsigned *others = bases[base_ids[upper]];
for (; depth > 1; depth >>= 1)
if (others[depth >> 1] > lower) {
others += depth >> 1;
depth += depth & 1U;
}
return others[others[0] > lower];
}
typedef union {
unsigned long packd;
struct {
int low, high;
};
} range_t;
typedef struct {
unsigned
*members,
*colors,
*indices,
*locations;
} colored_tree_t;
unsigned long query_all(colored_tree_t *self, unsigned at, range_t other) {
unsigned
color = self->colors[at],
length = self->indices[color + 1] - self->indices[color],
*base = &self->members[self->indices[color]];
if (other.high < base[0] || other.low > base[length - 1])
return 0;
if (self->colors[other.low] != color) {
if (at < other.low) {
base += self->locations[at] - self->indices[color];
length = self->indices[color + 1] - self->locations[at];
} else
length = self->locations[at] - self->indices[color]; // at > other.low
for (; length > 1; length >>= 1)
if (base[length >> 1] < other.low) {
base += length >> 1;
length += length & 1;
}
base += (base[0] < other.low);
} else
base += (self->locations[other.low] - self->indices[color]);
if (base[0] > other.high)
return 0;
unsigned *ceil;
if (self->colors[other.high] != color) {
ceil = (at > base[0] && at < other.high) ? &self->members[self->locations[at]] : base;
for (length = self->indices[color + 1] - self->locations[ceil[0]]; length > 1; length >>= 1)
if (ceil[length >> 1] < = other.high) {
ceil += length >> 1;
length += length & 1;
}
ceil -= (ceil[0] > other.high);
} else
ceil = &self->members[self->locations[other.high]];
return ceil - base + 1 - (at >= other.low && at < = other.high);
}
unsigned long count_pairs(
unsigned cnt,
unsigned length,
unsigned long pairs[cnt][cnt],
unsigned *overlapping,
colored_tree_t *tree,
range_t self,
range_t other
) {
unsigned long count = 0;
for (; (self.low % length) && (self.low < = self.high); count += query_all(tree, self.low++, other));
for (; ((self.high + 1) % length) && (self.low < = self.high); count += query_all(tree, self.high--, other));
if (self.low < = self.high) {
for (; (other.low % length) && (other.low <= other.high); count += query_all(tree, other.low++, self));
for (; ((other.high + 1) % length) && (other.low < = other.high); count += query_all(tree, other.high--, self));
if (other.low < = other.high) {
self.low /= length;
self.high /= length;
other.low /= length;
other.high /= length;
if (self.low > other.low) {
self.packd ^= other.packd;
other.packd ^= self.packd;
self.packd ^= other.packd;
}
unsigned high = (self.high < other.low) ? self.high : (other.low - 1);
count +=
pairs[high][other.high]
- pairs[high][other.low - 1UL]
- pairs[self.low - 1UL][other.high]
+ pairs[self.low - 1UL][other.low - 1UL];
self.low = high + 1;
if (self.high > other.high) {
self.packd ^= other.packd;
other.packd ^= self.packd;
self.packd ^= other.packd;
}
if (self.low < = self.high)
count +=
(overlapping[self.high] - overlapping[self.low - 1UL])
+ ((
pairs[self.high][self.high]
- pairs[self.high][self.low - 1UL]
- pairs[self.low - 1UL][self.high]
+ pairs[self.low - 1UL][self.low - 1UL]
) << 1) + (
pairs[self.high][other.high]
- pairs[self.high][self.high]
- pairs[self.low - 1UL][other.high]
+ pairs[self.low - 1UL][self.high]
);
}
}
return count;
}
int main() {
unsigned at, vertex_cnt;
unsigned short query_cnt;
scanf("%u %hu", &vertex_cnt, &query_cnt);
unsigned colors[vertex_cnt + 1];
for (at = 0; at < vertex_cnt; scanf("%u", &colors[at++]));
colors[at] = 0xFFFFFFFFU;
compress(at + 1, colors);
unsigned ancestors[at + 1];
{
unsigned ancestor, descendant;
for (memset(ancestors, 0xFFU, sizeof(ancestors)); --at; ancestors[descendant] = ancestor) {
scanf("%u %u", &ancestor, &descendant);
--ancestor;
if (ancestors[--descendant] != 0xFFFFFFFFU) {
unsigned root = descendant, next;
for (; ancestor != 0xFFFFFFFFU; ancestor = next) {
next = ancestors[ancestor];
ancestors[ancestor] = root;
root = ancestor;
}
for (; ancestors[descendant] != 0xFFFFFFFFU; descendant = next) {
next = ancestors[descendant];
ancestors[descendant] = ancestor;
ancestor = descendant;
}
}
}
for (ancestor = 0xFFFFFFFFU; at != 0xFFFFFFFFU; at = descendant) {
descendant = ancestors[at];
ancestors[at] = ancestor;
ancestor = at;
}
}
unsigned
others,
ids[vertex_cnt + 1],
weights[vertex_cnt],
bases[vertex_cnt + 1],
history[vertex_cnt];
unsigned char
base_depths[vertex_cnt],
dist = 0;
{
unsigned
history[vertex_cnt],
indices[vertex_cnt + 1],
descendants[vertex_cnt];
memset(indices, 0, sizeof(indices));
for (ancestors[vertex_cnt] = (at = vertex_cnt); at; indices[ancestors[at--]]++);
for (; ++at < = vertex_cnt; indices[at] += indices[at - 1]);
for (; --at; descendants[--indices[ancestors[at]]] = at);
history[0] = 0;
memset(weights, 0, sizeof(weights));
for (at = 1; at--; )
if (weights[history[at]])
for (others = indices[history[at]];
others < indices[history[at] + 1];
weights[history[at]] += weights[descendants[others++]]);
else {
weights[history[at]] = 1;
memcpy(
&history[at + 1],
&descendants[indices[history[at]]],
(indices[history[at] + 1] - indices[history[at]]) * sizeof(descendants[0])
);
at += indices[history[at] + 1] - indices[history[at]] + 1;
}
unsigned
orig_ancestors[vertex_cnt + 1],
orig_colors[vertex_cnt + 1],
orig_weights[vertex_cnt];
memcpy(orig_ancestors, ancestors, sizeof(ancestors));
memcpy(orig_weights, weights, sizeof(weights));
memcpy(orig_colors, colors, sizeof(colors));
base_depths[0] = (bases[0] = (ids[0] = 0));
bases[vertex_cnt] = (ids[vertex_cnt] = vertex_cnt);
for (at = 1; at--;) {
unsigned
id = ids[history[at]],
base = bases[id++],
branches = indices[history[at] + 1] - indices[history[at]];
heap_sort(&descendants[indices[history[at]]], orig_weights, branches);
memcpy(&history[at], &descendants[indices[history[at]]], branches * sizeof(descendants[0]));
for (others = (at += branches); branches--; base = id) {
ids[history[--others]] = id;
ancestors[id] = ids[orig_ancestors[history[others]]];
weights[id] = orig_weights[history[others]];
colors[id] = orig_colors[history[others]];
bases[id] = base;
base_depths[id] = base_depths[ancestors[id]] + (base == id);
if (dist < base_depths[id])
dist = base_depths[id];
id += weights[id];
}
}
}
unsigned base_ids[vertex_cnt + 1];
for (base_ids[0] = (others = (at = 0)); others < vertex_cnt; base_ids[others] = base_ids[at] + 1)
for (at = others; bases[at] == bases[others]; base_ids[others++] = base_ids[at]);
unsigned ancestral_bases[base_ids[vertex_cnt]][dist];
for (ancestors[0] = 0; others--; ancestral_bases[base_ids[others]][0] = ancestors[others]);
while ((++others + 1) < dist)
for (at = 0; ++at < base_ids[vertex_cnt];
ancestral_bases[at][others + 1] = ancestors[bases[ancestral_bases[at][others]]]);
unsigned
indexed_colors[colors[vertex_cnt] + 2],
members[vertex_cnt + 1];
memset(indexed_colors, 0, sizeof(indexed_colors));
for (at = vertex_cnt + 1; at--; indexed_colors[colors[at]]++);
for (; ++at < colors[vertex_cnt]; indexed_colors[at + 1] += indexed_colors[at]);
for (at = vertex_cnt + 1; at--; members[--indexed_colors[colors[at]]] = at);
indexed_colors[colors[vertex_cnt] + 1] = indexed_colors[colors[vertex_cnt]];
unsigned
levels = floor_log2(vertex_cnt) + 1,
block_cnt = (vertex_cnt / levels) + 1,
locations[vertex_cnt + 1],
overlapping[block_cnt];
unsigned long (*pairs)[block_cnt][block_cnt] = calloc(
(1 + block_cnt) * (1 + block_cnt),
sizeof(pairs[0][0][0])
);
pairs = (void *)&pairs[0][1][1];
for (at = vertex_cnt + 1; at--; locations[members[at]] = at);
memset(overlapping, 0, sizeof(overlapping));
for (at = 0; (indexed_colors[at + 1] - indexed_colors[at]) > 1; at++) {
others = indexed_colors[at];
unsigned
block_bases[indexed_colors[at + 1] - others + 1],
cnt = 1;
for (block_bases[0] = members[others]; at == colors[members[++others]]; )
if ((members[others] / levels) != (block_bases[cnt - 1] / levels))
block_bases[cnt++] = members[others];
block_bases[cnt] = members[others];
for (others = 0; others < cnt; others++) {
unsigned long density = locations[block_bases[others + 1]] - locations[block_bases[others]];
overlapping[block_bases[others] / levels] += density * (density - 1);
unsigned block = others;
for (; ++block < cnt; pairs[0][block_bases[others] / levels][block_bases[block] / levels]
+= density * (locations[block_bases[block + 1]] - locations[block_bases[block]]));
}
}
for (at = 0; ++at < block_cnt; overlapping[at] += overlapping[at - 1])
pairs[0][0][at] += pairs[0][0][at - 1];
for (at = 0; ++at < block_cnt; )
for (others = 0; ++others < block_cnt; pairs[0][at][others] += pairs[0][at][others - 1]);
for (at = 0; ++at < block_cnt; )
for (others = 0; others < block_cnt; others++)
pairs[0][at][others] += pairs[0][at - 1][others];
colored_tree_t *tree = &(colored_tree_t) {
.members = members,
.colors = colors,
.indices = indexed_colors,
.locations = locations
};
while (query_cnt--) {
range_t left, right;
scanf("%u %u %u %u", &left.low, &left.high, &right.low, &right.high);
left.packd -= 0x0000000100000001UL;
right.packd -= 0x0000000100000001UL;
left.low = ids[left.low];
left.high = ids[left.high];
right.low = ids[right.low];
right.high = ids[right.high];
if (left.high < left.low)
left.packd = (left.packd << 32) | (left.packd >> 32);
if (right.high < right.low)
right.packd = (right.packd << 32) | (right.packd >> 32);
if (right.high < left.low) {
left.packd ^= right.packd;
right.packd ^= left.packd;
left.packd ^= right.packd;
}
struct {
range_t members[32];
unsigned cnt;
}
a = {.cnt = 0},
b = {.cnt = 0};
unsigned common = nearest_common_ancestor(
dist, base_ids[vertex_cnt], vertex_cnt,
base_ids, ancestral_bases,
base_depths, weights,
left.low, left.high
);
for (at = left.low; bases[at] != bases[common]; at = ancestral_bases[base_ids[at]][0])
a.members[a.cnt++].packd = bases[at] | ((unsigned long)at << 32);
for (others = left.high; bases[others] != bases[common]; others = ancestral_bases[base_ids[others]][0])
a.members[a.cnt++].packd = bases[others] | ((unsigned long)others << 32);
a.members[a.cnt++].packd = common | ((unsigned long)((at != common) ? at : others) << 32);
common = nearest_common_ancestor(
dist, base_ids[vertex_cnt], vertex_cnt,
base_ids, ancestral_bases,
base_depths, weights,
right.low, right.high
);
for (at = right.low; bases[at] != bases[common]; at = ancestral_bases[base_ids[at]][0])
b.members[b.cnt++].packd = bases[at] | ((unsigned long)at << 32);
for (others = right.high; bases[others] != bases[common]; others = ancestral_bases[base_ids[others]][0])
b.members[b.cnt++].packd = bases[others] | ((unsigned long)others << 32);
b.members[b.cnt++].packd = common | ((unsigned long)((at != common) ? at : others) << 32);
unsigned long total = 0;
for (at = 0; at < a.cnt; at++)
for (others = 0; others < b.cnt;
total += count_pairs(
block_cnt, levels, pairs[0], overlapping, tree,
a.members[at], b.members[others++]
)
);
printf("%lu\n", total);
}
return 0;
}
Copy The Code &
Try With Live Editor
#2 Code Example with C++ Programming
Code -
C++ Programming
#include <cstdlib>
#include <cstdio>
#include <iostream>
#include <cmath>
#include <algorithm>
#include <vector>
#include <set>
#include <map>
#include <cstring>
#include <cassert>
using namespace std;
typedef long long LL;
typedef unsigned long long ULL;
#define SIZE(x) (int((x).size()))
#define rep(i,l,r) for (int i=(l); i < =(r); i++)
#define repd(i,r,l) for (int i=(r); i>=(l); i--)
#define rept(i,c) for (__typeof((c).begin()) i=(c).begin(); i!=(c).end(); i++)
#ifndef ONLINE_JUDGE
#define debug(x) { cerr<<#x<<" = "<<(x)<<endl; }
#else
#define debug(x) {}
#endif
#define maxn 100010
#define LIM 100
int ta[maxn];
void ta_modify(int x, int y)
{
while (x < maxn) ta[x]+=y, x+=x&-x;
}
int ta_query(int x)
{
int ret=0;
while (x) ret+=ta[x], x-=x&-x;
return ret;
}
void ds_modify(int l, int r, int c)
{
ta_modify(l,c);
ta_modify(r+1,-c);
}
int ds_query(int v)
{
return ta_query(v);
}
int dfsN;
int dfsLeft[maxn], dfsRight[maxn];
int lg2[maxn], p[maxn][17], depth[maxn];
vector<int> e[maxn];
void dfs(int cur, int pre, int dep)
{
dfsN++; dfsLeft[cur]=dfsN;
depth[cur]=dep;
p[cur][0]=pre;
rep(i,1,lg2[dep]) p[cur][i]=p[p[cur][i-1]][i-1];
rept(it,e[cur]) if (*it!=pre) dfs(*it,cur,dep+1);
dfsRight[cur]=dfsN;
}
int movedep(int x, int y)
{
if (y < 0) return 0;
while (y) x=p[x][lg2[y&-y]], y-=y&-y;
return x;
}
int lca(int x, int y)
{
if (depth[x] < depth[y]) swap(x,y);
x=movedep(x,depth[x]-depth[y]);
repd(i,16,0)
if (p[x][i]!=p[y][i])
{
x=p[x][i]; y=p[y][i];
}
if (x==y) return x;
return p[x][0];
}
int get_dist(int x, int y)
{
int z=lca(x,y);
return depth[x]+depth[y]-2*depth[z]+1;
}
int all, ti[5][2];
void check_intersect(int p1, int p2, int q1, int q2)
{
if (depth[p2]>depth[q2])
{
swap(p1,q1); swap(p2,q2);
}
if (depth[p1] < depth[q2]) return;
if (lca(p1,q2)!=q2 || lca(q2,p2)!=p2) return;
int z=lca(p1,q1);
rep(i,1,all) if (ti[i][0]==z && ti[i][1]==q2) return;
rep(i,1,all) if (ti[i][1]==z && ti[i][0]==q2) return;
//if (z==q2) rep(i,1,all) if (ti[i][0]==z || ti[i][1]==z) return;
all++; ti[all][0]=z; ti[all][1]=q2;
}
struct tasktype
{
int x, y, c;
tasktype() {}
tasktype(int x, int y, int c): x(x), y(y), c(c) {}
};
vector < tasktype> eventAddList[maxn], eventQueryList[maxn];
void addQueryEvent(int i, int p1, int q1, int c)
{
p1=dfsLeft[p1]; q1=dfsLeft[q1];
eventQueryList[p1].push_back(tasktype(q1,i,c));
}
void addContributionEvent(int p1, int p2, int q1, int q2)
{
eventAddList[p1].push_back(tasktype(q1,q2,1));
eventAddList[p2+1].push_back(tasktype(q1,q2,-1));
}
void add_task(int i, int p1, int p2, int q1, int q2)
{
if (!p1 || !p2 || !q1 || !q2) return;
addQueryEvent(i,p1,q1,1);
if (p[q2][0]) addQueryEvent(i,p1,p[q2][0],-1);
if (p[p2][0]) addQueryEvent(i,p[p2][0],q1,-1);
if (p[p2][0] && p[q2][0]) addQueryEvent(i,p[p2][0],p[q2][0],1);
}
map < int, vector<int> > clist;
int color[maxn];
int q[maxn][6];
int ans[maxn];
void lemon()
{
lg2[1]=0; rep(i,2,maxn-1) lg2[i]=lg2[i>>1]+1;
int n,qa; scanf("%d%d",&n,&qa);
rep(i,1,n)
{
scanf("%d",&color[i]);
clist[color[i]].push_back(i);
}
rep(i,1,n-1)
{
int x,y; scanf("%d%d",&x,&y);
e[x].push_back(y);
e[y].push_back(x);
}
rep(i,1,qa)
{
scanf("%d%d%d%d",&q[i][0],&q[i][1],&q[i][2],&q[i][3]);
}
dfsN=0;
dfs(1,0,0);
rep(i,1,qa) ans[i]=0;
rep(i,1,qa)
{
all=0;
int z1=lca(q[i][0],q[i][1]);
int z2=lca(q[i][2],q[i][3]);
q[i][4]=z1; q[i][5]=z2;
check_intersect(q[i][0],z1,q[i][2],z2);
check_intersect(q[i][0],z1,q[i][3],z2);
check_intersect(q[i][1],z1,q[i][2],z2);
check_intersect(q[i][1],z1,q[i][3],z2);
int t1=movedep(q[i][1],depth[q[i][1]]-depth[z1]-1);
int t2=movedep(q[i][3],depth[q[i][3]]-depth[z2]-1);
add_task(i,q[i][0],z1,q[i][2],z2);
add_task(i,q[i][0],z1,q[i][3],t2);
add_task(i,q[i][1],t1,q[i][2],z2);
add_task(i,q[i][1],t1,q[i][3],t2);
if (all>0)
{
rep(k,1,all)
ans[i]-=get_dist(ti[k][0],ti[k][1]);
ans[i]+=all-1;
}
}
rept(it,clist)
{
int cl=it->first;
if (it->second.size() < =LIM)
{
int s=it->second.size();
rep(i,0,s-1)
rep(j,0,s-1)
{
int i1=it->second[i], j1=it->second[j];
addContributionEvent(dfsLeft[i1], dfsRight[i1], dfsLeft[j1], dfsRight[j1]);
}
}
else
{
rept(it2,it->second)
ds_modify(dfsLeft[*it2],dfsRight[*it2],1);
rep(i,1,qa)
{
int x1=ds_query(dfsLeft[q[i][0]])+ds_query(dfsLeft[q[i][1]])-2*ds_query(dfsLeft[q[i][4]]);
if (color[q[i][4]]==cl) x1++;
//printf("%d: %d %d %d\n",cl,q[i][0],q[i][1],x1);
int x2=ds_query(dfsLeft[q[i][2]])+ds_query(dfsLeft[q[i][3]])-2*ds_query(dfsLeft[q[i][5]]);
if (color[q[i][5]]==cl) x2++;
//printf("%d: %d %d %d\n",cl,q[i][2],q[i][3],x2);
ans[i]+=x1*x2;
}
rept(it2,it->second)
ds_modify(dfsLeft[*it2],dfsRight[*it2],-1);
}
}
rep(i,1,n)
{
rept(it,eventAddList[i]) ds_modify(it->x,it->y,it->c);
rept(it,eventQueryList[i]) ans[it->y]+=it->c*ds_query(it->x);
}
rep(i,1,qa) printf("%d\n",ans[i]);
}
int main()
{
ios::sync_with_stdio(true);
#ifndef ONLINE_JUDGE
//freopen("8.in","r",stdin);
#endif
lemon();
return 0;
}
Copy The Code &
Try With Live Editor
#3 Code Example with Java Programming
Code -
Java Programming
import java.io.*;
import java.util.*;
public class Solution {
static int[] nxt;
static int[] succ;
static int[] ptr;
static int index = 1;
static void addEdge(int u, int v) {
nxt[index] = ptr[u];
ptr[u] = index;
succ[index++] = v;
}
static int timer = 0;
static int[] tin;
static int[] tout;
static int[] pr;
static int[][] up;
static int[] w;
static class NodeDfs {
int v;
int lvl;
int p;
boolean start = true;
public NodeDfs(int v, int lvl, int p) {
this.v = v;
this.lvl = lvl;
this.p = p;
}
}
static void dfs() {
Deque < NodeDfs> q = new LinkedList<>();
q.add(new NodeDfs(1, 0, 1));
while (!q.isEmpty()) {
NodeDfs node = q.peekLast();
if (node.start) {
tin[node.v] = timer;
w[node.v] = node.lvl;
up[node.v][0] = node.p;
for (int i = 1; i < = 17; i++) {
up[node.v][i] = up[up[node.v][i - 1]][i - 1];
}
for (int i = ptr[node.v]; i > 0; i = nxt[i]) {
int to = succ[i];
if (to != node.p) {
q.add(new NodeDfs(to, node.lvl + 1, node.v));
pr[to] = node.v;
}
}
node.start = false;
} else {
tout[node.v] = timer++;
q.removeLast();
}
}
}
static boolean upper(int x, int y) {
return tout[x] >= tout[y] && tin[x] < = tin[y];
}
static int lca(int a, int b) {
if (upper(a, b))
return a;
if (upper(b, a))
return b;
for (int i = 17; i >= 0; --i)
if (!upper(up[a][i], b))
a = up[a][i];
return up[a][0];
}
static void normalize(int[] a) {
Map < Integer, Integer> trans = new HashMap<>();
int j = 0;
for (int i = 1; i < a.length; i++) {
if (!trans.containsKey(a[i])) {
trans.put(a[i], j++);
}
}
for (int i = 1; i < a.length; i++) {
a[i] = trans.get(a[i]);
}
}
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
BufferedWriter bw = new BufferedWriter(new FileWriter(System.getenv("OUTPUT_PATH")));
StringTokenizer st = new StringTokenizer(br.readLine());
int n = Integer.parseInt(st.nextToken());
int m = Integer.parseInt(st.nextToken());
int[] a = new int[n + 1];
st = new StringTokenizer(br.readLine());
for (int i = 1; i < = n; i++) {
a[i] = Integer.parseInt(st.nextToken());
}
normalize(a);
nxt = new int[n + 1];
succ = new int[n + 1];
ptr = new int[n + 1];
for (int i = 0; i < n - 1; i++) {
st = new StringTokenizer(br.readLine());
int u = Integer.parseInt(st.nextToken());
int v = Integer.parseInt(st.nextToken());
if (u < v) {
addEdge(u, v);
} else {
addEdge(v, u);
}
}
tin = new int[n + 1];
tout = new int[n + 1];
pr = new int[n + 1];
up = new int[n + 1][20];
w = new int[n + 1];
dfs();
int[] d = new int[m];
int[] X1 = new int[m];
int[] Y1 = new int[m];
int[] X2 = new int[m];
int[] Y2 = new int[m];
for (int i = 0; i < m; i++) {
st = new StringTokenizer(br.readLine());
int x1 = Integer.parseInt(st.nextToken());
int y1 = Integer.parseInt(st.nextToken());
int x2 = Integer.parseInt(st.nextToken());
int y2 = Integer.parseInt(st.nextToken());
X1[i] = x1;
Y1[i] = y1;
X2[i] = x2;
Y2[i] = y2;
int lc1 = lca(x1, y1);
int lc2 = lca(x2, y2);
int ans1 = 0;
if (lc1 == lc2) {
int lc3 = lca(x1, x2);
int lc4 = lca(x1, y2);
int lc5 = lca(y1, x2);
int lc6 = lca(y1, y2);
ans1++;
ans1 += w[lc3] - w[lc1];
ans1 += w[lc4] - w[lc1];
ans1 += w[lc5] - w[lc1];
ans1 += w[lc6] - w[lc1];
} else if (w[lc1] < w[lc2]) {
int lc3 = lca(x1, x2);
int lc4 = lca(x1, y2);
int lc5 = lca(y1, x2);
int lc6 = lca(y1, y2);
if (upper(lc2, x1) && upper(lc1, lc2)) {
ans1 += Math.abs(w[lc3] - w[lc4]) + 1;
}
if (upper(lc2, y1) && upper(lc1, lc2)) {
ans1 += Math.abs(w[lc5] - w[lc6]) + 1;
}
} else if (w[lc1] > w[lc2]) {
int lc3 = lca(x1, x2);
int lc4 = lca(x1, y2);
int lc5 = lca(y1, x2);
int lc6 = lca(y1, y2);
if (upper(lc1, x2) && upper(lc2, lc1)) {
ans1 += Math.abs(w[lc3] - w[lc5]) + 1;
}
if (upper(lc1, y2) && upper(lc2, lc1)) {
ans1 += Math.abs(w[lc4] - w[lc6]) + 1;
}
}
d[i] = ans1;
}
int[] b = new int[n + 1];
for (int i = 0; i < m; i++) {
int x1 = X1[i];
int y1 = Y1[i];
int x2 = X2[i];
int y2 = Y2[i];
int x3 = x1;
int y3 = y1;
while (!upper(x1, y1)) {
++b[a[x1]];
x1 = pr[x1];
}
while (y1 != x1) {
++b[a[y1]];
y1 = pr[y1];
}
b[a[x1]]++;
int ans = 0;
while (!upper(x2, y2)) {
ans += b[a[x2]];
x2 = pr[x2];
}
while (y2 != x2) {
ans += b[a[y2]];
y2 = pr[y2];
}
ans += b[a[x2]];
int tmp = x1;
x1 = x3;
y1 = y3;
while (x1 != tmp) {
b[a[x1]] = 0;
x1 = pr[x1];
}
while (y1 != tmp) {
b[a[y1]] = 0;
y1 = pr[y1];
}
b[a[x1]] = 0;
bw.write((ans - d[i]) + "\n");
}
bw.newLine();
bw.close();
br.close();
}
}
Copy The Code &
Try With Live Editor