Skip to content

Instantly share code, notes, and snippets.

@UnPolinomio
Created August 4, 2021 19:46
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save UnPolinomio/5098c4146e97bd3c3aedc26e791e8785 to your computer and use it in GitHub Desktop.
Save UnPolinomio/5098c4146e97bd3c3aedc26e791e8785 to your computer and use it in GitHub Desktop.
Concurrent cache with Go
package main
import (
"fmt"
"sync"
"time"
)
type Function func(key interface{}) (interface{}, error)
type FunctionResult struct {
Value interface{}
Error error
}
type Placeholder struct{}
type Cache struct {
fn Function
InProgress map[interface{}]Placeholder
Waiting map[interface{}][]chan<- FunctionResult
Results map[interface{}]FunctionResult
Lock sync.RWMutex
}
func NewCache(fn Function) *Cache {
return &Cache{
fn: fn,
InProgress: make(map[interface{}]Placeholder),
Waiting: make(map[interface{}][]chan<- FunctionResult),
Results: make(map[interface{}]FunctionResult),
}
}
func (c *Cache) Get(key interface{}) (interface{}, error) {
c.Lock.RLock()
result, ok := c.Results[key]
c.Lock.RUnlock()
if ok {
fmt.Printf("Key %v found\n", key)
return result.Value, result.Error
}
c.Lock.RLock()
_, ok = c.InProgress[key]
c.Lock.RUnlock()
if ok {
c.Lock.Lock()
ch := make(chan FunctionResult)
c.Waiting[key] = append(c.Waiting[key], ch)
c.Lock.Unlock()
fmt.Printf("Waiting key %v...\n", key)
result = <-ch
fmt.Printf("Key %v recieved\n", key)
} else {
c.Lock.Lock()
c.InProgress[key] = Placeholder{}
c.Lock.Unlock()
fmt.Printf("Key %v in progress\n", key)
res, err := c.fn(key)
result = FunctionResult{
Value: res,
Error: err,
}
c.Lock.Lock()
chs := c.Waiting[key]
c.Results[key] = result
delete(c.InProgress, key)
delete(c.Waiting, key)
c.Lock.Unlock()
fmt.Printf("Sending key %v...\n", key)
for _, ch := range chs {
ch <- result
}
}
return result.Value, result.Error
}
func Double(number interface{}) (interface{}, error) {
n, ok := number.(int)
if !ok {
return 0, fmt.Errorf("Got %d, expected an int", n)
}
time.Sleep(2 * time.Second)
return n * 2, nil
}
var (
wg = sync.WaitGroup{}
)
func PrintDouble(cache *Cache, v int) {
res, err := cache.Get(v)
if err == nil {
fmt.Printf("Got %d: %v\n", v, res)
} else {
fmt.Printf("Error on key %d, %v\n", v, err)
}
}
func main() {
cache := NewCache(Double)
values := []int{1, 1, 2, 2, 3}
wg.Add(len(values))
for _, v := range values {
go func(v int) {
PrintDouble(cache, v)
wg.Done()
}(v)
}
wg.Add(1)
go func() {
time.Sleep(3 * time.Second)
PrintDouble(cache, 3)
wg.Done()
}()
wg.Wait()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment