Skip to content

Instantly share code, notes, and snippets.

@DanielHeath
Forked from cjyar/memoize.go
Created December 2, 2013 02:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save DanielHeath/7744196 to your computer and use it in GitHub Desktop.
Save DanielHeath/7744196 to your computer and use it in GitHub Desktop.
package memoize
import (
"fmt"
"reflect"
)
// fptr is a pointer to a function variable which will receive a
// memoized wrapper around function impl. Impl must have 1 or more
// arguments, all of which must be usable as map keys; and it must
// have 1 or more return values.
func Memoize(fptr, impl interface{}) {
implType := reflect.TypeOf(impl)
implValue := reflect.ValueOf(impl)
if implType.Kind() != reflect.Func {
panic(fmt.Sprintf("Not a function: %v", impl))
}
if implType.NumIn() == 0 {
panic(fmt.Sprintf("%v takes no inputs", impl))
}
if implType.NumOut() == 0 {
panic(fmt.Sprintf("%v gives no outputs", impl))
}
if !reflect.PtrTo(implType).AssignableTo(reflect.TypeOf(fptr)) {
panic(fmt.Sprintf("Can't assign %v to %v", impl, fptr))
}
var resultTypes []reflect.Type
for on := 0; on < implType.NumOut(); on++ {
out := implType.Out(on)
resultTypes = append(resultTypes, out)
}
mapTypes := make([]reflect.Type, implType.NumIn())
mapType := reflect.TypeOf([]reflect.Value{})
mapTypes[len(mapTypes)-1] = mapType
for in := implType.NumIn() - 1; in >= 0; in-- {
inType := implType.In(in)
mapType = reflect.MapOf(inType, mapType)
mapTypes[in] = mapType
}
m := reflect.MakeMap(mapTypes[0])
mem := func(args []reflect.Value) []reflect.Value {
thisMap := m
for an := 0; an < len(args)-1; an++ {
v := thisMap.MapIndex(args[an])
if !v.IsValid() {
v = reflect.MakeMap(mapTypes[an+1])
thisMap.SetMapIndex(args[an], v)
}
thisMap = v
}
an := len(args) - 1
v := thisMap.MapIndex(args[an])
var vs []reflect.Value
if v.IsValid() {
for i := 0; i < v.Len(); i++ {
// v.Index() gives us a Value for
// Value for int. We need a Value for
// int.
valval := v.Index(i)
val := deVal(valval).(reflect.Value)
vs = append(vs, val)
}
} else {
vs = implValue.Call(args)
thisMap.SetMapIndex(args[an], reflect.ValueOf(vs))
}
return vs
}
typedMem := reflect.MakeFunc(implType, mem)
reflect.ValueOf(fptr).Elem().Set(typedMem)
}
func deVal(val reflect.Value) interface{} {
var result interface{}
inner := func(v interface{}) {
result = v
}
reflect.ValueOf(inner).Call([]reflect.Value{val})
return result
}
package memoize
import (
"testing"
)
var intData = []struct {
in int
calls int
}{
{0, 1},
{1, 2},
{0, 2},
{2, 3},
{2, 3},
}
func TestInt1(t *testing.T) {
numCalls := 0
var f1 func(int) int
f2 := func(i int) int {
numCalls++
return i
}
Memoize(&f1, f2)
for _, d := range intData {
out := f1(d.in)
if out != d.in {
t.Errorf("Got %d, want %d", out, d.in)
}
if numCalls != d.calls {
t.Errorf("Num calls = %d, want %d", numCalls, d.calls)
}
}
}
func TestReassign(t *testing.T) {
numCalls := 0
f := func(i int) int {
numCalls++
return i
}
Memoize(&f, f)
for _, d := range intData {
out := f(d.in)
if out != d.in {
t.Errorf("Got %d, want %d", out, d.in)
}
if numCalls != d.calls {
t.Errorf("Num calls = %d, want %d", numCalls, d.calls)
}
}
}
var int2Data = []struct {
in1, in2 int
calls int
}{
{0, 0, 1},
{0, 1, 2},
{1, 0, 3},
{0, 0, 3},
{0, 1, 3},
{1, 0, 3},
}
func TestInt2(t *testing.T) {
numCalls := 0
f := func(a, b int) int {
numCalls++
return a + b
}
Memoize(&f, f)
for _, d := range int2Data {
out := f(d.in1, d.in2)
if out != d.in1+d.in2 {
t.Errorf("Got %d, want %d + %d", out, d.in1, d.in2)
}
if numCalls != d.calls {
t.Errorf("Num calls = %d, want %d", numCalls, d.calls)
}
}
}
var int22Data = []struct {
in1, in2 int
calls int
}{
{0, 0, 1},
{1, 0, 2},
{0, 1, 3},
{0, 1, 3},
{1, 0, 3},
{0, 0, 3},
}
func TestInt22(t *testing.T) {
numCalls := 0
f := func(a, b int) (int, int) {
numCalls++
return b, a
}
Memoize(&f, f)
for _, d := range int22Data {
out1, out2 := f(d.in1, d.in2)
if out1 != d.in2 || out2 != d.in1 {
t.Errorf("Got (%d, %d) from (%d, %d)", out1, out2,
d.in1, d.in2)
}
if numCalls != d.calls {
t.Errorf("Num calls = %d, want %d", numCalls, d.calls)
}
}
}
var mixedData = []struct {
in1 int
in2 string
calls int
}{
{0, "zero", 1},
{1, "zero", 2},
{0, "one", 3},
{1, "one", 4},
{0, "zero", 4},
{1, "zero", 4},
{0, "one", 4},
{1, "one", 4},
}
func TestMixed(t *testing.T) {
numCalls := 0
f := func(a int, b string) (string, int) {
numCalls++
return b, a
}
Memoize(&f, f)
for _, d := range mixedData {
out1, out2 := f(d.in1, d.in2)
if out1 != d.in2 || out2 != d.in1 {
t.Errorf("Got (%s, %d) from (%d, %s)", out1, out2,
d.in1, d.in2)
}
if numCalls != d.calls {
t.Errorf("Num calls = %d, want %d", numCalls, d.calls)
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment