Skip to content

Instantly share code, notes, and snippets.

@ydnar
Last active June 11, 2018 20:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ydnar/1a13ef704982cc77d78a3e04b2ce58d7 to your computer and use it in GitHub Desktop.
Save ydnar/1a13ef704982cc77d78a3e04b2ce58d7 to your computer and use it in GitHub Desktop.
sync.Map with LoadOrCreate for idempotent deferred creation of expensive entries
package syncmap
import (
"sync"
)
// Map wraps a sync.Map to enable idempotent deferred creation of new values.
type Map struct {
sync.Map
}
// Load returns the value stored in the map for a key, or nil if no value is present.
// The ok result indicates whether value was found in the map.
func (m *Map) Load(key interface{}) (value interface{}, ok bool) {
value, ok = m.Map.Load(key)
if !ok {
return
}
if f, ok := value.(valueFunc); ok {
value = f()
}
return
}
// LoadOrCreate returns the existing value for the key if present.
// Otherwise, it calls f to create a new value, stores it back in the map, and returns the newly created value.
// Func f is guaranteed to not be called more than once for multiple in-flight requests.
// The loaded result is true if the value was loaded, false if created.
func (m *Map) LoadOrCreate(key interface{}, f func() interface{}) (value interface{}, loaded bool) {
var once sync.Once
var actual interface{}
value, loaded = m.LoadOrStore(key, valueFunc(func() interface{} {
once.Do(func() {
actual = f()
m.Store(key, actual)
})
return actual
}))
return
}
// LoadOrStore returns the existing value for the key if present.
// Otherwise, it stores and returns the given value.
// The loaded result is true if the value was loaded, false if stored.
func (m *Map) LoadOrStore(key interface{}, value interface{}) (actual interface{}, loaded bool) {
actual, loaded = m.Map.LoadOrStore(key, value)
if f, ok := actual.(valueFunc); ok {
actual = f()
}
return
}
// Range calls f sequentially for each key and value present in the map.
// If f returns false, range stops the iteration.
func (m *Map) Range(f func(key, value interface{}) bool) {
m.Map.Range(func(key, value interface{}) bool {
if f, ok := value.(valueFunc); ok {
value = f()
}
return f(key, value)
})
}
// valueFunc is a private type to wrap a func that returns a (potentially new) value.
type valueFunc func() interface{}
package syncmap
import (
"reflect"
"sync/atomic"
"testing"
"time"
)
func ff(value interface{}, c *int32) func() interface{} {
return func() interface{} {
atomic.AddInt32(c, 1)
time.Sleep(100 * time.Millisecond)
return value
}
}
// TestNotValueFunc ensures that (func() interface{}) is not the same as valueFunc.
func TestNotValueFunc(t *testing.T) {
var c int32
var v interface{} = ff(true, &c)
if _, ok := v.(valueFunc); ok {
t.Error("ff returned a valueFunc")
}
}
func TestMap_Load(t *testing.T) {
var m Map
var c int32
m.Store("a", "1")
m.Store("b", "2")
m.Store("c", valueFunc(ff("lazy", &c)))
m.Store(true, valueFunc(ff(true, &c)))
tests := []struct {
name string
key interface{}
wantValue interface{}
wantOK bool
}{
{"a", "a", "1", true},
{"b", "b", "2", true},
{"c", "c", "lazy", true},
{"true", true, true, true},
{"false", false, nil, false},
{"unknown", "unknown", nil, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
value, ok := m.Load(tt.key)
if !reflect.DeepEqual(value, tt.wantValue) {
t.Errorf("Map.Load() got value = %v, want %v", value, tt.wantValue)
}
if ok != tt.wantOK {
t.Errorf("Map.Load() got ok = %v, want %v", ok, tt.wantOK)
}
})
}
}
func TestMap_LoadOrCreate(t *testing.T) {
var m Map
var c int32
tests := []struct {
name string
key interface{}
f func() interface{}
wantValue interface{}
wantLoaded bool
wantC int32
}{
{"empty map", "a", ff(1, &c), 1, false, 1},
{"first key again", "a", ff(1, &c), 1, true, 1},
{"second key", "b", ff(2, &c), 2, false, 2},
{"second key again", "b", ff(2, &c), 2, true, 2},
{"third key", "c", ff(3, &c), 3, false, 3},
{"third key again", "c", ff(3, &c), 3, true, 3},
{"third key again and again", "c", ff(3, &c), 3, true, 3},
{"third key again and again and again", "c", ff(3, &c), 3, true, 3},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
// Load with LoadOrCreate
value, loaded := m.LoadOrCreate(tt.key, tt.f)
if !reflect.DeepEqual(value, tt.wantValue) {
t.Errorf("Map.LoadOrCreate() got value = %v, want %v", value, tt.wantValue)
}
if loaded != tt.wantLoaded {
t.Errorf("Map.LoadOrCreate() got loaded = %v, want %v", loaded, tt.wantLoaded)
}
if atomic.LoadInt32(&c) != tt.wantC {
t.Errorf("Map.LoadOrCreate() changed c = %v, want %v", c, tt.wantC)
}
// Load again
value, loaded = m.Load(tt.key)
if !reflect.DeepEqual(value, tt.wantValue) {
t.Errorf("Map.Load() got value = %v, want %v", value, tt.wantValue)
}
if loaded != true {
t.Errorf("Map.LoadOrCreate() got loaded = %v, want %v", loaded, true)
}
if atomic.LoadInt32(&c) != tt.wantC {
t.Errorf("Map.LoadOrCreate() changed c = %v, want %v", c, tt.wantC) // c should not change
}
})
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment