Algorithm

Problem Name: Data Structures - Max Transform

In this HackerRank in Data Structures - Max Transform solutions,

Transforming data into some other data is typical of a programming job. This problem is about a particular kind of transformation which we'll call the max transform.

Let A be a zero-indexed array of integers. For 0 <= i <= j < length(A) , let ai...j denote the subarray of A from index i to index j , inclusive. Let's define the max transform of A as the array obtained by the following procedure:

• Let B be a list, initially empty.
• For K from 0 to length(A) - 1: For i from 0 to length(A) - K - 1: Let j = i + k. Append max(Ai...j) to the end of B.
• Return B.

The returned array is defined as the max transform of A . We denote it by S(A).

Complete the function solve that takes an integer array A as input. Given an array, find the sum of the elements of (S(A)), i.e., the max transform of the max transform of A. Since the answer may be very large, only find it modulo 109 + 7.

Input Format

The first line of input contains a single integer n denoting the length of A.

The second line contains n space-separated integers A0 , A1 , ..... , An-1 denoting the elements of A.

Constraints

• 1 <= n <= 2 * 105
• 1 <= Ai <= 106

Output Format

Print a single line containing a single integer denoting the answer.

Sample Input 0

3
3 2 1


Sample Output 0

58


Explanation 0

In the sample case, we have:

A = [3,2,1]

S(A)= [3,2,1,3,2,3]

S(S(A)) = [3,2,1,3,2,3,3,2,3,3,3,3,3,3,3,3,3,3,3,3,3]

Therefore, the sum of the elements of S(S(A)) is 58

Code Examples

#1 Code Example with C Programming

Code - C Programming


#pragma GCC optimize ("Ofast")
#pragma GCC target ("sse4")
#include<stdio.h>
#include<string.h>
#include<stdlib.h>
const int mod = 1000000007, _2 = 500000004;
int N, MX = 0, tp, a[200010], i_1[200010], st[200010], mxl[200010], mxr[200010], sxl[200010], sxr[200010];
long long M, CNT, ANS = 0;
void calc(int w, int x, int y)
{
if( x < y )
{
int temp = x;
x = y;
y = temp;
}
int k;
if( x == y )
{
k = ( ( (long long)( x + y ) * i_1[y] % mod - (long long)x * x % mod ) % mod + mod ) % mod;
}
else
{
k = ( ( (long long)y * ( i_1[x-1] - i_1[y] ) % mod + (long long)( x + y ) * i_1[y] % mod ) % mod + mod ) % mod;
}
ANS = ( ANS + (long long)w * k ) % mod;
CNT -= k;
if( CNT  <  0 )
{
CNT += mod;
}
}
void calcl(int w, int x, int y)
{
if( x == 1 || y == 0 )
{
return;
}
int k;
if( y  <  x )
{
k = i_1[y];
}
else
{
k = ( i_1[x-1] + (long long)( y - x + 1 ) * ( x - 1 ) ) % mod;
}
ANS = ( ANS + (long long)w * k ) % mod;
CNT -= k;
if( CNT  <  0 )
{
CNT += mod;
}
}
void calcr(int w, int x, int y)
{
if( x == 0 || y == 1 )
{
return;
}
int k;
if( y + 1  < = x )
{
k = i_1[y-1];
}
else
{
k = ( i_1[x] + (long long)( y - x - 1 ) * x ) % mod;
}
ANS = ( ANS + (long long)w * k ) % mod;
CNT -= k;
if( CNT  <  0 )
{
CNT += mod;
}
}
int main()
{
int p;
scanf("%d", &N);
for( int i = 1 ; i  < = N ; i++ )
{
scanf("%d", &a[i]>;
MX = MX > a[i] ? MX : a[i];
}
M = ( (long long)N * ( N + 1 ) >> 1 ) % mod;
M = (long long)M * ( M + 1 ) % mod * _2 % mod;
CNT = M;
for( int i = 1 ; i  < = N ; i++ )
{
i_1[i] = ( i_1[i-1] + i ) % mod;
}
for( int i = 1 ; i  < = N ; i++ )
{
sxl[i] = sxl[i-1] > a[i] ? sxl[i-1] : a[i];
}
for( int i = N ; i ; i-- )

{
sxr[i] = sxr[i+1] > a[i] ? sxr[i+1] : a[i];
}
tp = 0;
for( int i = 1 ; i  < = N ; i++ )
{
while( tp > 0 && a[st[tp]] <= a[i] )
{
tp--;
}
if(tp)
{
mxl[i] = st[tp] + 1;
}
else
{
mxl[i] = 1;
}
st[++tp] = i;
}
tp = 0;
for( int i = N ; i ; i-- )
{
while( tp > 0 && a[st[tp]] < a[i] )
{
tp--;
}
if(tp)
{
mxr[i] = st[tp] - 1;
}
else
{
mxr[i] = N;
}
st[++tp] = i;
}
for( int i = 1 ; i  < = N ; i++ )
{
calc(a[i], i-mxl[i]+1, mxr[i]-i+1);
}
p = N;
for( int i = 1 ; i  < = N ; i++ >
{
int g = sxl[i];
while( p > i && sxr[p]  <  g )
{
p--;
}
while( p < i )
{
p++;
}
calcl(g, i, N-p);
}
p = 1;
for( int i = N ; i ; i-- )
{
int g = sxr[i];
while( p  <  i && sxl[p] <= g )
{
p++;
}
while( p > i )
{
p--;
}
calcr(g, N-i+1, p-1);
}
CNT = ( CNT % mod + mod ) % mod;
ANS = ( ANS + (long long)CNT * MX ) % mod;
printf("%lld", ANS);
return 0;
}

Copy The Code &

#3 Code Example with Java Programming

Code - Java Programming


import java.io.*;
import java.math.*;
import java.security.*;
import java.text.*;
import java.util.*;
import java.util.concurrent.*;
import java.util.regex.*;

public class Solution {

// Complete the solve function below.
static final int SUM_DIV = 1000000007;
static class Plateau {
final int start;
final int end;
final int v;

Plateau(int start, int end, int v) {
this.start = start;
this.end = end;
this.v = v;
}

@Override
public String toString() {
return new StringJoiner(", ",  "[", "]")
.toString();
}
}

static int solve(int[] input) {
// Return the sum of S(S(A)) modulo 10^9+7.
final Map < Integer, Plateau> mapStart = new HashMap<>(input.length * 2);
final Map mapEnd = new HashMap<>(input.length * 2);
for (int i = 0; i  <  input.length; ++i) {
Plateau p = new Plateau(i, i, input[i]);
mapStart.put(i, p);
mapEnd.put(i, p);
}
long subtract = 0;
Plateau cur = mapStart.remove(0);
mapEnd.remove(0);

for (;;) {

if (mapStart.isEmpty()) {
long total = totalCount(input.length) ;
long result = ((((long)cur.v) * total + subtract) + SUM_DIV) % SUM_DIV;
//                System.out.println("total=" + total + " subtract=" + subtract + " result=" + result);
return (int)result;
}
Plateau prev = mapEnd.get(normalize(cur.start - 1, input));
if (prev.v == cur.v) {
// extend plateau

cur = new Plateau(prev.start, cur.end, cur.v);
//                System.out.println("Extending plateau back, " + cur.toString());
mapStart.remove(prev.start);
mapEnd.remove(prev.end);
continue;
}
Plateau next = mapStart.get(normalize(cur.end + 1, input));
if (next.v == cur.v) {
cur = new Plateau(cur.start, next.end, cur.v);
//                System.out.println("Extending plateau forward, " + cur.toString());
mapStart.remove(next.start);
mapEnd.remove(next.end);
continue;
}
if (next.v > cur.v && prev.v > cur.v) {
// found plateau; pull it up
int nextV = Math.min(next.v, prev.v);
long delta = (long) (nextV - cur.v);
if (cur.end >= cur.start) {
delta *= calculateCounts(normalize(cur.end - cur.start + 1, input));
} else {
delta *= countInverse(input.length - cur.start, normalize(cur.end + 1 - cur.start, input));
}
//                System.out.println("Pull up, nextV=" + nextV + " cur=" + cur +
//                        " subDelta=" + delta + " sub=" + subtract + "->" + (subtract - delta));
subtract -= delta;
subtract %= SUM_DIV;
cur = new Plateau(cur.start, cur.end, nextV);

continue;
}

//            System.out.println("value=" + (countMaxClean(input) + subtract + " " + Arrays.toString(input)));
boolean back = prev.v  <  cur.v;
Plateau successor;
if (back) {
successor = prev;
} else { //next  <  v
successor = next;
}
mapStart.remove(successor.start);
mapEnd.remove(successor.end);
mapStart.put(cur.start, cur);
mapEnd.put(cur.end, cur);
cur = successor;
//            System.out.println("Switch " + (back ? "back" : "forw") + ", " + cur.toString());
}
}
private static int normalize(int idx, int[] input) {
return (idx + input.length) % input.length;
}

private static int getByIdx(int[] input, int i) {
return input[normalize(i + input.length, input)];
}

private static long totalCount(long n) {
long s1Size = n * (n + 1) / 2  % SUM_DIV;
long s2Size = (s1Size * (s1Size + 1) / 2) % SUM_DIV;
return s2Size;
}

private static long calculateCounts(long n) {
return (n * n * n + 3 * n * n + 2 * n) / 6 % SUM_DIV;
}

private static long countInverse(long c1, long l) {
if (c1  < = l / 2) {
return (-4 * c1 * c1 * c1 + l * l * l +
6 * c1 * c1 * l - 3 * c1 * l * l
- 3 * c1 * l + 3 * l * l - 2 * c1 + 2 * l) / 6 % SUM_DIV;
} else {
return ((countInverse(l - c1 - 1, l) - temp(c1 + 1) % SUM_DIV + temp(l - c1) % SUM_DIV) +
SUM_DIV) % SUM_DIV;
}
}
private static long temp(long n) {
return (n * n + n) / 2;
}

private static final Scanner scanner = new Scanner(System.in);

public static void main(String[] args) throws IOException {
BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(System.getenv("OUTPUT_PATH")));

int n = scanner.nextInt();
scanner.skip("(\r\n|[\n\r\u2028\u2029\u0085])?");

int[] A = new int[n];

String[] AItems = scanner.nextLine().split(" ");
scanner.skip("(\r\n|[\n\r\u2028\u2029\u0085])?");

for (int i = 0; i  <  n; i++) {
int AItem = Integer.parseInt(AItems[i]);
A[i] = AItem;
}

int result = solve(A);

bufferedWriter.write(String.valueOf(result));
bufferedWriter.newLine();

bufferedWriter.close();

scanner.close();
}
}

Copy The Code &

#4 Code Example with Python Programming

Code - Python Programming

#!/bin/python3

import math
import os
import random
import re
import sys

# Complete the solve function below.

import math
import os
import random
import re
import sys
sys.setrecursionlimit(9999999)
from decimal import Decimal
def t1(n):
return Decimal(n * (n + 1) / 2)

def t2(n):
return Decimal(n * (n + 1) * (n + 2) / 6)

def u2(n):
return Decimal(n * (n + 2) * (2 * n + 5) / 24)

def countzip(a, b):
return u2(a + b) - u2(abs(a - b)) + t2(abs(a - b))

def countends(x, n, ex):
return countzip(n, ex) - countzip(x, ex) - countzip(n - 1 - x, 0)

def countsplit(x, n):
return t1(t1(n)) - t1(x) - countzip(n - x - 1, x - 1)

K = 20
lg = [0] * (1 << K)
for i in range(K):
lg[1 << i] = i
for i in range(1, 1 << K):
lg[i] = max(lg[i], lg[i - 1])

def make_rangemax(A):
n = len(A)
assert 1 << K > n

key = lambda x: A[x]
mxk = []
mxk.append(range(n))
for k in range(K - 1):
mxk.append(list(mxk[-1]))
for i in range(n - (1 << k)):
mxk[k + 1][i] = max(
mxk[k][i], mxk[k][i + (1 << k)],
key=key)

def rangemax(i, j):
k = lg[j - i]
return max(mxk[k][i], mxk[k][j - (1 << k)], key=key)

return rangemax

def brutesolo(A):
rangemax = make_rangemax(A)
stack = [(0, len(A))]
ans = 0
while stack:
i, j = stack.pop()
if i != j:
x = rangemax(i, j)
stack.append((i, x))
stack.append((x + 1, j))
ans += A[x] * (x - i + 1) * (j - x)
return ans

def make_brute(A):
rangemax = make_rangemax(A)

def brute(i, j):
stack = [(i, j)]
ans = 0
while stack:
i, j = stack.pop()
if i != j:
x = rangemax(i, j)
stack.append((i, x))
stack.append((x + 1, j))
ans += A[x] * countends(x - i, j - i, 0)
return ans

return brute, rangemax

def ends(A, B):
brutea, rangemaxa = make_brute(A)
bruteb, rangemaxb = make_brute(B)

stack = [(len(A), len(B))]
ans = 0
while stack:
i, j = stack.pop()
if i == 0:
ans += bruteb(0, j)
elif j == 0:
ans += brutea(0, i)
else:
x = rangemaxa(0, i)
y = rangemaxb(0, j)
if A[x] < B[y]:
ans += bruteb(y + 1, j)
ans += B[y] * countends(y, j, i)
stack.append((i, y))
else:
ans += brutea(x + 1, i)
ans += A[x] * countends(x, i, j)
stack.append((x, j))

return ans

def maxpairs(a):
return [max(x, y) for x, y in zip(a, a[1:])]

def solve(A):
n = len(A)
x = max(range(n), key=lambda x: A[x])
return (int((brutesolo(A[:x]) +
ends(A[x + 1:][::-1], maxpairs(A[:x])) +
A[x] * countsplit(x, n))%(10**9+7)))

if __name__ == '__main__':
fptr = open(os.environ['OUTPUT_PATH'], 'w')

n = int(input())

A = list(map(int, input().rstrip().split()))

result = solve(A)

fptr.write(str(result) + '\n')

fptr.close()

Copy The Code &