Skip to content

Instantly share code, notes, and snippets.

@MastaP
Created March 14, 2012 11:21
Show Gist options
  • Save MastaP/2035851 to your computer and use it in GitHub Desktop.
Save MastaP/2035851 to your computer and use it in GitHub Desktop.
Counting inversions
--module AlgoClass where
import Data.List (elem, permutations)
import System.IO
import System.CPUTime
--Week 1
--Merge Sort
mergesort :: Ord a => [a] -> [a]
mergesort [] = []
mergesort a@([_]) = a
mergesort xs = let (ls,rs) = split xs in merge (mergesort ls) (mergesort rs)
merge :: Ord a => [a] -> [a] -> [a]
merge [] rs = rs
merge ls [] = ls
merge ls@(l:lst) rs@(r:rst) = if l < r
then l:(merge lst rs)
else r:(merge ls rst)
split :: [a] -> ([a], [a])
split [] = ([],[])
split [x] = ([x],[])
split (l:r:xs) = let (ls,rs) = split xs in (l:ls,r:rs)
split' [] = ([],[])
split' [x] = ([x],[])
split' xs = let q = div (length xs) 2 in (take q xs, drop q xs)
test_ms = let array = [8,12,5,8,3,4,8,1,3,5] in mergesort array
test_ms2 = let array = [8,12,5,8,3,4,8,1,3,5] in elem (mergesort array) (permutations array)
--Inversions
sortandcount :: Ord a => [a] -> ([a],Int)
sortandcount [] = ([],0)
sortandcount a@([_]) = (a,0)
sortandcount xs = (xs',lc + rc + xc) where
(ls,rs) = split' xs
(ls',lc) = sortandcount ls
(rs',rc) = sortandcount rs
(xs',xc) = mergeandcount ls' rs'
mergeandcount :: Ord a => [a] -> [a] -> ([a],Int)
mergeandcount [] rs = (rs,0)
mergeandcount ls [] = (ls,0)
mergeandcount ls@(l:lst) rs@(r:rst) = if l < r
then let (ds,dc) = mergeandcount lst rs in (l:ds,dc)
else let (ds,dc) = mergeandcount ls rst in (r:ds,dc+(length ls))
countInversions :: Ord a => [a] -> Int
countInversions = snd . sortandcount
--countInversions performance:
--real 0m19.488s
--real 0m15.504s with -O2 flag
--A bit optimized inversions counting -- 10 times faster than the previous
sortandcount' :: Ord a => [a] -> ([a],Int,Int) --(array, array length, inversions #)
sortandcount' [] = ([], 0, 0)
sortandcount' a@([_]) = (a, 1, 0)
sortandcount' xs = (xs', lenls'+lenrs', lc + rc + xc) where
(ls,rs) = split' xs
(ls',lenls',lc) = sortandcount' ls
(rs',lenrs',rc) = sortandcount' rs
(xs', xc) = mergeandcount' ls' lenls' rs' lenrs'
mergeandcount' :: Ord a => [a] -> Int -> [a] -> Int -> ([a],Int)
mergeandcount' [] _ rs _ = (rs,0)
mergeandcount' ls _ [] _ = (ls,0)
mergeandcount' ls@(l:lst) lenls rs@(r:rst) lenrs =
if l < r
then let (ds,dc) = mergeandcount' lst (lenls-1) rs lenrs in (l:ds,dc)
else let (ds,dc) = mergeandcount' ls lenls rst (lenrs-1) in (r:ds,dc+lenls)
countInversions' :: Ord a => [a] -> Int
countInversions' xs = let (_,_,c) = sortandcount' xs in c
--countInversions' performance:
--Counting time: 1.40508
--2407905288 inversions
testProgSet1 = do
t0 <- getCPUTime
contents <- readFile "zIntegerArray.txt"
t1 <- getCPUTime
print $ "Parsing time: " ++ show (t1-t0)
t2 <- getCPUTime
print $ countInversions' (map read (words contents) :: [Int])
t3 <- getCPUTime
print $ "Counting time: " ++ show (fromIntegral (t3-t2)/1000000000000)
testProgSetMS = do
handle <- openFile "zIntegerArray.txt" ReadMode
contents <- hGetContents handle
print $ head (mergesort (map read (words contents) :: [Int]))
hClose handle
main = testProgSet1
import java.io.*;
import java.util.*;
public class Inversions {
private static long inv = 0;
public static long count( int[] arr ) {
inv = 0;
sortAndCount( arr, 0, arr.length-1 );
return inv;
}
public static int[] sortAndCount( int[] arr, int p, int r ) {
if( r > p ) {
int q = (p + r)/2;
int[] b = sortAndCount( arr, p, q);
int[] c = sortAndCount( arr, q+1, r);
return countSplitInv( b, c );
}
return new int[] { arr[p] };
}
private static int[] countSplitInv( int[] b, int[] c ) {
int[] d = new int[b.length+c.length];
int i, j;
i = j = 0;
for( int k = 0; k < d.length; k++ ) {
if( i >= b.length ) {
d[k] = c[j++];
} else if( j >= c.length ) {
d[k] = b[i++];
} else if( b[i] < c[j] ) {
d[k] = b[i++];
} else {//
d[k] = c[j++];
inv += b.length-i;
}
}
return d;
}
private static int[] getArray() {
List<Integer> list = new ArrayList<Integer>();
FileReader fr;
try {
fr = new FileReader( "src/org/coursera/algo/zIntegerArray.txt" );
BufferedReader br = new BufferedReader( fr );
String line;
while( ( line = br.readLine() ) != null ) {
list.add( Integer.parseInt( line ) );
}
} catch ( Exception e ) {
e.printStackTrace();
}
int[] arr = new int[list.size()];
for( int i = 0; i < arr.length; i++ ) {
arr[i] = list.get( i );
}
System.out.println("Parsed array of length " + arr.length);
return arr;
}
/**
* @param args
*/
public static void main( String[] args ) {
long start = System.currentTimeMillis();
int[] arr = getArray();
System.out.println( "Parse array time: " + (System.currentTimeMillis()-start));
//Parse array time: 127
start = System.currentTimeMillis();
System.out.println( count( arr ) );//2407905288
System.out.println( "Count inversions time: " + (System.currentTimeMillis()-start));
//Count inversions time: 42
}
}
see previous revisions for the content of this file (https://gist.github.com/2035851)
@nrolland
Copy link

I really enjoyed reading the java code, especially after seeing that in fsharp, it takes 35 lines, and runs in 0.6 seconds :)) Does the java sort really takes 42 ms ?

@MastaP
Copy link
Author

MastaP commented Jun 17, 2012

Java is quite verbose, true :) The algorithm can run even faster (in about 15-20ms) with some JVM bootstrapping techniques (in order to force JIT compilation)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment