Skip to content

Instantly share code, notes, and snippets.

@eliquious
Last active January 27, 2016 02:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save eliquious/90baf4ce72be7a17aac4 to your computer and use it in GitHub Desktop.
Save eliquious/90baf4ce72be7a17aac4 to your computer and use it in GitHub Desktop.
Simple key-value store in Go
//CLIENT
package main
import (
"bufio"
"fmt"
"io"
"math/rand"
"net"
"runtime"
"sync"
"sync/atomic"
"time"
)
var requestCount uint64
var totalPingsPerConnection uint64 = 1000000
var concurrentConnections uint64 = 128
var totalPings = concurrentConnections * totalPingsPerConnection
func monitor(done chan bool) chan bool {
out := make(chan bool)
go func() {
var last uint64
start := time.Now()
var elapsed time.Duration
OUTER:
for {
select {
case <-done:
break OUTER
case <-time.After(1 * time.Second):
current := atomic.LoadUint64(&requestCount)
fmt.Printf("%d combined requests per second (%d)\n", current-last, current)
if current >= uint64(totalPings) || current-last == 0 {
break OUTER
}
last = current
}
}
elapsed = time.Since(start)
fmt.Printf("%f ns\n", float64(elapsed)/float64(requestCount))
fmt.Printf("%d requests\n", requestCount)
fmt.Printf("%f requests per second\n", float64(time.Second)/(float64(elapsed)/float64(requestCount)))
fmt.Printf("elapsed: %s\r\n", elapsed)
out <- true
return
}()
return out
}
func (c *client) readLoop(wg *sync.WaitGroup) {
defer wg.Done()
rd := bufio.NewReader(c.conn)
// buf := make([]byte, 1024)
for atomic.LoadUint64(&c.revcd) < totalPingsPerConnection {
line, _, err := rd.ReadLine()
if err != nil {
fmt.Println(err)
return
}
if len(line) > 0 && line[0] == '-' {
fmt.Println(string(line))
return
}
atomic.AddUint64(&c.revcd, 1)
atomic.AddUint64(&requestCount, 1)
// n, err := rd.Read(buf)
// if n > 0 {
// // fmt.Println(string(buf[:n]))
// atomic.AddUint64(&c.revcd, 1)
// atomic.AddUint64(&requestCount, 1)
// } else if err == io.EOF {
// return
// }
// if err != nil && err != io.EOF {
// fmt.Println(err)
// return
// }
}
}
func (c *client) writeLoop(wg *sync.WaitGroup) {
defer wg.Done()
wr := bufio.NewWriterSize(c.conn, 65536)
outBuf := []byte(fmt.Sprintf("GET key%d\r\n", rand.Intn(8)))
// outBuf := []byte(fmt.Sprintf("SET key%d 0 value\r\n", rand.Intn(int(concurrentConnections)*8)))
for atomic.LoadUint64(&c.sent) < totalPingsPerConnection {
n, err := wr.Write(outBuf)
if n > 0 {
// wr.Flush()
}
if err != nil && err != io.EOF {
fmt.Println(err)
return
}
atomic.AddUint64(&c.sent, 1)
}
wr.Flush()
}
type client struct {
sent uint64
revcd uint64
conn *net.TCPConn
}
func NewClient(wg *sync.WaitGroup) {
defer wg.Done()
tcpAddr, _ := net.ResolveTCPAddr("tcp4", "localhost:9022")
conn, err := net.DialTCP("tcp", nil, tcpAddr)
if err != nil {
fmt.Println(err)
return
}
var w sync.WaitGroup
c := client{conn: conn}
w.Add(2)
go c.writeLoop(&w)
go c.readLoop(&w)
w.Wait()
conn.Close()
}
func main() {
runtime.GOMAXPROCS(2)
var wg sync.WaitGroup
done := make(chan bool)
c := monitor(done)
for i := uint64(0); i < concurrentConnections; i++ {
wg.Add(1)
go NewClient(&wg)
}
wg.Wait()
done <- true
<-c
}
package main
import (
"bufio"
"fmt"
"io"
"log"
"net"
"os"
"runtime"
"strconv"
"time"
// "github.com/pkg/profile"
"github.com/coocood/freecache"
disruptor "github.com/smartystreets/go-disruptor"
"golang.org/x/net/context"
)
func main() {
runtime.GOMAXPROCS(8)
// defer profile.Start(profile.CPUProfile, profile.ProfilePath(".")).Stop()
cache := freecache.NewCache(0)
cache.Set([]byte("key0"), []byte("value"), 0)
cache.Set([]byte("key1"), []byte("value"), 0)
cache.Set([]byte("key2"), []byte("value"), 0)
cache.Set([]byte("key3"), []byte("value"), 0)
cache.Set([]byte("key4"), []byte("value"), 0)
cache.Set([]byte("key5"), []byte("value"), 0)
cache.Set([]byte("key6"), []byte("value"), 0)
cache.Set([]byte("key7"), []byte("value"), 0)
cache.Set([]byte("key8"), []byte("value"), 0)
server := Server{Cache: cache}
server.Start(":9022")
}
// Server handles all the incoming connections as well as handler dispatch.
type Server struct {
Cache *freecache.Cache
Logger *log.Logger
Addr *net.TCPAddr
listener *net.TCPListener
context context.Context
cancel context.CancelFunc
}
// Start starts accepting client connections. This method is non-blocking.
func (s *Server) Start(addr string) (err error) {
// Validate the ssh bind addr
if addr == "" {
err = fmt.Errorf("server: Empty bind address")
return
}
// Open SSH socket listener
netAddr, e := net.ResolveTCPAddr("tcp", addr)
if e != nil {
err = fmt.Errorf("server: Invalid tcp address")
return
}
// Create listener
listener, e := net.ListenTCP("tcp", netAddr)
if e != nil {
err = e
return
}
s.Logger = log.New(os.Stdout, "logger: ", log.Lshortfile)
s.listener = listener
s.Addr = listener.Addr().(*net.TCPAddr)
s.Logger.Println("Starting server", "addr", addr)
c, cancel := context.WithCancel(context.Background())
s.context = c
s.cancel = cancel
go s.listen(c)
<-c.Done()
return
}
// Stop stops the server and kills all goroutines. This method is blocking.
func (s *Server) Stop() {
s.Logger.Println("[INFO] Shutting down server...")
s.cancel()
}
// listen accepts new connections and handles the conversion from TCP to SSH connections.
func (s *Server) listen(c context.Context) {
defer s.listener.Close()
for {
// Accepts will only block for 1s
s.listener.SetDeadline(time.Now().Add(time.Second))
select {
// Stop server on channel receive
case <-c.Done():
s.Logger.Println("[DEBUG] Context Completed")
return
default:
// Accept new connection
tcpConn, err := s.listener.Accept()
if err != nil {
if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
// s.Logger.Println("[DBG] Connection timeout...")
} else {
s.Logger.Println("[WRN] Connection failed", "error", err)
}
continue
}
// Handle connection
s.Logger.Println("[INF] Successful TCP connection:", tcpConn.RemoteAddr().String())
h := NewTcpHandler(s.Cache, tcpConn, s.Logger, s.context)
go h.Execute()
}
}
}
const RingBufferCapacity = 4 * 1024 * 1024
const RingBufferMask = RingBufferCapacity - 1
func NewTcpHandler(cache *freecache.Cache, conn net.Conn, logger *log.Logger, ctx context.Context) *tcpHandler {
ring := [RingBufferCapacity]byte{}
controller := disruptor.
Configure(RingBufferCapacity).
WithConsumerGroup(&ByteConsumer{
Writer: bufio.NewWriterSize(conn, 128*1024),
Closer: conn,
ring: &ring,
cache: cache,
logger: logger,
}).Build()
controller.Start()
c, cancel := context.WithCancel(ctx)
return &tcpHandler{cache,
logger, conn, &ring, &controller, c, cancel,
}
}
type tcpHandler struct {
cache *freecache.Cache
logger *log.Logger
conn net.Conn
ring *[RingBufferCapacity]byte
controller *disruptor.Disruptor
context context.Context
cancel context.CancelFunc
}
func (t *tcpHandler) Execute() {
defer t.conn.Close()
defer t.controller.Stop()
// Read from connection
go t.createReadLoop()
<-t.context.Done()
}
func (t *tcpHandler) createReadLoop() {
defer t.cancel()
writer := t.controller.Writer()
buffer := make([]byte, 64*1024)
var sequence, reservations int64
var idx int
rd := bufio.NewReaderSize(t.conn, 1024*1024)
for {
select {
case <-t.context.Done():
return
default:
n, err := rd.Read(buffer)
// t.logger.Printf("n: %d; err: %s\r\n", n, err)
// t.logger.Printf("body: %s\r\n", string(buffer[:n]))
if n > 0 {
idx = 0
reservations = int64(n)
sequence = writer.Reserve(reservations)
for lower := sequence - reservations + 1; lower <= sequence; lower++ {
t.ring[lower&RingBufferMask] = buffer[idx]
idx++
}
writer.Commit(sequence-reservations+1, sequence)
} else if err == io.EOF {
return
}
if err != nil && err != io.EOF {
return
}
}
}
}
var ErrMaxSize = []byte("-ERRMAXSIZE Request too large\r\n")
var ErrUnknownCmd = []byte("-ERRPARSE Unknown command\r\n")
var ErrIncompleteCmd = []byte("-ERRPARSE Incomplete command\r\n")
var ErrEmptyRequest = []byte("-ERRPARSE Empty request\r\n")
var ErrInvalidCmdDelimiter = []byte("-ERRPARSE Missing tab character after command\r\n")
var ErrInvalidKeyDelimiter = []byte("-ERRPARSE Missing tab character after key\r\n")
var ErrLargeKey = []byte("-ERRLARGEKEY The key is larger than 65535\r\n")
var ErrLargeEntry = []byte("-ERRLARGEENTRY The entry size is larger than 1/1024 of cache size\r\n")
var ErrNotFound = []byte("-ERRNOTFOUND Entry not found\r\n")
var ErrInvalidExpiration = []byte("-ERRINVEXP Invalid key expiration\r\n")
type ByteConsumer struct {
Writer *bufio.Writer
Closer io.Closer
logger *log.Logger
cache *freecache.Cache
ring *[RingBufferCapacity]byte
buffer [65336 * 4]byte
closed bool
requestSize int
}
func (b *ByteConsumer) Consume(lower, upper int64) {
if b.closed {
return
}
// b.logger.Printf("Consuming %d-%d\r\n", lower, upper)
var char byte
for sequence := lower; sequence <= upper; sequence++ {
if b.requestSize >= len(b.buffer) {
b.Writer.Write(ErrMaxSize)
b.logger.Printf("ERR %s\r\n", string(ErrMaxSize))
b.Writer.Flush()
b.Closer.Close()
b.closed = true
return
}
char = b.ring[sequence&RingBufferMask]
// b.logger.Printf("char '%s'\r\n", char)
// end of request
if char == '\n' {
line := b.buffer[:b.requestSize]
ok := b.parse(line, b.Writer)
b.Writer.Flush()
b.closed = !ok
if b.closed {
b.Closer.Close()
return
}
// reset request size to 0
b.requestSize = 0
// also skip the new line that follows the \r
// sequence += 1
} else if char == '\r' {
continue
} else {
b.buffer[b.requestSize] = char
b.requestSize++
}
}
}
func (b *ByteConsumer) parse(line []byte, w *bufio.Writer) bool {
if len(line) == 0 {
w.Write(ErrEmptyRequest)
return false
}
// b.logger.Printf("Parsing line: %s\r\n", strconv.Quote(string(line)))
var i, expiration int
var c byte
state := OP_START
var e error
var key, value, err []byte
// Move to loop instead of range syntax to allow jumping of i
for i = 0; i < len(line); i++ {
c = line[i]
switch state {
case OP_START:
switch c {
case 'G', 'g':
state = OP_G
case 'S', 's':
state = OP_S
default:
b.logger.Printf("OP_START: (%s) %s\r\n", c, strconv.Quote(string(line)))
err = ErrUnknownCmd
goto PARSE_ERR
}
case OP_G:
switch c {
case 'E', 'e':
state = OP_GE
default:
b.logger.Printf("OP_G: (%s) %s\r\n", c, strconv.Quote(string(line)))
err = ErrUnknownCmd
goto PARSE_ERR
}
case OP_GE:
switch c {
case 'T', 't':
state = OP_GET
default:
b.logger.Printf("OP_GE: (%s) %s\r\n", c, strconv.Quote(string(line)))
err = ErrUnknownCmd
goto PARSE_ERR
}
case OP_GET:
switch c {
case '\t', ' ':
key = (line)[i+1:]
// b.logger.Printf("KEY: %s\r\n", strconv.Quote(string(key)))
goto PERFORM_GET
default:
err = ErrInvalidCmdDelimiter
goto PARSE_ERR
}
case OP_S:
switch c {
case 'E', 'e':
state = OP_SE
default:
b.logger.Printf("OP_S: (%s) %s\r\n", c, strconv.Quote(string(line)))
err = ErrUnknownCmd
goto PARSE_ERR
}
case OP_SE:
switch c {
case 'T', 't':
state = OP_SET
default:
b.logger.Printf("OP_GE: (%s) %s\r\n", c, strconv.Quote(string(line)))
err = ErrUnknownCmd
goto PARSE_ERR
}
case OP_SET:
switch c {
case '\t', ' ':
state = OP_SET_KEY
default:
err = ErrInvalidCmdDelimiter
goto PARSE_ERR
}
case OP_SET_KEY:
offset := i
for i < len(line) {
if line[i] == '\t' || line[i] == ' ' {
break
}
i++
}
// end of input?
// key empty?
if i == len(line) || i-offset == 0 {
err = ErrIncompleteCmd
goto PARSE_ERR
}
// set key
key = line[offset:i]
// skip space
i++
// parse expiration
offset = i
for i < len(line) {
if line[i] == '\t' || line[i] == ' ' {
i++
break
}
i++
}
// end of input? empty?
if i >= len(line) || i-offset == 0 {
err = ErrIncompleteCmd
goto PARSE_ERR
}
exp, e := strconv.Atoi(string(line[offset : i-1]))
if e != nil {
err = ErrInvalidExpiration
goto PARSE_ERR
}
expiration = exp
value = line[i-1:]
goto PERFORM_SET
// key = line[offset:i]
// i++
// switch cmd {
// case OP_GET:
// goto PERFORM_GET
// default:
// return ErrUnknownCmd, false
// }
}
}
PARSE_ERR:
// Ignoring all write errors here, because we are going to return false
// and close the connection due to the parse error anyway.
w.Write(err)
b.logger.Printf("%s (%s)\r\n", string(err), strconv.Quote(string(line)))
return false
PERFORM_GET:
value, e = b.cache.Get(key)
if e == freecache.ErrLargeKey {
err = ErrLargeKey
goto PARSE_ERR
} else if e == freecache.ErrLargeEntry {
err = ErrLargeEntry
goto PARSE_ERR
} else if e == freecache.ErrNotFound {
err = ErrNotFound
goto PARSE_ERR
} else if e != nil {
err = []byte("-ERRCACHE Unknown cache error\r\n")
goto PARSE_ERR
} else {
if _, err := w.Write([]byte("+VALUE ")); err != nil {
return false
}
if _, err := w.Write(value); err != nil {
return false
}
if _, err := w.Write([]byte("\r\n")); err != nil {
return false
}
return true
}
PERFORM_SET:
e = b.cache.Set(key, value, expiration)
if e == freecache.ErrLargeKey {
err = ErrLargeKey
goto PARSE_ERR
} else if e == freecache.ErrLargeEntry {
err = ErrLargeEntry
goto PARSE_ERR
} else if e != nil {
err = []byte("-ERRCACHE Unknown cache error\r\n")
goto PARSE_ERR
} else {
if _, err := w.Write([]byte("+OK\r\n")); err != nil {
return false
}
return true
}
b.logger.Printf("END %s\r\n", c)
err = ErrUnknownCmd
goto PARSE_ERR
}
const (
OP_START int = iota
OP_G
OP_GE
OP_GET
OP_S
OP_SE
OP_SET
OP_SET_KEY
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment