Created
September 24, 2021 13:17
-
-
Save aarzilli/d45fe34d6fe72bc5e1f390964a4ac394 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// This requires tip of Go master branch for 1.18; `gotip` is a great tool to | |
// use. | |
// | |
// Run benchmark with | |
// $ gotip test -bench=. <filename> | |
package main | |
import ( | |
"math/rand" | |
"sort" | |
"testing" | |
) | |
const N = 100_000 | |
type Pair struct { | |
A, B int | |
} | |
func makeRandomPairs(n int) []*Pair { | |
pairs := make([]*Pair, n) | |
for i := 0; i < n; i++ { | |
pairs[i] = &Pair{ rand.Intn(n), rand.Intn(n) } | |
} | |
return pairs | |
} | |
func TestTest(t *testing.T) { | |
pairs := makeRandomPairs(10) | |
quickSortFunc_gen(pairs, 0, len(pairs), maxDepth(len(pairs)), func(a, b *Pair) bool { return a.A < b.A }) | |
for _, pair := range pairs { | |
t.Logf("%#v", pair) | |
} | |
} | |
func BenchmarkSortSlice(b *testing.B) { | |
rand.Seed(42) | |
for i := 0; i < b.N; i++ { | |
b.StopTimer() | |
pairs := makeRandomPairs(N) | |
b.StartTimer() | |
sort.Slice(pairs, func(i, j int) bool { return pairs[i].A < pairs[j].A }) | |
} | |
} | |
func BenchmarkSortGeneric(b *testing.B) { | |
rand.Seed(42) | |
for i := 0; i < b.N; i++ { | |
b.StopTimer() | |
pairs := makeRandomPairs(N) | |
b.StartTimer() | |
quickSortFunc_gen(pairs, 0, len(pairs), maxDepth(len(pairs)), func(a, b *Pair) bool { return a.A < b.A }) | |
} | |
} | |
// These functions are taken from the Go source code of sort.go and | |
// tweaked to only work on []int, avoiding interfaces and dispatch. | |
func insertionSort_func(data []int, a, b int) { | |
for i := a + 1; i < b; i++ { | |
for j := i; j > a && data[j] < data[j-1]; j-- { | |
data[j], data[j-1] = data[j-1], data[j] | |
} | |
} | |
} | |
func maxDepth(n int) int { | |
var depth int | |
for i := n; i > 0; i >>= 1 { | |
depth++ | |
} | |
return depth * 2 | |
} | |
func quickSort_func(data []int, a, b, maxDepth int) { | |
for b-a > 12 { | |
if maxDepth == 0 { | |
heapSort_func(data, a, b) | |
return | |
} | |
maxDepth-- | |
mlo, mhi := doPivot_func(data, a, b) | |
if mlo-a < b-mhi { | |
quickSort_func(data, a, mlo, maxDepth) | |
a = mhi | |
} else { | |
quickSort_func(data, mhi, b, maxDepth) | |
b = mlo | |
} | |
} | |
if b-a > 1 { | |
for i := a + 6; i < b; i++ { | |
if data[i] < data[i-6] { | |
data[i], data[i-6] = data[i-6], data[i] | |
} | |
} | |
insertionSort_func(data, a, b) | |
} | |
} | |
func siftDown_func(data []int, lo, hi, first int) { | |
root := lo | |
for { | |
child := 2*root + 1 | |
if child >= hi { | |
break | |
} | |
if child+1 < hi && data[first+child] < data[first+child+1] { | |
child++ | |
} | |
if !(data[first+root] < data[first+child]) { | |
return | |
} | |
data[first+root], data[first+child] = data[first+child], data[first+root] | |
root = child | |
} | |
} | |
func heapSort_func(data []int, a, b int) { | |
first := a | |
lo := 0 | |
hi := b - a | |
for i := (hi - 1) / 2; i >= 0; i-- { | |
siftDown_func(data, i, hi, first) | |
} | |
for i := hi - 1; i >= 0; i-- { | |
data[first], data[first+i] = data[first+i], data[first] | |
siftDown_func(data, lo, i, first) | |
} | |
} | |
func medianOfThree_func(data []int, m1, m0, m2 int) { | |
if data[m1] < data[m0] { | |
data[m1], data[m0] = data[m0], data[m1] | |
} | |
if data[m2] < data[m1] { | |
data[m2], data[m1] = data[m1], data[m2] | |
if data[m1] < data[m0] { | |
data[m1], data[m0] = data[m0], data[m1] | |
} | |
} | |
} | |
func doPivot_func(data []int, lo, hi int) (midlo, midhi int) { | |
m := int(uint(lo+hi) >> 1) | |
if hi-lo > 40 { | |
s := (hi - lo) / 8 | |
medianOfThree_func(data, lo, lo+s, lo+2*s) | |
medianOfThree_func(data, m, m-s, m+s) | |
medianOfThree_func(data, hi-1, hi-1-s, hi-1-2*s) | |
} | |
medianOfThree_func(data, lo, m, hi-1) | |
pivot := lo | |
a, c := lo+1, hi-1 | |
for ; a < c && data[a] < data[pivot]; a++ { | |
} | |
b := a | |
for { | |
for ; b < c && data[pivot] >= data[b]; b++ { | |
} | |
for ; b < c && data[pivot] < data[c-1]; c-- { | |
} | |
if b >= c { | |
break | |
} | |
data[b], data[c-1] = data[c-1], data[b] | |
b++ | |
c-- | |
} | |
protect := hi-c < 5 | |
if !protect && hi-c < (hi-lo)/4 { | |
dups := 0 | |
if data[pivot] >= data[hi-1] { | |
data[c], data[hi-1] = data[hi-1], data[c] | |
c++ | |
dups++ | |
} | |
if data[b-1] >= data[pivot] { | |
b-- | |
dups++ | |
} | |
if data[m] >= data[pivot] { | |
data[m], data[b-1] = data[b-1], data[m] | |
b-- | |
dups++ | |
} | |
protect = dups > 1 | |
} | |
if protect { | |
for { | |
for ; a < b && data[b-1] >= data[pivot]; b-- { | |
} | |
for ; a < b && data[a] < data[pivot]; a++ { | |
} | |
if a >= b { | |
break | |
} | |
data[a], data[b-1] = data[b-1], data[a] | |
a++ | |
b-- | |
} | |
} | |
data[pivot], data[b-1] = data[b-1], data[pivot] | |
return b - 1, c | |
} | |
func insertionSortFunc_gen[Elem any](data []Elem, a, b int, less func(a, b Elem) bool) { | |
for i := a + 1; i < b; i++ { | |
for j := i; j > a && less(data[j], data[j-1]); j-- { | |
data[j], data[j-1] = data[j-1], data[j] | |
} | |
} | |
} | |
func quickSortFunc_gen[Elem any](data []Elem, a, b, maxDepth int, less func(a, b Elem) bool) { | |
for b-a > 12 { | |
if maxDepth == 0 { | |
heapSortFunc_gen(data, a, b, less) | |
return | |
} | |
maxDepth-- | |
mlo, mhi := doPivot_gen(data, a, b, less) | |
if mlo-a < b-mhi { | |
quickSortFunc_gen(data, a, mlo, maxDepth, less) | |
a = mhi | |
} else { | |
quickSortFunc_gen(data, mhi, b, maxDepth, less) | |
b = mlo | |
} | |
} | |
if b-a > 1 { | |
for i := a + 6; i < b; i++ { | |
if less(data[i], data[i-6]) { | |
data[i], data[i-6] = data[i-6], data[i] | |
} | |
} | |
insertionSortFunc_gen(data, a, b, less) | |
} | |
} | |
func siftDown_gen[Elem any](data []Elem, lo, hi, first int, less func(a, b Elem) bool) { | |
root := lo | |
for { | |
child := 2*root + 1 | |
if child >= hi { | |
break | |
} | |
if child+1 < hi && less(data[first+child], data[first+child+1]) { | |
child++ | |
} | |
if !less(data[first+root], data[first+child]) { | |
return | |
} | |
data[first+root], data[first+child] = data[first+child], data[first+root] | |
root = child | |
} | |
} | |
func heapSortFunc_gen[Elem any](data []Elem, a, b int, less func(a, b Elem) bool) { | |
first := a | |
lo := 0 | |
hi := b - a | |
for i := (hi - 1) / 2; i >= 0; i-- { | |
siftDown_gen(data, i, hi, first, less) | |
} | |
for i := hi - 1; i >= 0; i-- { | |
data[first], data[first+i] = data[first+i], data[first] | |
siftDown_gen(data, lo, i, first, less) | |
} | |
} | |
func medianOfThree_gen[Elem any](data []Elem, m1, m0, m2 int, less func(a, b Elem) bool) { | |
if less(data[m1], data[m0]) { | |
data[m1], data[m0] = data[m0], data[m1] | |
} | |
if less(data[m2], data[m1]) { | |
data[m2], data[m1] = data[m1], data[m2] | |
if less(data[m1], data[m0]) { | |
data[m1], data[m0] = data[m0], data[m1] | |
} | |
} | |
} | |
func doPivot_gen[Elem any](data []Elem, lo, hi int, less func(a, b Elem) bool) (midlo, midhi int) { | |
m := int(uint(lo+hi) >> 1) | |
if hi-lo > 40 { | |
s := (hi - lo) / 8 | |
medianOfThree_gen(data, lo, lo+s, lo+2*s, less) | |
medianOfThree_gen(data, m, m-s, m+s, less) | |
medianOfThree_gen(data, hi-1, hi-1-s, hi-1-2*s, less) | |
} | |
medianOfThree_gen(data, lo, m, hi-1, less) | |
pivot := lo | |
a, c := lo+1, hi-1 | |
for ; a < c && less(data[a], data[pivot]); a++ { | |
} | |
b := a | |
for { | |
for ; b < c && !less(data[pivot], data[b]); b++ { | |
} | |
for ; b < c && less(data[pivot], data[c-1]); c-- { | |
} | |
if b >= c { | |
break | |
} | |
data[b], data[c-1] = data[c-1], data[b] | |
b++ | |
c-- | |
} | |
protect := hi-c < 5 | |
if !protect && hi-c < (hi-lo)/4 { | |
dups := 0 | |
if !less(data[pivot], data[hi-1]) { | |
data[c], data[hi-1] = data[hi-1], data[c] | |
c++ | |
dups++ | |
} | |
if !less(data[b-1], data[pivot]) { | |
b-- | |
dups++ | |
} | |
if !less(data[m], data[pivot]) { | |
data[m], data[b-1] = data[b-1], data[m] | |
b-- | |
dups++ | |
} | |
protect = dups > 1 | |
} | |
if protect { | |
for { | |
for ; a < b && !less(data[b-1], data[pivot]); b-- { | |
} | |
for ; a < b && less(data[a], data[pivot]); a++ { | |
} | |
if a >= b { | |
break | |
} | |
data[a], data[b-1] = data[b-1], data[a] | |
a++ | |
b-- | |
} | |
} | |
data[pivot], data[b-1] = data[b-1], data[pivot] | |
return b - 1, c | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment