Skip to content

Instantly share code, notes, and snippets.

@BurntSushi
Last active May 14, 2019 16:49
Show Gist options
  • Save BurntSushi/5298812 to your computer and use it in GitHub Desktop.
Save BurntSushi/5298812 to your computer and use it in GitHub Desktop.
Code from my blog post "Writing type parametric functions in Go."
// This is *not* valid Go!
func Map(f func(A) B, xs []A) []B {
ys := make([]B, len(xs))
for i, x := range xs {
ys[i] = f(x)
}
return ys
}
Map(func(x int) int { return x * x }, []int{1, 2, 3})
// Returns: [1, 4, 9]
package main
import "fmt"
func Map(f func(interface{}) interface{}, xs []interface{}) []interface{} {
ys := make([]interface{}, len(xs))
for i, x := range xs {
ys[i] = f(x)
}
return ys
}
func main() {
square := func(x interface{}) interface{} {
return x.(int) * x.(int)
}
nums := []int{1, 2, 3, 4}
gnums := make([]interface{}, len(nums))
for i, x := range nums {
gnums[i] = x
}
gsquared := Map(square, gnums)
squared := make([]int, len(gsquared))
for i, x := range gsquared {
squared[i] = x.(int)
}
fmt.Printf("%v\n", squared)
// Run time type safety is questionable.
// Map(func(a interface{}) interface{} { return len(a.(string)) },
// []interface{}{1, 2, 3})
}
package main
import (
"fmt"
"reflect"
)
func Map(f interface{}, xs interface{}) []interface{} {
vf := reflect.ValueOf(f)
vxs := reflect.ValueOf(xs)
ys := make([]interface{}, vxs.Len())
for i := 0; i < vxs.Len(); i++ {
ys[i] = vf.Call([]reflect.Value{vxs.Index(i)})[0].Interface()
}
return ys
}
func main() {
square := func(x int) int {
return x * x
}
nums := []int{1, 2, 3, 4}
gsquared := Map(square, nums)
squared := make([]int, len(gsquared))
for i, x := range gsquared {
squared[i] = x.(int)
}
fmt.Printf("%v\n", squared)
// Run time type safety is still questionable.
// Map(func(a string) int { return len(a) }, []int{1, 2, 3})
}
package main
import (
"fmt"
"reflect"
)
func Map(f interface{}, xs interface{}) interface{} {
vf := reflect.ValueOf(f)
vxs := reflect.ValueOf(xs)
tys := reflect.SliceOf(vf.Type().Out(0))
vys := reflect.MakeSlice(tys, vxs.Len(), vxs.Len())
for i := 0; i < vxs.Len(); i++ {
y := vf.Call([]reflect.Value{vxs.Index(i)})[0]
vys.Index(i).Set(y)
}
return vys.Interface()
}
func main() {
squared := Map(func(x int) int { return x * x }, []int{1, 2, 3}).([]int)
fmt.Printf("%v\n", squared)
// Run time type safety is *still* questionable.
// _ = Map(func(a string) int { return len(a) }, []int{1, 2, 3}).([]int)
}
package main
import (
"fmt"
"log"
"reflect"
)
func Map(f interface{}, xs interface{}) interface{} {
vf := reflect.ValueOf(f)
vxs := reflect.ValueOf(xs)
ftype := vf.Type()
xstype := vxs.Type()
// 1) Map's first parameter type must be `func(A) B`
if ftype.Kind() != reflect.Func {
log.Panicf("`f` should be %s but got %s", reflect.Func, ftype.Kind())
}
if ftype.NumIn() != 1 {
log.Panicf("`f` should have 1 parameter but it has %d parameters",
ftype.NumIn())
}
if ftype.NumOut() != 1 {
log.Panicf("`f` should return 1 value but it returns %d values",
ftype.NumOut())
}
// 2) Map's second parameter type must be `[]A1` where `A == A1`.
if xstype.Kind() != reflect.Slice {
log.Panicf("`xs` should be %s but got %s", reflect.Slice, xstype.Kind())
}
if xstype.Elem() != ftype.In(0) {
log.Panicf("type of `f`'s parameter should be %s but xs contains %s",
ftype.In(0), xstype.Elem())
}
// 3) Map's return type must be `[]B1` where `B == B1`.
tys := reflect.SliceOf(vf.Type().Out(0))
vys := reflect.MakeSlice(tys, vxs.Len(), vxs.Len())
for i := 0; i < vxs.Len(); i++ {
y := vf.Call([]reflect.Value{vxs.Index(i)})[0]
vys.Index(i).Set(y)
}
return vys.Interface()
}
func main() {
squared := Map(func(x int) int { return x * x }, []int{1, 2, 3}).([]int)
fmt.Printf("%v\n", squared)
// Run time type safety is *still* questionable.
_ = Map(func(a string) int { return len(a) }, []int{1, 2, 3}).([]int)
// _ = Map(5, []int{1, 2, 3}).([]int)
}
package main
import (
"fmt"
"reflect"
"github.com/BurntSushi/ty"
)
// Map has a parametric type:
//
// func Map(f func(A) B, xs []A) []B
//
// Map returns the list corresponding to the return value of applying
// `f` to each element in `xs`.
func Map(f, xs interface{}) interface{} {
chk := ty.Check(
new(func(func(ty.A) ty.B, []ty.A) []ty.B),
f, xs)
vf, vxs, tys := chk.Args[0], chk.Args[1], chk.Returns[0]
xsLen := vxs.Len()
vys := reflect.MakeSlice(tys, xsLen, xsLen)
for i := 0; i < xsLen; i++ {
vy := vf.Call([]reflect.Value{vxs.Index(i)})[0]
vys.Index(i).Set(vy)
}
return vys.Interface()
}
func main() {
squared := Map(func(x int) int { return x * x }, []int{1, 2, 3}).([]int)
fmt.Printf("%v\n", squared)
_ = Map(func(a string) int { return len(a) }, []int{1, 2, 3}).([]int)
}
package main
import (
"fmt"
"math/rand"
"reflect"
"time"
"github.com/BurntSushi/ty"
)
// Shuffle has a parametric type:
//
// func Shuffle(xs []A)
//
// Shuffle shuffles `xs` in place using a default random number
// generator.
func Shuffle(xs interface{}) {
chk := ty.Check(
new(func([]ty.A)),
xs)
vxs := chk.Args[0]
// Used for swapping in the loop.
// Equivalent to `var tmp A`.
tmp := reflect.New(vxs.Type().Elem()).Elem()
// Implements the Fisher-Yates shuffle: http://goo.gl/Hb9vg
for i := vxs.Len() - 1; i >= 1; i-- {
j := rand.Intn(i + 1)
// Swapping is a bit painful.
tmp.Set(vxs.Index(i))
vxs.Index(i).Set(vxs.Index(j))
vxs.Index(j).Set(tmp)
}
}
func main() {
rand.Seed(time.Now().UnixNano())
words := []string{
"the", "quick", "brown", "fox",
"jumps", "over", "the", "lazy", "dog",
}
Shuffle(words)
fmt.Printf("%v\n", words)
}
package main
import (
"fmt"
"reflect"
"github.com/BurntSushi/ty"
)
// Union has a parametric type:
//
// func Union(a map[A]bool, b map[A]bool) map[A]bool
//
// Union returns the union of two sets, where a set is represented as a
// `map[A]bool`. The sets `a` and `b` are not modified.
func Union(a, b interface{}) interface{} {
chk := ty.Check(
new(func(map[ty.A]bool, map[ty.A]bool) map[ty.A]bool),
a, b)
va, vb, tc := chk.Args[0], chk.Args[1], chk.Returns[0]
vtrue := reflect.ValueOf(true)
vc := reflect.MakeMap(tc)
for _, vkey := range va.MapKeys() {
vc.SetMapIndex(vkey, vtrue)
}
for _, vkey := range vb.MapKeys() {
vc.SetMapIndex(vkey, vtrue)
}
return vc.Interface()
}
func main() {
A := map[string]bool{
"springsteen": true,
"j. geils": true,
"seger": true,
}
B := map[string]bool{
"petty": true,
"seger": true,
}
AandB := Union(A, B).(map[string]bool)
fmt.Printf("%v\n", AandB)
}
package main
import (
"fmt"
"github.com/BurntSushi/ty/fun"
)
func main() {
// Memoizing a recursive function like `fibonacci`:
// Write it like normal but with a type assert.
var fib func(n int64) int64
fib = func(n int64) int64 {
switch n {
case 0:
return 0
case 1:
return 1
}
return fib(n-1) + fib(n-2)
}
// Wrap it with a memoizing function.
// The type assert here is the *only* burden on the caller.
fib = fun.Memo(fib).(func(int64) int64)
// Will keep your CPU busy for a long time
// without memoization.
fmt.Println(fib(80))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment