Saturday, October 6, 2012

Java Fork Join & Parallel Programming

After emerging of multi-cores programs , a major shift to the application development to utilize these cores by forking many threads/process for the applications esp. in the intensive processing procedure.

Java came lately to the paradigm shift with concurrency APIs in Java 7  (JSR 166) (which was initially planned in Java 5), the framework will go under another improvement to facilitate the current complex way in Java 8.

Best fit algorithms for paralization is divide-and-conquer algorithms.
In this post we will go through example of using the fork-join to develop quick sort enhancement..

1) Framework important parts:
-ForkJoinTask : this is a task to be forked and joined after processing.
It has lifecycle methods as doJoin , doInvoke , doExecute plus status.

- RecursiveAction extends ForkJoinTask: represent a recursive action..
It has important method which is compute..
We will extend this class and override compute method to do what we need.

-ForkJoinPool : represents the pool for fork and join framework , you can initialize it by the number of cores or by fixed magic number ...
Note: To get the number of processors/cores:
int processors = Runtime.getRuntime().availableProcessors();

2) Non-parallel Quick sort:
Static method that process the numbers[] to sort them...


//non-parallellll...
private static int[] numbers;
private static void quicksort(int low, int high) {
int i = low, j = high;
// Get the pivot element from the middle of the list
int pivot = numbers[low + (high-low)/2];

// Divide into two lists
while (i <= j) {
// If the current value from the left list is smaller then the pivot
// element then get the next element from the left list
while (numbers[i] < pivot) {
i++;
}
// If the current value from the right list is larger then the pivot
// element then get the next element from the right list
while (numbers[j] > pivot) {
j--;
}

// If we have found a values in the left list which is larger then
// the pivot element and if we have found a value in the right list
// which is smaller then the pivot element then we exchange the
// values.
// As we are done we can increase i and j
if (i <= j) {
exchange(i, j);
i++;
j--;
}
}
// Recursion
if (low < j)
quicksort(low, j);
if (i < high)
quicksort(i, high);
}

private static void exchange(int i, int j) {
int temp = numbers[i];
numbers[i] = numbers[j];
numbers[j] = temp;
}



3) Using Fork-Join:


public class ParallelQuickSort extends RecursiveAction {
    Phaser phaser;
    int[] arr = null;
    int left;
    int right;

    ParallelQuickSort(Phaser phaser, int[] arr) {
        this(phaser, arr, 0, arr.length - 1);
    }

    ParallelQuickSort(Phaser phaser, int[] arr, int left, int right) {
        this.phaser = phaser;
        this.arr = arr;
        this.left = left;
        this.right = right;
        phaser.register();  //important
    }


    private ParallelQuickSort leftSorter(int pivotI) {
        return new ParallelQuickSort(phaser, arr, left, --pivotI);
    }

    private ParallelQuickSort rightSorter(int pivotI) {
        return new ParallelQuickSort(phaser, arr, pivotI, right);
    }

    private void recurSort(int leftI, int rightI) {
        if (rightI - leftI > 7) {
            int pIdx = partition(leftI, rightI, getPivot(arr, leftI, rightI));
            recurSort(leftI, pIdx - 1);
            recurSort(pIdx, rightI);
        } else if (rightI - leftI > 0) {
            insertionSort(leftI, rightI);
        }
    }


    @Override
    protected void compute() {
        if (right - left > 1000) {   // if more than 1000 (totally arbitrary number i chose) try doing it parallelly
            int pIdx = partition(left, right, getPivot(arr, left, right));
            leftSorter(pIdx).fork();
            rightSorter(pIdx).fork();

        } else if (right - left > 7) {  // less than 1000 sort recursively in this thread
            recurSort(left, right);

        } else if (right - left > 0) {  //if less than 7 try simple insertion sort
            insertionSort(left, right);
        }

        if (isRoot()) { //if this instance is the root one (the one that started the sort process), wait for others
                        // to complete.
            phaser.arriveAndAwaitAdvance();
        } else {  // all not root one just arrive and de register not waiting for others.
            phaser.arriveAndDeregister();
        }
    }

    /** Patition the array segment based on the pivot   **/
    private int partition(int startI, int endI, int pivot) {
        for (int si = startI - 1, ei = endI + 1; ; ) {
            for (; arr[++si] < pivot;) ;
            for (; ei > startI && arr[--ei] > pivot ; ) ;
            if (si >= ei) {
                return si;
            }
            swap(si, ei);
        }
    }

    private void insertionSort(int leftI, int rightI) {
        for (int i = leftI; i < rightI + 1; i++)
            for (int j = i; j > leftI && arr[j - 1] > arr[j]; j--)
                swap(j, j - 1);

    }

    private void swap(int startI, int endI) {
        int temp = arr[startI];
        arr[startI] = arr[endI];
        arr[endI] = temp;
    }

    /**
     * Check to see if this instance is the root, i.e the first one used to sort the array.
     * @return
     */
    private boolean isRoot() {
        return arr.length == (right - left) + 1;
    }

    /**
     * copied from java.util.Arrays
     */
    private int getPivot(int[] arr, int startI, int endI) {
        int len = (endI - startI) + 1;
        // Choose a partition element, v
        int m = startI + (len >> 1);       // Small arrays, middle element
        if (len > 7) {
            int l = startI;
            int n = startI + len - 1;
            if (len > 40) {        // Big arrays, pseudomedian of 9
                int s = len / 8;
                l = med3(arr, l, l + s, l + 2 * s);
                m = med3(arr, m - s, m, m + s);
                n = med3(arr, n - 2 * s, n - s, n);
            }
            m = med3(arr, l, m, n); // Mid-size, med of 3
        }
        int v = arr[m];
        return v;
    }

    /**
     * copied from java.util.Arrays
     */
    private static int med3(int x[], int a, int b, int c) {
        return (x[a] < x[b] ?
                (x[b] < x[c] ? b : x[a] < x[c] ? c : a) :
                (x[b] > x[c] ? b : x[a] > x[c] ? c : a));
    }


4) Testing :
Generate big random array then sort it using both and see the results:

       private static int[] getRandom(int i) {
        Random randomGenerator = new Random(i);
        int[] array = new int[i];
        for (int n = 0; n < i; n++) {
            array[n] = randomGenerator.nextInt();
        }
        return array;
    }

    public static void main(String[] args) throws InterruptedException {
        int[] arr = getRandom(1000000);
        numbers=Arrays.copyOf(arr, arr.length);
        int[] arr2=Arrays.copyOf(arr, arr.length);
        System.out.println("show: " + arr.length+" "+numbers.length);
        System.out.println("show: " + arr[0]+" "+arr[arr.length-1]);
        System.out.println("show: " + numbers[0]+" "+numbers[numbers.length-1]);
        ForkJoinPool pool = new ForkJoinPool();
        StopWatch  stopWatch=new StopWatch();
        Phaser phaser = new Phaser();
        pool.invoke(new ParallelQuickSort(phaser, arr));
        stopWatch.stop();
        System.out.println("Elapsed Time: " + stopWatch.getElapsedTime());
        System.out.println("show: " + arr[0]+" "+arr[arr.length-1]);
        System.out.println("show: " + numbers[0]+" "+numbers[numbers.length-1]);
        numbers=getRandom(1000000);
        stopWatch=new StopWatch();        
        quicksort(0,numbers.length-1);
        stopWatch.stop();
        System.out.println("Elapsed Time: " + stopWatch.getElapsedTime());
        System.out.println("show: " + arr[0]+" "+arr[arr.length-1]);
        System.out.println("show: " + numbers[0]+" "+numbers[numbers.length-1]);     
        stopWatch=new StopWatch();        
        Arrays.sort(arr2);
        stopWatch.stop();
        System.out.println("Elapsed Time: " + stopWatch.getElapsedTime());
    }



5) Output :

Example of output of this code:


show: 1000000 1000000
show: 1608240105 -356486679
show: 1608240105 -356486679
Elapsed Time: 68
show: -2147481329 2147476538
show: 1608240105 -356486679
Elapsed Time: 114
show: -2147481329 2147476538
show: -2147481329 2147476538
Elapsed Time: 92



So sorting 1 million element in the array toke 65 milliseconds using fork-join and 113 milliseconds using single threaded mode, and using java optimized quick sort it takes 92 milliseconds.

No comments:

Post a Comment