Last active
January 10, 2024 15:10
-
-
Save jauhararifin/663752042410c5526b2b7076965a2a00 to your computer and use it in GitHub Desktop.
Simple redis implementation in Golang supporting only GET and SET operations
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"bufio" | |
"errors" | |
"fmt" | |
"io" | |
"log/slog" | |
"net" | |
"os" | |
"os/signal" | |
"strings" | |
"sync" | |
"sync/atomic" | |
"time" | |
) | |
var ( | |
errServerNotStarted = errors.New("server is not yet started") | |
errServerAlreadyStarted = errors.New("server already started") | |
errServerIsStopping = errors.New("server is already stopping") | |
errMalformedInput = errors.New("malformed input") | |
unsupportedValue = errors.New("unsupported value") | |
) | |
func main() { | |
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelInfo})) | |
address := "0.0.0.0:3010" | |
logger.Info("starting server", slog.String("addr", address)) | |
listener, err := net.Listen("tcp", address) | |
if err != nil { | |
logger.Error( | |
"cannot start tcp server", | |
slog.String("addr", address), | |
slog.String("err", err.Error()), | |
) | |
os.Exit(1) | |
} | |
server := NewServer(listener, logger) | |
go func() { | |
if err := server.Start(); err != nil { | |
logger.Error("cannot start server", slog.String("error", err.Error())) | |
os.Exit(1) | |
} | |
}() | |
c := make(chan os.Signal, 1) | |
signal.Notify(c, os.Interrupt) | |
<-c | |
if err := server.Stop(); err != nil { | |
logger.Error("cannot stop server", slog.String("error", err.Error())) | |
os.Exit(1) | |
} | |
} | |
type server struct { | |
listener net.Listener | |
logger *slog.Logger | |
started atomic.Bool | |
stopping atomic.Bool | |
stopped chan struct{} | |
clientGroup sync.WaitGroup | |
clientId atomic.Int64 | |
clients sync.Map | |
handler map[string]commandDesc | |
database database | |
} | |
type commandDesc struct { | |
handler func(r io.Reader, nArgs int, w io.Writer) error | |
} | |
type database struct { | |
m *sync.RWMutex | |
dict map[string]string | |
} | |
func NewServer(listener net.Listener, logger *slog.Logger) *server { | |
s := &server{ | |
listener: listener, | |
logger: logger, | |
stopping: atomic.Bool{}, | |
stopped: make(chan struct{}), | |
clientGroup: sync.WaitGroup{}, | |
clientId: atomic.Int64{}, | |
clients: sync.Map{}, | |
handler: make(map[string]commandDesc), | |
database: newDatabase(), | |
} | |
s.handler = map[string]commandDesc{ | |
"GET": { | |
handler: s.handleGetCommand, | |
}, | |
"SET": { | |
handler: s.handleSetCommand, | |
}, | |
} | |
return s | |
} | |
func newDatabase() database { | |
return database{ | |
m: &sync.RWMutex{}, | |
dict: map[string]string{}, | |
} | |
} | |
func (s *server) Start() error { | |
if alreadyStarted := !s.started.CompareAndSwap(false, true); alreadyStarted { | |
return errServerAlreadyStarted | |
} | |
defer close(s.stopped) | |
s.logger.Info("server started") | |
for !s.stopping.Load() { | |
conn, err := s.listener.Accept() | |
if err != nil { | |
if s.stopping.Load() { | |
// upon shutting down, the listener will be Closed and | |
// calling listener.Accept() will return an error. | |
// in this case, the error is expected and we shouldn't return | |
// it back to the caller. | |
break | |
} | |
return err | |
} | |
s.clientGroup.Add(1) | |
go s.handleClient(conn) | |
} | |
s.clients.Range(func(key, _ any) bool { | |
key.(net.Conn).Close() | |
return true | |
}) | |
s.clientGroup.Wait() | |
return nil | |
} | |
func (s *server) Stop() error { | |
if !s.started.Load() { | |
return errServerNotStarted | |
} | |
if isAlreadyClosing := !s.stopping.CompareAndSwap(false, true); isAlreadyClosing { | |
return errServerIsStopping | |
} | |
s.logger.Info("closing server") | |
err := s.listener.Close() | |
<-s.stopped | |
return err | |
} | |
func (s *server) handleClient(netConn net.Conn) { | |
clientId := s.clientId.Add(1) | |
logger := s.logger.With(slog.Int64("client_id", clientId)) | |
s.clients.Store(netConn, nil) | |
defer func() { | |
s.clients.Delete(netConn) | |
s.clientGroup.Done() | |
netConn.Close() | |
}() | |
reader := bufio.NewReader(netConn) | |
writer := newBufferWriter(netConn) | |
closing := atomic.Bool{} | |
go func() { | |
if err := writer.Start(&closing); err != nil { | |
logger.Error("cannot flush write to client", slog.String("error", err.Error())) | |
} | |
closing.Store(true) | |
}() | |
logError := func(err error) { | |
if errors.Is(err, io.EOF) { | |
logger.Info("client disconnected") | |
} else { | |
logger.Error("cannot read from client", slog.String("error", err.Error())) | |
} | |
} | |
for !closing.Load() { | |
length, err := readArrayValue(reader) | |
if err != nil { | |
logError(err) | |
break | |
} | |
if length < 1 { | |
logError(errMalformedInput) | |
break | |
} | |
commandStr := strings.Builder{} | |
if err := readAnyString(reader, &commandStr); err != nil { | |
logError(err) | |
break | |
} | |
commandName := strings.ToUpper(commandStr.String()) | |
commandDesc, ok := s.handler[commandName] | |
if !ok { | |
logError(fmt.Errorf("unknown command '%s'", commandName)) | |
break | |
} | |
if err := commandDesc.handler(reader, length, writer); err != nil { | |
logError(err) | |
break | |
} | |
} | |
closing.Store(true) | |
} | |
func (s *server) handleGetCommand(r io.Reader, nArgs int, w io.Writer) error { | |
key, err := s.parseGetCommandArgs(r, nArgs) | |
if err != nil { | |
return err | |
} | |
s.database.m.RLock() | |
defer s.database.m.RUnlock() | |
if value, ok := s.database.dict[key]; ok { | |
return encodeBulkStringValue(w, value) | |
} else { | |
return encodeNilValue(w) | |
} | |
} | |
func (s *server) parseGetCommandArgs(r io.Reader, nArgs int) (key string, err error) { | |
if nArgs != 2 { | |
return "", fmt.Errorf("wrong number of arguments") | |
} | |
b := strings.Builder{} | |
if err := readAnyString(r, &b); err != nil { | |
return "", err | |
} | |
return b.String(), nil | |
} | |
func (s *server) handleSetCommand(r io.Reader, nArgs int, w io.Writer) error { | |
key, value, err := s.parseSetCommandArgs(r, nArgs) | |
if err != nil { | |
return err | |
} | |
s.database.m.Lock() | |
s.database.dict[key] = value | |
s.database.m.Unlock() | |
return encodeSimpleString(w, "OK") | |
} | |
func (s *server) parseSetCommandArgs(r io.Reader, nArgs int) (string, string, error) { | |
if nArgs != 3 { | |
return "", "", fmt.Errorf("wrong number of arguments") | |
} | |
key := strings.Builder{} | |
if err := readAnyString(r, &key); err != nil { | |
return "", "", err | |
} | |
value := strings.Builder{} | |
if err := readAnyString(r, &value); err != nil { | |
return "", "", err | |
} | |
return key.String(), value.String(), nil | |
} | |
func readArrayValue(r io.Reader) (length int, err error) { | |
b, err := readOne(r) | |
if err != nil { | |
return 0, err | |
} | |
if b != '*' { | |
return 0, errMalformedInput | |
} | |
length, err = readValueLength(r) | |
if err != nil { | |
return 0, err | |
} | |
return length, nil | |
} | |
func readOne(r io.Reader) (byte, error) { | |
if r, ok := r.(*bufio.Reader); ok { | |
return r.ReadByte() | |
} | |
buff := []byte{0} | |
if _, err := io.ReadFull(r, buff); err != nil { | |
return 0, err | |
} | |
return buff[0], nil | |
} | |
func readValueLength(r io.Reader) (int, error) { | |
b, err := readOne(r) | |
if err != nil { | |
return 0, err | |
} | |
if b < '0' || b > '9' { | |
return 0, errMalformedInput | |
} | |
result := int(b - '0') | |
for b != '\r' { | |
b, err = readOne(r) | |
if err != nil { | |
return 0, err | |
} | |
if b != '\r' { | |
if b < '0' || b > '9' { | |
return 0, errMalformedInput | |
} | |
result = result*10 + int(b-'0') | |
} | |
} | |
b, err = readOne(r) | |
if err != nil { | |
return 0, err | |
} | |
if b != '\n' { | |
return 0, errMalformedInput | |
} | |
return result, nil | |
} | |
func readUntilCRLF(r io.Reader, w io.Writer) error { | |
b, err := readOne(r) | |
if err != nil { | |
return err | |
} | |
for b != '\r' { | |
if _, err := w.Write([]byte{b}); err != nil { | |
return err | |
} | |
b, err = readOne(r) | |
if err != nil { | |
return err | |
} | |
} | |
b, err = readOne(r) | |
if err != nil { | |
return err | |
} | |
if b != '\n' { | |
return errMalformedInput | |
} | |
return nil | |
} | |
func skipCRLF(r io.Reader) error { | |
b, err := readOne(r) | |
if err != nil { | |
return err | |
} | |
if b != '\r' { | |
return errMalformedInput | |
} | |
b, err = readOne(r) | |
if err != nil { | |
return err | |
} | |
if b != '\n' { | |
return errMalformedInput | |
} | |
return nil | |
} | |
func readAnyString(r io.Reader, w io.Writer) error { | |
b, err := readOne(r) | |
if err != nil { | |
return err | |
} | |
if b == '$' { | |
length, err := readValueLength(r) | |
if err != nil { | |
return err | |
} | |
if sb, ok := w.(interface{ Grow(int) }); ok { | |
sb.Grow(length) | |
} | |
if _, err := io.CopyN(w, r, int64(length)); err != nil { | |
return err | |
} | |
return skipCRLF(r) | |
} else if b == '+' { | |
return readUntilCRLF(r, w) | |
} else { | |
return errMalformedInput | |
} | |
} | |
func encodeSimpleString(w io.Writer, s string) error { | |
if _, err := fmt.Fprintf(w, "+%s\r\n", s); err != nil { | |
return err | |
} | |
return nil | |
} | |
func encodeBulkStringValue(w io.Writer, s string) error { | |
if _, err := w.Write([]byte{'$'}); err != nil { | |
return err | |
} | |
if _, err := fmt.Fprintf(w, "%d\r\n", len(s)); err != nil { | |
return err | |
} | |
if _, err := io.WriteString(w, s); err != nil { | |
return err | |
} | |
if _, err := w.Write([]byte{'\r', '\n'}); err != nil { | |
return err | |
} | |
return nil | |
} | |
func encodeNilValue(w io.Writer) error { | |
if _, err := w.Write([]byte{'_', '\r', '\n'}); err != nil { | |
return err | |
} | |
return nil | |
} | |
type bufferWriter struct { | |
w io.Writer | |
lock *sync.Mutex | |
cond *sync.Cond | |
size int | |
buff []byte | |
n int | |
err error | |
} | |
func newBufferWriter(w io.Writer) *bufferWriter { | |
lock := &sync.Mutex{} | |
return &bufferWriter{ | |
w: w, | |
lock: lock, | |
cond: sync.NewCond(lock), | |
size: 4096, | |
buff: make([]byte, 4096), | |
n: 0, | |
err: nil, | |
} | |
} | |
func (b *bufferWriter) Start(stop *atomic.Bool) error { | |
for !stop.Load() { | |
b.lock.Lock() | |
for b.n == 0 { | |
b.cond.Wait() | |
b.lock.Unlock() | |
time.Sleep(500 * time.Microsecond) | |
b.lock.Lock() | |
} | |
if err := b.flush(); err != nil { | |
return err | |
} | |
b.lock.Unlock() | |
} | |
return nil | |
} | |
func (b *bufferWriter) Write(buff []byte) (int, error) { | |
b.lock.Lock() | |
defer b.lock.Unlock() | |
totalWrite := 0 | |
for len(buff) > b.size-b.n && b.err == nil { | |
var n int | |
if b.n == 0 { | |
n, b.err = b.w.Write(buff) | |
} else { | |
n = copy(b.buff[b.n:], buff) | |
b.n += n | |
b.flush() | |
} | |
totalWrite += n | |
buff = buff[n:] | |
} | |
if b.err != nil { | |
return totalWrite, b.err | |
} | |
n := copy(b.buff[b.n:], buff) | |
b.n += n | |
totalWrite += n | |
b.cond.Signal() | |
return totalWrite, nil | |
} | |
// WARNING: b.lock should already be locked | |
func (b *bufferWriter) flush() error { | |
if b.err != nil { | |
return b.err | |
} | |
if b.n == 0 { | |
return nil | |
} | |
n, err := b.w.Write(b.buff[:b.n]) | |
if n < b.n && err == nil { | |
err = io.ErrShortWrite | |
} | |
if err != nil { | |
if n > 0 && n < b.n { | |
copy(b.buff[0:b.n-n], b.buff[n:b.n]) | |
} | |
b.n -= n | |
b.err = err | |
return err | |
} | |
b.n = 0 | |
return nil | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Benchmark Result
This benchmark ran on macbook pro M1
The actual redis:
This use default redis configuration with disabled snapshot
This implementation: