Skip to content

Instantly share code, notes, and snippets.

@jauhararifin
Last active January 10, 2024 15:10
Show Gist options
  • Save jauhararifin/663752042410c5526b2b7076965a2a00 to your computer and use it in GitHub Desktop.
Save jauhararifin/663752042410c5526b2b7076965a2a00 to your computer and use it in GitHub Desktop.
Simple redis implementation in Golang supporting only GET and SET operations
package main
import (
"bufio"
"errors"
"fmt"
"io"
"log/slog"
"net"
"os"
"os/signal"
"strings"
"sync"
"sync/atomic"
"time"
)
var (
errServerNotStarted = errors.New("server is not yet started")
errServerAlreadyStarted = errors.New("server already started")
errServerIsStopping = errors.New("server is already stopping")
errMalformedInput = errors.New("malformed input")
unsupportedValue = errors.New("unsupported value")
)
func main() {
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelInfo}))
address := "0.0.0.0:3010"
logger.Info("starting server", slog.String("addr", address))
listener, err := net.Listen("tcp", address)
if err != nil {
logger.Error(
"cannot start tcp server",
slog.String("addr", address),
slog.String("err", err.Error()),
)
os.Exit(1)
}
server := NewServer(listener, logger)
go func() {
if err := server.Start(); err != nil {
logger.Error("cannot start server", slog.String("error", err.Error()))
os.Exit(1)
}
}()
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt)
<-c
if err := server.Stop(); err != nil {
logger.Error("cannot stop server", slog.String("error", err.Error()))
os.Exit(1)
}
}
type server struct {
listener net.Listener
logger *slog.Logger
started atomic.Bool
stopping atomic.Bool
stopped chan struct{}
clientGroup sync.WaitGroup
clientId atomic.Int64
clients sync.Map
handler map[string]commandDesc
database database
}
type commandDesc struct {
handler func(r io.Reader, nArgs int, w io.Writer) error
}
type database struct {
m *sync.RWMutex
dict map[string]string
}
func NewServer(listener net.Listener, logger *slog.Logger) *server {
s := &server{
listener: listener,
logger: logger,
stopping: atomic.Bool{},
stopped: make(chan struct{}),
clientGroup: sync.WaitGroup{},
clientId: atomic.Int64{},
clients: sync.Map{},
handler: make(map[string]commandDesc),
database: newDatabase(),
}
s.handler = map[string]commandDesc{
"GET": {
handler: s.handleGetCommand,
},
"SET": {
handler: s.handleSetCommand,
},
}
return s
}
func newDatabase() database {
return database{
m: &sync.RWMutex{},
dict: map[string]string{},
}
}
func (s *server) Start() error {
if alreadyStarted := !s.started.CompareAndSwap(false, true); alreadyStarted {
return errServerAlreadyStarted
}
defer close(s.stopped)
s.logger.Info("server started")
for !s.stopping.Load() {
conn, err := s.listener.Accept()
if err != nil {
if s.stopping.Load() {
// upon shutting down, the listener will be Closed and
// calling listener.Accept() will return an error.
// in this case, the error is expected and we shouldn't return
// it back to the caller.
break
}
return err
}
s.clientGroup.Add(1)
go s.handleClient(conn)
}
s.clients.Range(func(key, _ any) bool {
key.(net.Conn).Close()
return true
})
s.clientGroup.Wait()
return nil
}
func (s *server) Stop() error {
if !s.started.Load() {
return errServerNotStarted
}
if isAlreadyClosing := !s.stopping.CompareAndSwap(false, true); isAlreadyClosing {
return errServerIsStopping
}
s.logger.Info("closing server")
err := s.listener.Close()
<-s.stopped
return err
}
func (s *server) handleClient(netConn net.Conn) {
clientId := s.clientId.Add(1)
logger := s.logger.With(slog.Int64("client_id", clientId))
s.clients.Store(netConn, nil)
defer func() {
s.clients.Delete(netConn)
s.clientGroup.Done()
netConn.Close()
}()
reader := bufio.NewReader(netConn)
writer := newBufferWriter(netConn)
closing := atomic.Bool{}
go func() {
if err := writer.Start(&closing); err != nil {
logger.Error("cannot flush write to client", slog.String("error", err.Error()))
}
closing.Store(true)
}()
logError := func(err error) {
if errors.Is(err, io.EOF) {
logger.Info("client disconnected")
} else {
logger.Error("cannot read from client", slog.String("error", err.Error()))
}
}
for !closing.Load() {
length, err := readArrayValue(reader)
if err != nil {
logError(err)
break
}
if length < 1 {
logError(errMalformedInput)
break
}
commandStr := strings.Builder{}
if err := readAnyString(reader, &commandStr); err != nil {
logError(err)
break
}
commandName := strings.ToUpper(commandStr.String())
commandDesc, ok := s.handler[commandName]
if !ok {
logError(fmt.Errorf("unknown command '%s'", commandName))
break
}
if err := commandDesc.handler(reader, length, writer); err != nil {
logError(err)
break
}
}
closing.Store(true)
}
func (s *server) handleGetCommand(r io.Reader, nArgs int, w io.Writer) error {
key, err := s.parseGetCommandArgs(r, nArgs)
if err != nil {
return err
}
s.database.m.RLock()
defer s.database.m.RUnlock()
if value, ok := s.database.dict[key]; ok {
return encodeBulkStringValue(w, value)
} else {
return encodeNilValue(w)
}
}
func (s *server) parseGetCommandArgs(r io.Reader, nArgs int) (key string, err error) {
if nArgs != 2 {
return "", fmt.Errorf("wrong number of arguments")
}
b := strings.Builder{}
if err := readAnyString(r, &b); err != nil {
return "", err
}
return b.String(), nil
}
func (s *server) handleSetCommand(r io.Reader, nArgs int, w io.Writer) error {
key, value, err := s.parseSetCommandArgs(r, nArgs)
if err != nil {
return err
}
s.database.m.Lock()
s.database.dict[key] = value
s.database.m.Unlock()
return encodeSimpleString(w, "OK")
}
func (s *server) parseSetCommandArgs(r io.Reader, nArgs int) (string, string, error) {
if nArgs != 3 {
return "", "", fmt.Errorf("wrong number of arguments")
}
key := strings.Builder{}
if err := readAnyString(r, &key); err != nil {
return "", "", err
}
value := strings.Builder{}
if err := readAnyString(r, &value); err != nil {
return "", "", err
}
return key.String(), value.String(), nil
}
func readArrayValue(r io.Reader) (length int, err error) {
b, err := readOne(r)
if err != nil {
return 0, err
}
if b != '*' {
return 0, errMalformedInput
}
length, err = readValueLength(r)
if err != nil {
return 0, err
}
return length, nil
}
func readOne(r io.Reader) (byte, error) {
if r, ok := r.(*bufio.Reader); ok {
return r.ReadByte()
}
buff := []byte{0}
if _, err := io.ReadFull(r, buff); err != nil {
return 0, err
}
return buff[0], nil
}
func readValueLength(r io.Reader) (int, error) {
b, err := readOne(r)
if err != nil {
return 0, err
}
if b < '0' || b > '9' {
return 0, errMalformedInput
}
result := int(b - '0')
for b != '\r' {
b, err = readOne(r)
if err != nil {
return 0, err
}
if b != '\r' {
if b < '0' || b > '9' {
return 0, errMalformedInput
}
result = result*10 + int(b-'0')
}
}
b, err = readOne(r)
if err != nil {
return 0, err
}
if b != '\n' {
return 0, errMalformedInput
}
return result, nil
}
func readUntilCRLF(r io.Reader, w io.Writer) error {
b, err := readOne(r)
if err != nil {
return err
}
for b != '\r' {
if _, err := w.Write([]byte{b}); err != nil {
return err
}
b, err = readOne(r)
if err != nil {
return err
}
}
b, err = readOne(r)
if err != nil {
return err
}
if b != '\n' {
return errMalformedInput
}
return nil
}
func skipCRLF(r io.Reader) error {
b, err := readOne(r)
if err != nil {
return err
}
if b != '\r' {
return errMalformedInput
}
b, err = readOne(r)
if err != nil {
return err
}
if b != '\n' {
return errMalformedInput
}
return nil
}
func readAnyString(r io.Reader, w io.Writer) error {
b, err := readOne(r)
if err != nil {
return err
}
if b == '$' {
length, err := readValueLength(r)
if err != nil {
return err
}
if sb, ok := w.(interface{ Grow(int) }); ok {
sb.Grow(length)
}
if _, err := io.CopyN(w, r, int64(length)); err != nil {
return err
}
return skipCRLF(r)
} else if b == '+' {
return readUntilCRLF(r, w)
} else {
return errMalformedInput
}
}
func encodeSimpleString(w io.Writer, s string) error {
if _, err := fmt.Fprintf(w, "+%s\r\n", s); err != nil {
return err
}
return nil
}
func encodeBulkStringValue(w io.Writer, s string) error {
if _, err := w.Write([]byte{'$'}); err != nil {
return err
}
if _, err := fmt.Fprintf(w, "%d\r\n", len(s)); err != nil {
return err
}
if _, err := io.WriteString(w, s); err != nil {
return err
}
if _, err := w.Write([]byte{'\r', '\n'}); err != nil {
return err
}
return nil
}
func encodeNilValue(w io.Writer) error {
if _, err := w.Write([]byte{'_', '\r', '\n'}); err != nil {
return err
}
return nil
}
type bufferWriter struct {
w io.Writer
lock *sync.Mutex
cond *sync.Cond
size int
buff []byte
n int
err error
}
func newBufferWriter(w io.Writer) *bufferWriter {
lock := &sync.Mutex{}
return &bufferWriter{
w: w,
lock: lock,
cond: sync.NewCond(lock),
size: 4096,
buff: make([]byte, 4096),
n: 0,
err: nil,
}
}
func (b *bufferWriter) Start(stop *atomic.Bool) error {
for !stop.Load() {
b.lock.Lock()
for b.n == 0 {
b.cond.Wait()
b.lock.Unlock()
time.Sleep(500 * time.Microsecond)
b.lock.Lock()
}
if err := b.flush(); err != nil {
return err
}
b.lock.Unlock()
}
return nil
}
func (b *bufferWriter) Write(buff []byte) (int, error) {
b.lock.Lock()
defer b.lock.Unlock()
totalWrite := 0
for len(buff) > b.size-b.n && b.err == nil {
var n int
if b.n == 0 {
n, b.err = b.w.Write(buff)
} else {
n = copy(b.buff[b.n:], buff)
b.n += n
b.flush()
}
totalWrite += n
buff = buff[n:]
}
if b.err != nil {
return totalWrite, b.err
}
n := copy(b.buff[b.n:], buff)
b.n += n
totalWrite += n
b.cond.Signal()
return totalWrite, nil
}
// WARNING: b.lock should already be locked
func (b *bufferWriter) flush() error {
if b.err != nil {
return b.err
}
if b.n == 0 {
return nil
}
n, err := b.w.Write(b.buff[:b.n])
if n < b.n && err == nil {
err = io.ErrShortWrite
}
if err != nil {
if n > 0 && n < b.n {
copy(b.buff[0:b.n-n], b.buff[n:b.n])
}
b.n -= n
b.err = err
return err
}
b.n = 0
return nil
}
@jauhararifin
Copy link
Author

jauhararifin commented Jan 10, 2024

Benchmark Result

This benchmark ran on macbook pro M1

❯ uname -a
Darwin xxxxxxxx-MacBook-Pro.local 23.0.0 Darwin Kernel Version 23.0.0: Fri Sep 15 14:41:34 PDT 2023; root:xnu-10002.1.13~1/RELEASE_ARM64_T8103 x86_64

The actual redis:
This use default redis configuration with disabled snapshot

❯ redis-benchmark -h localhost -p 6379 --seed 0 -n 1000000 -r 100000000000 -t set,get -c 50 -q -P 100
SET: 712758.38 requests per second, p50=6.031 msec
GET: 1190476.25 requests per second, p50=4.087 msec

This implementation:

# Run the server using:
❯ go run redis.go

# Benchmark using
❯ redis-benchmark -h localhost -p 3010 --seed 0 -n 1000000 -r 100000000000 -t set,get -c 50 -q -P 100
WARNING: Could not fetch server CONFIG
SET: 1574803.12 requests per second, p50=0.679 msec
GET: 4629629.50 requests per second, p50=0.831 msec

# Benchmark on GOMAXPROCS=1
❯ redis-benchmark -h localhost -p 3010 --seed 0 -n 1000000 -r 100000000000 -t set,get -c 50 -q -P 100
WARNING: Could not fetch server CONFIG
SET: 1018329.94 requests per second, p50=4.143 msec
GET: 1615508.88 requests per second, p50=2.767 msec

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment