Taylor loves trees, and this new challenge has him stumped!

Consider a tree, t , consisting of n nodes. Each node is numbered from 1 to n and each node i has an integer, ci, attached to it.

A query on tree t takes the form w x y z. To process a query, you must print the count of ordered pairs of integers (i,j) such that the following four conditions are all satisfied:

  • i not= j
  • i E the path from node w to node x.
  • j E path from node y to node z.
  • ci = cj

Given t and q queries, process each query in order, printing the pair count for each query on a new line.

Input Format

The first line contains two space-separated integers describing the respective values of n (the number of nodes) and q (the number of queries).
The second line contains n space-separated integers describing the respective values of each node (i.e., c1,c2, ... , cn)

Each of the n - 1 subsequent lines contains two space-separated integers, u and v defining a bidirectional edge between nodes u and v.

Each of the q subsequent lines contains a w x y z query, defined above.


  • 1 <= n <= 10**5
  • 1 <= q <= 50000
  • 1 <=ci <= 10**9
  • 1 <= u,v,w,x,y,z <= n

Scoring for this problem is Binary, that means you have to pass all the test cases to get a positive score.

Output Format

For each query, print the count of ordered pairs of integers satisfying the four given conditions on a new line.

Sample Input

10 5
10 2 3 5 10 5 3 6 2 1
1 2
1 3
3 4
3 5
3 6
4 7
5 8
7 9
2 10
8 5 2 10
3 8 4 9
1 9 5 9
4 6 4 6
5 8 5 8

Sample Output







Code Examples

C Programming

C Programming

#include <stdio.h>
#include <string.h>
#include <math.h>
#include <stdlib.h>

#define floor_log2_X86(self) (__builtin_clz(self) ^ 31U)
#define floor_log2 floor_log2_X86

void heap_sort(unsigned *self, unsigned *weights, unsigned length) {
        at = length >> 1,

    for (self--; at; self[node >> 1] = member) {
        member = self[at];

        for (node = at-- << 1; node  < = length; node <<= 1) {
            node |= (node  <  length) && (weights[self[node]] < weights[self[node | 1]]);
            if (weights[self[node]] < weights[member])
                break ;
            self[node >> 1] = self[node];
    for (; length > 1; self[at >> 1] = member) {
        member = self[length];
        self[length--] = self[1];

        for (at = 2; at  < = length; at <<= 1) {
            at |= (at  <  length) && (weights[self[at]] < weights[self[at | 1]]);
            if (weights[self[at]] < weights[member])
                break ;
            self[at >> 1] = self[at];

void compress(unsigned length, unsigned values[length]) {

    unsigned long sum = 0x0000000100000000UL;
    for (at = 0; at  <  (length >> 1); sum += 0x0000000200000002UL)
        ((unsigned long *)order)[at++] = sum;
    order[length - 1] = length - 1;

    heap_sort(order, values, length);

    unsigned roots[length], seen = 1, max = 0, others;
    for (roots[at = 0] = -1U; at  <  length; roots[seen++] = at - 1) {
        for (others = at; (at  <  length) && values[order[at]] == values[order[others]]; at++);

        if (max  <  (at - others))
            max = (at - others);

        indices[max + 1],

    memset(indices, 0, sizeof(indices));
    for (at = 0; ++at  <  seen; indices[roots[at] - roots[at - 1]]++);
    for (at = max; at--; indices[at] += indices[at + 1]);
    for (at = seen; --at; ranks[--indices[roots[at] - roots[at - 1]]] = at);

    for (; at  <  (seen - 1); at++)
        for (others = roots[ranks[at] - 1]; ++others  < = roots[ranks[at]]; values[order[others]] = at);

static inline unsigned nearest_common_ancestor(
    unsigned depth,
    unsigned base_cnt,
    unsigned vertex_cnt,
    unsigned base_ids[vertex_cnt],
    unsigned bases[base_cnt][depth],
    unsigned char depths[base_cnt],
    unsigned weights[vertex_cnt],
    unsigned lower,
    unsigned upper
) {
    if (upper  <  (lower + weights[lower]))
        return lower;

    if (depths[upper] > depths[lower])
        upper = bases[base_ids[upper]][depths[upper] - depths[lower] - 1];

    if (upper  <  lower)
        return upper;

    unsigned *others = bases[base_ids[upper]];
    for (; depth > 1; depth >>= 1)
        if (others[depth >> 1] > lower) {
            others += depth >> 1;
            depth += depth & 1U;

    return others[others[0] > lower];

typedef union {
    unsigned long packd;
    struct {
        int low, high;
} range_t;

typedef struct {
} colored_tree_t;

unsigned long query_all(colored_tree_t *self, unsigned at, range_t other) {
        color = self->colors[at],
        length = self->indices[color + 1] - self->indices[color],
        *base = &self->members[self->indices[color]];

    if (other.high  <  base[0] || other.low > base[length - 1])
        return 0;

    if (self->colors[other.low] != color) {
        if (at < other.low) {
            base += self->locations[at] - self->indices[color];
            length = self->indices[color + 1] - self->locations[at];
        } else
            length = self->locations[at] - self->indices[color]; // at > other.low

        for (; length > 1; length >>= 1)
            if (base[length >> 1]  <  other.low) {
                base += length >> 1;
                length += length & 1;

        base += (base[0]  <  other.low);
    } else
        base += (self->locations[other.low] - self->indices[color]);

    if (base[0] > other.high)
        return 0;

    unsigned *ceil;
    if (self->colors[other.high] != color) {
        ceil = (at > base[0] && at  <  other.high) ? &self->members[self->locations[at]] : base;

        for (length = self->indices[color + 1] - self->locations[ceil[0]]; length > 1; length >>= 1)
            if (ceil[length >> 1]  < = other.high) {
                ceil += length >> 1;
                length += length & 1;

        ceil -= (ceil[0] > other.high);
    } else
        ceil = &self->members[self->locations[other.high]];

    return ceil - base + 1 - (at >= other.low && at  < = other.high);

unsigned long count_pairs(
    unsigned cnt,
    unsigned length,
    unsigned long pairs[cnt][cnt],
    unsigned *overlapping,
    colored_tree_t *tree,
    range_t self,
    range_t other
) {
    unsigned long count = 0;
    for (; (self.low % length) && (self.low  < = self.high); count += query_all(tree, self.low++, other));
    for (; ((self.high + 1) % length) && (self.low  < = self.high); count += query_all(tree, self.high--, other));

    if (self.low  < = self.high) {
        for (; (other.low % length) && (other.low <= other.high); count += query_all(tree, other.low++, self));
        for (; ((other.high + 1) % length) && (other.low  < = other.high); count += query_all(tree, other.high--, self));

        if (other.low  < = other.high) {
            self.low   /= length;
            self.high  /= length;
            other.low  /= length;
            other.high /= length;

            if (self.low > other.low) {
                self.packd  ^= other.packd;
                other.packd ^= self.packd;
                self.packd  ^= other.packd;

            unsigned high = (self.high  <  other.low) ? self.high : (other.low - 1);

            count +=
                    - pairs[high][other.low - 1UL]
                    - pairs[self.low - 1UL][other.high]
                    + pairs[self.low - 1UL][other.low - 1UL];

            self.low = high + 1;

            if (self.high > other.high) {
                self.packd  ^= other.packd;
                other.packd ^= self.packd;
                self.packd  ^= other.packd;

            if (self.low  < = self.high)
                count +=
                    (overlapping[self.high] - overlapping[self.low - 1UL])
                        + ((
                            - pairs[self.high][self.low - 1UL]
                            - pairs[self.low - 1UL][self.high]
                            + pairs[self.low - 1UL][self.low - 1UL]
                    ) << 1) + (
                            - pairs[self.high][self.high]
                            - pairs[self.low - 1UL][other.high]
                            + pairs[self.low - 1UL][self.high]

    return count;

int main() {
    unsigned at, vertex_cnt;
    unsigned short query_cnt;
    scanf("%u %hu", &vertex_cnt, &query_cnt);

    unsigned colors[vertex_cnt + 1];
    for (at = 0; at  <  vertex_cnt; scanf("%u", &colors[at++]));
    colors[at] = 0xFFFFFFFFU;
    compress(at + 1, colors);

    unsigned ancestors[at + 1];
        unsigned ancestor, descendant;
        for (memset(ancestors, 0xFFU, sizeof(ancestors)); --at; ancestors[descendant] = ancestor) {
            scanf("%u %u", &ancestor, &descendant);
            if (ancestors[--descendant] != 0xFFFFFFFFU) {
                unsigned root = descendant, next;
                for (; ancestor != 0xFFFFFFFFU; ancestor = next) {
                    next = ancestors[ancestor];
                    ancestors[ancestor] = root;
                    root = ancestor;
                for (; ancestors[descendant] != 0xFFFFFFFFU; descendant = next) {
                    next = ancestors[descendant];
                    ancestors[descendant] = ancestor;
                    ancestor = descendant;

        for (ancestor = 0xFFFFFFFFU; at != 0xFFFFFFFFU; at = descendant) {
            descendant = ancestors[at];
            ancestors[at] = ancestor;
            ancestor = at;

        ids[vertex_cnt + 1],
        bases[vertex_cnt + 1],

    unsigned char
        dist = 0;

            indices[vertex_cnt + 1],

        memset(indices, 0, sizeof(indices));
        for (ancestors[vertex_cnt] = (at = vertex_cnt); at; indices[ancestors[at--]]++);
        for (; ++at  < = vertex_cnt; indices[at] += indices[at - 1]);
        for (; --at; descendants[--indices[ancestors[at]]] = at);

        history[0] = 0;
        memset(weights, 0, sizeof(weights));
        for (at = 1; at--; )
            if (weights[history[at]])
                for (others = indices[history[at]];
                     others  <  indices[history[at] + 1];
                     weights[history[at]] += weights[descendants[others++]]);
            else {
                weights[history[at]] = 1;
                    &history[at + 1],
                    (indices[history[at] + 1] - indices[history[at]]) * sizeof(descendants[0])
                at += indices[history[at] + 1] - indices[history[at]] + 1;

            orig_ancestors[vertex_cnt + 1],
            orig_colors[vertex_cnt + 1],

        memcpy(orig_ancestors, ancestors, sizeof(ancestors));
        memcpy(orig_weights, weights, sizeof(weights));
        memcpy(orig_colors, colors, sizeof(colors));

        base_depths[0] = (bases[0] = (ids[0] = 0));
        bases[vertex_cnt] = (ids[vertex_cnt] = vertex_cnt);
        for (at = 1; at--;) {
                id = ids[history[at]],
                base = bases[id++],
                branches = indices[history[at] + 1] - indices[history[at]];

            heap_sort(&descendants[indices[history[at]]], orig_weights, branches);
            memcpy(&history[at], &descendants[indices[history[at]]], branches * sizeof(descendants[0]));

            for (others = (at += branches); branches--; base = id) {
                ids[history[--others]] = id;

                ancestors[id] = ids[orig_ancestors[history[others]]];
                weights[id] = orig_weights[history[others]];
                colors[id] = orig_colors[history[others]];

                bases[id] = base;
                base_depths[id] = base_depths[ancestors[id]] + (base == id);

                if (dist  <  base_depths[id])
                    dist = base_depths[id];

                id += weights[id];

    unsigned base_ids[vertex_cnt + 1];
    for (base_ids[0] = (others = (at = 0)); others  <  vertex_cnt; base_ids[others] = base_ids[at] + 1)
        for (at = others; bases[at] == bases[others]; base_ids[others++] = base_ids[at]);

    unsigned ancestral_bases[base_ids[vertex_cnt]][dist];
    for (ancestors[0] = 0; others--; ancestral_bases[base_ids[others]][0] = ancestors[others]);
    while ((++others + 1)  <  dist)
        for (at = 0; ++at < base_ids[vertex_cnt];
             ancestral_bases[at][others + 1] = ancestors[bases[ancestral_bases[at][others]]]);

        indexed_colors[colors[vertex_cnt] + 2],
        members[vertex_cnt + 1];

    memset(indexed_colors, 0, sizeof(indexed_colors));
    for (at = vertex_cnt + 1; at--; indexed_colors[colors[at]]++);
    for (; ++at  <  colors[vertex_cnt]; indexed_colors[at + 1] += indexed_colors[at]);
    for (at = vertex_cnt + 1; at--; members[--indexed_colors[colors[at]]] = at);
    indexed_colors[colors[vertex_cnt] + 1] = indexed_colors[colors[vertex_cnt]];

        levels = floor_log2(vertex_cnt) + 1,
        block_cnt = (vertex_cnt / levels) + 1,
        locations[vertex_cnt + 1],

    unsigned long (*pairs)[block_cnt][block_cnt] = calloc(
        (1 + block_cnt) * (1 + block_cnt),
    pairs = (void *)&pairs[0][1][1];

    for (at = vertex_cnt + 1; at--; locations[members[at]] = at);

    memset(overlapping, 0, sizeof(overlapping));
    for (at = 0; (indexed_colors[at + 1] - indexed_colors[at]) > 1; at++) {
        others = indexed_colors[at];

            block_bases[indexed_colors[at + 1] - others + 1],
            cnt = 1;

        for (block_bases[0] = members[others]; at == colors[members[++others]]; )
            if ((members[others] / levels) != (block_bases[cnt - 1] / levels))
                block_bases[cnt++] = members[others];

        block_bases[cnt] = members[others];
        for (others = 0; others  <  cnt; others++) {
            unsigned long density = locations[block_bases[others + 1]] - locations[block_bases[others]];
            overlapping[block_bases[others] / levels] += density * (density - 1);

            unsigned block = others;
            for (; ++block  <  cnt; pairs[0][block_bases[others] / levels][block_bases[block] / levels]
                += density * (locations[block_bases[block + 1]] - locations[block_bases[block]]));

    for (at = 0; ++at  <  block_cnt; overlapping[at] += overlapping[at - 1])
        pairs[0][0][at] += pairs[0][0][at - 1];

    for (at = 0; ++at  <  block_cnt; )
        for (others = 0; ++others  <  block_cnt; pairs[0][at][others] += pairs[0][at][others - 1]);

    for (at = 0; ++at  <  block_cnt; )
        for (others = 0; others  <  block_cnt; others++)
            pairs[0][at][others] += pairs[0][at - 1][others];

    colored_tree_t *tree = &(colored_tree_t) {
        .members = members,
        .colors = colors,
        .indices = indexed_colors,
        .locations = locations

    while (query_cnt--) {
        range_t left, right;
        scanf("%u %u %u %u", &left.low, &left.high, &right.low, &right.high);
        left.packd -= 0x0000000100000001UL;
        right.packd -= 0x0000000100000001UL;

        left.low = ids[left.low];
        left.high = ids[left.high];

        right.low = ids[right.low];
        right.high = ids[right.high];

        if (left.high  <  left.low)
            left.packd = (left.packd << 32) | (left.packd >> 32);

        if (right.high  <  right.low)
            right.packd = (right.packd << 32) | (right.packd >> 32);

        if (right.high  <  left.low) {
            left.packd  ^= right.packd;
            right.packd ^= left.packd;
            left.packd  ^= right.packd;

        struct {
            range_t members[32];
            unsigned cnt;
            a = {.cnt = 0},
            b = {.cnt = 0};

        unsigned common = nearest_common_ancestor(
            dist, base_ids[vertex_cnt], vertex_cnt,
            base_ids, ancestral_bases,
            base_depths, weights,
            left.low, left.high

        for (at = left.low; bases[at] != bases[common]; at = ancestral_bases[base_ids[at]][0])
            a.members[a.cnt++].packd = bases[at] | ((unsigned long)at << 32);

        for (others = left.high; bases[others] != bases[common]; others = ancestral_bases[base_ids[others]][0])
            a.members[a.cnt++].packd = bases[others] | ((unsigned long)others << 32);

        a.members[a.cnt++].packd = common | ((unsigned long)((at != common) ? at : others) << 32);

        common = nearest_common_ancestor(
            dist, base_ids[vertex_cnt], vertex_cnt,
            base_ids, ancestral_bases,
            base_depths, weights,
            right.low, right.high

        for (at = right.low; bases[at] != bases[common]; at = ancestral_bases[base_ids[at]][0])
            b.members[b.cnt++].packd = bases[at] | ((unsigned long)at << 32);

        for (others = right.high; bases[others] != bases[common]; others = ancestral_bases[base_ids[others]][0])
            b.members[b.cnt++].packd = bases[others] | ((unsigned long)others << 32);

        b.members[b.cnt++].packd = common | ((unsigned long)((at != common) ? at : others) << 32);

        unsigned long total = 0;
        for (at = 0; at  <  a.cnt; at++)
            for (others = 0; others  <  b.cnt;
                 total += count_pairs(
                     block_cnt, levels, pairs[0], overlapping, tree,
                     a.members[at], b.members[others++]

        printf("%lu\n", total);

    return 0;
C++ Programming

C++ Programming

#include <cstdlib>
#include <cstdio>
#include <iostream>
#include <cmath>
#include <algorithm>
#include <vector>
#include <set>
#include <map>
#include <cstring>
#include <cassert>

using namespace std;

typedef long long LL;
typedef unsigned long long ULL;

#define SIZE(x) (int((x).size()))
#define rep(i,l,r) for (int i=(l); i < =(r); i++)
#define repd(i,r,l) for (int i=(r); i>=(l); i--)
#define rept(i,c) for (__typeof((c).begin()) i=(c).begin(); i!=(c).end(); i++)

#define debug(x) { cerr<<#x<<" = "<<(x)<<endl; }
#define debug(x) {}

#define maxn 100010
#define LIM 100

int ta[maxn];

void ta_modify(int x, int y)
    while (x < maxn) ta[x]+=y, x+=x&-x;

int ta_query(int x)
    int ret=0;
    while (x) ret+=ta[x], x-=x&-x;
    return ret;

void ds_modify(int l, int r, int c)

int ds_query(int v)
    return ta_query(v);

int dfsN;
int dfsLeft[maxn], dfsRight[maxn];
int lg2[maxn], p[maxn][17], depth[maxn];
vector<int> e[maxn];

void dfs(int cur, int pre, int dep)
    dfsN++; dfsLeft[cur]=dfsN;
    rep(i,1,lg2[dep]) p[cur][i]=p[p[cur][i-1]][i-1];
    rept(it,e[cur]) if (*it!=pre) dfs(*it,cur,dep+1);

int movedep(int x, int y)
    if (y < 0) return 0;
    while (y) x=p[x][lg2[y&-y]], y-=y&-y;
    return x;

int lca(int x, int y)
    if (depth[x] < depth[y]) swap(x,y);
        if (p[x][i]!=p[y][i])
            x=p[x][i]; y=p[y][i];
    if (x==y) return x;
    return p[x][0];

int get_dist(int x, int y)
    int z=lca(x,y);
    return depth[x]+depth[y]-2*depth[z]+1;

int all, ti[5][2];

void check_intersect(int p1, int p2, int q1, int q2)
    if (depth[p2]>depth[q2])
        swap(p1,q1); swap(p2,q2);
    if (depth[p1] < depth[q2]) return;
    if (lca(p1,q2)!=q2 || lca(q2,p2)!=p2) return;
    int z=lca(p1,q1);
    rep(i,1,all) if (ti[i][0]==z && ti[i][1]==q2) return;
    rep(i,1,all) if (ti[i][1]==z && ti[i][0]==q2) return;
    //if (z==q2) rep(i,1,all) if (ti[i][0]==z || ti[i][1]==z) return;
    all++; ti[all][0]=z; ti[all][1]=q2;

struct tasktype
    int x, y, c;
    tasktype() {}
    tasktype(int x, int y, int c): x(x), y(y), c(c) {}

vector < tasktype> eventAddList[maxn], eventQueryList[maxn];

void addQueryEvent(int i, int p1, int q1, int c)
    p1=dfsLeft[p1]; q1=dfsLeft[q1];

void addContributionEvent(int p1, int p2, int q1, int q2)

void add_task(int i, int p1, int p2, int q1, int q2)
    if (!p1 || !p2 || !q1 || !q2) return;
    if (p[q2][0]) addQueryEvent(i,p1,p[q2][0],-1);
    if (p[p2][0]) addQueryEvent(i,p[p2][0],q1,-1);
    if (p[p2][0] && p[q2][0]) addQueryEvent(i,p[p2][0],p[q2][0],1);

map < int, vector<int> > clist;
int color[maxn];
int q[maxn][6];
int ans[maxn];

void lemon()
    lg2[1]=0; rep(i,2,maxn-1) lg2[i]=lg2[i>>1]+1;
    int n,qa; scanf("%d%d",&n,&qa);
        int x,y; scanf("%d%d",&x,&y);
    rep(i,1,qa) ans[i]=0;
        int z1=lca(q[i][0],q[i][1]);
        int z2=lca(q[i][2],q[i][3]);
        q[i][4]=z1; q[i][5]=z2;
        int t1=movedep(q[i][1],depth[q[i][1]]-depth[z1]-1);
        int t2=movedep(q[i][3],depth[q[i][3]]-depth[z2]-1);
        if (all>0)
        int cl=it->first;
        if (it->second.size() < =LIM)
            int s=it->second.size();
                    int i1=it->second[i], j1=it->second[j];
                    addContributionEvent(dfsLeft[i1], dfsRight[i1], dfsLeft[j1], dfsRight[j1]);
                int x1=ds_query(dfsLeft[q[i][0]])+ds_query(dfsLeft[q[i][1]])-2*ds_query(dfsLeft[q[i][4]]);
                if (color[q[i][4]]==cl) x1++;
                //printf("%d: %d %d %d\n",cl,q[i][0],q[i][1],x1);
                int x2=ds_query(dfsLeft[q[i][2]])+ds_query(dfsLeft[q[i][3]])-2*ds_query(dfsLeft[q[i][5]]);
                if (color[q[i][5]]==cl) x2++;
                //printf("%d: %d %d %d\n",cl,q[i][2],q[i][3],x2);
        rept(it,eventAddList[i]) ds_modify(it->x,it->y,it->c);
        rept(it,eventQueryList[i]) ans[it->y]+=it->c*ds_query(it->x);
    rep(i,1,qa) printf("%d\n",ans[i]);

int main()
    #ifndef ONLINE_JUDGE
    return 0;
Java Programming

Java Programming

import java.util.*;

public class Solution {

  static int[] nxt;
  static int[] succ;
  static int[] ptr;
  static int index = 1;

  static void addEdge(int u, int v) {
    nxt[index] = ptr[u];
    ptr[u] = index;
    succ[index++] = v;

  static int timer = 0;
  static int[] tin;
  static int[] tout;
  static int[] pr;
  static int[][] up;
  static int[] w;

  static class NodeDfs {
    int v;
    int lvl;
    int p;
    boolean start = true;

    public NodeDfs(int v, int lvl, int p) {
      this.v = v;
      this.lvl = lvl;
      this.p = p;

  static void dfs() {
    Deque < NodeDfs> q = new LinkedList<>();
    q.add(new NodeDfs(1, 0, 1));
    while (!q.isEmpty()) {
      NodeDfs node = q.peekLast();
      if (node.start) {
        tin[node.v] = timer;
        w[node.v] = node.lvl;
        up[node.v][0] = node.p;
        for (int i = 1; i  < = 17; i++) {
          up[node.v][i] = up[up[node.v][i - 1]][i - 1];
        for (int i = ptr[node.v]; i > 0; i = nxt[i]) {
          int to = succ[i];
          if (to != node.p) {
            q.add(new NodeDfs(to, node.lvl + 1, node.v));
            pr[to] = node.v;
        node.start = false;
      } else {
        tout[node.v] = timer++;


  static boolean upper(int x, int y) {
    return tout[x] >= tout[y] && tin[x]  < = tin[y];

  static int lca(int a, int b) {
    if (upper(a, b))
      return a;
    if (upper(b, a))
      return b;
    for (int i = 17; i >= 0; --i)
      if (!upper(up[a][i], b))
        a = up[a][i];
    return up[a][0];

  static void normalize(int[] a) {
    Map < Integer, Integer> trans = new HashMap<>();
    int j = 0;
    for (int i = 1; i  <  a.length; i++) {
      if (!trans.containsKey(a[i])) {
        trans.put(a[i], j++);
    for (int i = 1; i  <  a.length; i++) {
      a[i] = trans.get(a[i]);

  public static void main(String[] args) throws IOException {
    BufferedReader br = new BufferedReader(new InputStreamReader(;
    BufferedWriter bw = new BufferedWriter(new FileWriter(System.getenv("OUTPUT_PATH")));

    StringTokenizer st = new StringTokenizer(br.readLine());
    int n = Integer.parseInt(st.nextToken());
    int m = Integer.parseInt(st.nextToken());

    int[] a = new int[n + 1];
    st = new StringTokenizer(br.readLine());
    for (int i = 1; i  < = n; i++) {
      a[i] = Integer.parseInt(st.nextToken());

    nxt = new int[n + 1];
    succ = new int[n + 1];
    ptr = new int[n + 1];

    for (int i = 0; i  <  n - 1; i++) {
      st = new StringTokenizer(br.readLine());
      int u = Integer.parseInt(st.nextToken());
      int v = Integer.parseInt(st.nextToken());
      if (u  <  v) {
        addEdge(u, v);
      } else {
        addEdge(v, u);

    tin = new int[n + 1];
    tout = new int[n + 1];
    pr = new int[n + 1];
    up = new int[n + 1][20];
    w = new int[n + 1];


    int[] d = new int[m];
    int[] X1 = new int[m];
    int[] Y1 = new int[m];
    int[] X2 = new int[m];
    int[] Y2 = new int[m];

    for (int i = 0; i  <  m; i++) {
      st = new StringTokenizer(br.readLine());
      int x1 = Integer.parseInt(st.nextToken());
      int y1 = Integer.parseInt(st.nextToken());
      int x2 = Integer.parseInt(st.nextToken());
      int y2 = Integer.parseInt(st.nextToken());
      X1[i] = x1;
      Y1[i] = y1;
      X2[i] = x2;
      Y2[i] = y2;

      int lc1 = lca(x1, y1);
      int lc2 = lca(x2, y2);
      int ans1 = 0;
      if (lc1 == lc2) {

        int lc3 = lca(x1, x2);
        int lc4 = lca(x1, y2);
        int lc5 = lca(y1, x2);
        int lc6 = lca(y1, y2);
        ans1 += w[lc3] - w[lc1];
        ans1 += w[lc4] - w[lc1];
        ans1 += w[lc5] - w[lc1];
        ans1 += w[lc6] - w[lc1];
      } else if (w[lc1]  <  w[lc2]) {

        int lc3 = lca(x1, x2);
        int lc4 = lca(x1, y2);
        int lc5 = lca(y1, x2);
        int lc6 = lca(y1, y2);
        if (upper(lc2, x1) && upper(lc1, lc2)) {
          ans1 += Math.abs(w[lc3] - w[lc4]) + 1;
        if (upper(lc2, y1) && upper(lc1, lc2)) {
          ans1 += Math.abs(w[lc5] - w[lc6]) + 1;
      } else if (w[lc1] > w[lc2]) {

        int lc3 = lca(x1, x2);
        int lc4 = lca(x1, y2);
        int lc5 = lca(y1, x2);
        int lc6 = lca(y1, y2);
        if (upper(lc1, x2) && upper(lc2, lc1)) {
          ans1 += Math.abs(w[lc3] - w[lc5]) + 1;
        if (upper(lc1, y2) && upper(lc2, lc1)) {
          ans1 += Math.abs(w[lc4] - w[lc6]) + 1;
      d[i] = ans1;

    int[] b = new int[n + 1];

    for (int i = 0; i  <  m; i++) {
      int x1 = X1[i];
      int y1 = Y1[i];
      int x2 = X2[i];
      int y2 = Y2[i];
      int x3 = x1;
      int y3 = y1;
      while (!upper(x1, y1)) {
        x1 = pr[x1];
      while (y1 != x1) {
        y1 = pr[y1];

      int ans = 0;
      while (!upper(x2, y2)) {
        ans += b[a[x2]];
        x2 = pr[x2];
      while (y2 != x2) {
        ans += b[a[y2]];
        y2 = pr[y2];
      ans += b[a[x2]];

      int tmp = x1;
      x1 = x3;
      y1 = y3;
      while (x1 != tmp) {
        b[a[x1]] = 0;
        x1 = pr[x1];
      while (y1 != tmp) {
        b[a[y1]] = 0;
        y1 = pr[y1];
      b[a[x1]] = 0;
      bw.write((ans - d[i]) + "\n");


