Skip to content

Instantly share code, notes, and snippets.

@tristanwietsma
Created September 9, 2013 21:48
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tristanwietsma/6501991 to your computer and use it in GitHub Desktop.
Save tristanwietsma/6501991 to your computer and use it in GitHub Desktop.
Key-Value server (GET, SET, DEL, PUB, SUB)
package main
import (
"log"
"net"
"sync"
"strings"
)
type Store struct {
dataMap map[string]string
subMap map[string][]chan<- string
sync.RWMutex
}
func (s *Store) Init() {
s.dataMap = make(map[string]string)
s.subMap = make(map[string][]chan<- string)
}
func (s *Store) Get(key string) (value string, ok bool) {
s.RLock()
defer s.RUnlock()
value, ok = s.dataMap[key]
return
}
func (s *Store) Set(key, value string) bool {
s.Lock()
defer s.Unlock()
s.dataMap[key] = value
return true
}
func (s *Store) Delete(keys []string) {
s.Lock()
defer s.Unlock()
for _, key := range keys {
delete(s.dataMap, key)
}
}
func (s *Store) Publish(key string, incoming <-chan string) {
for {
value, ok := <-incoming
if !ok { return }
_ = s.Set(key, value)
s.updateSubscribers(key, value)
}
}
func (s *Store) Subscribe(key string, outgoing chan<- string) {
_, hasSubs := s.fetchSubscribers(key)
s.Lock()
defer s.Unlock()
if hasSubs {
s.subMap[key] = append(s.subMap[key], outgoing)
} else {
subs := []chan<- string{outgoing}
s.subMap[key] = subs
}
}
// helpers
func (s *Store) unsubscribe(key string, outgoing chan<- string) {
subs, hasSubs := s.fetchSubscribers(key)
s.Lock()
defer s.Unlock()
if hasSubs {
newSubs := []chan<- string{}
for _, sub := range subs {
if sub == outgoing {
continue
}
newSubs = append(newSubs, sub)
}
s.subMap[key] = newSubs
}
}
func (s *Store) fetchSubscribers(key string) ([]chan<- string, bool) {
s.RLock()
subs, hasSubs := s.subMap[key]
s.RUnlock()
return subs, hasSubs
}
func (s *Store) updateSubscribers(key, value string) {
subs, ok := s.fetchSubscribers(key)
if ok {
for _, out := range subs {
defer func() {
if r := recover(); r != nil {
s.unsubscribe(key, out)
}
}()
out <- value
}
}
}
func HandleConnection(c net.Conn) {
defer c.Close()
fromAddr := c.RemoteAddr()
buf := make([]byte, 1024)
nb, err := c.Read(buf)
if err != nil {
log.Printf("[%s] read error: '%s'.\n", fromAddr, err)
return
}
args := strings.Split(string(buf[:nb]), " ")
numArgs := len(args)
if numArgs < 2 {
log.Printf("[%s] failed to provide command.\n", fromAddr)
return
}
switch args[0] {
case "GET": // GET key
value, ok := DB.Get(args[1])
log.Printf("[%s] got value of '%s'.\n", fromAddr, args[1])
if ok {
_, err = c.Write([]byte(value))
} else {
_, err = c.Write([]byte("(nil)"))
}
if err != nil {
log.Printf("[%s] write to socket failed.\n", fromAddr)
}
case "SET": // SET key value
ok := false
if numArgs == 3 {
ok = DB.Set(args[1], args[2])
}
if ok {
log.Printf("[%s] set '%s' to '%s'.\n", fromAddr, args[1], args[2])
_, err = c.Write([]byte("OK"))
} else {
log.Printf("[%s] failed to set value of '%s'.\n", fromAddr, args[1])
_, err = c.Write([]byte("FAIL"))
}
if err != nil {
log.Printf("[%s] write to socket failed.\n", fromAddr)
}
case "DEL": // DEL key [key ...]
DB.Delete(args[1:])
log.Printf("[%s] deleted key(s): %s.\n", fromAddr, args[1:])
_, err = c.Write([]byte("OK"))
if err != nil {
log.Printf("[%s] write to socket failed.\n", fromAddr)
}
case "PUB": // PUB key; values... (maintains persistent connection)
incoming := make(chan string)
go DB.Publish(args[1], incoming)
log.Printf("[%s] publishing to '%s'.\n", fromAddr, args[1])
_, err = c.Write([]byte("READY"))
if err != nil {
close(incoming)
log.Printf("[%s] write to socket failed.\n", fromAddr)
return
}
for {
nb, err := c.Read(buf)
if err != nil {
close(incoming)
log.Printf("[%s] finished publishing to '%s'.\n", fromAddr, args[1])
return
}
incoming <- string(buf[:nb])
}
case "SUB": // SUB key
outgoing := make(chan string)
DB.Subscribe(args[1], outgoing)
log.Printf("[%s] subscribed to '%s'.\n", fromAddr, args[1])
for value := range outgoing {
_, err := c.Write([]byte(value))
if err != nil {
close(outgoing)
log.Printf("[%s] unsubscribed to '%s'.\n", fromAddr, args[1])
return
}
}
default:
log.Printf("[%s] unknown command '%s'.\n", fromAddr, args[0])
}
}
//
// Main
//
var DB Store
func main() {
DB.Init()
// create a listener
l, err := net.Listen("tcp", ":2000")
if err != nil {
log.Fatal(err)
}
defer l.Close()
// handle concurrent connections
for {
conn, err := l.Accept()
if err != nil {
log.Fatal(err)
}
go HandleConnection(conn)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment