Skip to content

Instantly share code, notes, and snippets.

@vmihailenco
Created June 25, 2012 21:33
Show Gist options
  • Save vmihailenco/2991383 to your computer and use it in GitHub Desktop.
Save vmihailenco/2991383 to your computer and use it in GitHub Desktop.
Redis proxy
package main
import (
"flag"
"io"
"log"
"net"
"os"
"os/signal"
"runtime/pprof"
"sync"
)
var cpuprofile = flag.String("cpuprofile", "", "write cpu profile to file")
var proxyAddr = flag.String("l", "localhost:9999", "proxy address")
var redisAddr = flag.String("r", "localhost:6379", "redis address")
var readerPool = NewReaderPool(100, createReader)
func createReader() (*BufReader, error) {
return NewBufReader(8192), nil
}
func RedisConn() (io.ReadWriteCloser, error) {
conn, err := net.Dial("tcp", *redisAddr)
if err != nil {
log.Printf("redisConn: start redis")
conn, err = net.Dial("tcp", *redisAddr)
if err != nil {
return nil, err
}
}
return conn, nil
}
func proxyRequest(conn, redisConn io.ReadWriteCloser) bool {
r, err := readerPool.Get()
if err != nil {
log.Printf("proxyRequest: bufreader.Get: %v", err)
return false
}
err = r.PeekAll(conn)
if err != nil {
if err != io.EOF {
log.Printf("proxyRequest: PeekAll: %v", err)
}
return false
}
_, err = redisConn.Write(r.Bytes())
if err != nil {
log.Printf("proxyRequest: Write: %v", err)
return false
}
readerPool.Add(r)
return true
}
func proxyResponse(conn, redisConn io.ReadWriteCloser) bool {
r, err := readerPool.Get()
if err != nil {
log.Printf("proxyRequest: bufreader.Get: %v", err)
return false
}
err = r.PeekAll(redisConn)
if err != nil {
if err != io.EOF {
log.Printf("proxyResponse: PeekAll: %v", err)
}
return false
}
_, err = conn.Write(r.Bytes())
if err != nil {
log.Printf("proxyRequest: Write: %v", err)
return false
}
readerPool.Add(r)
return true
}
func proxyConn(conn io.ReadWriteCloser) {
redisConn, err := RedisConn()
if err != nil {
log.Printf("proxyConn: PopConn: %v", err)
return
}
for {
if !proxyRequest(conn, redisConn) || !proxyResponse(conn, redisConn) {
break
}
}
conn.Close()
redisConn.Close()
}
func waitSignals(c <-chan os.Signal) {
select {
case s := <-c:
if s == os.Interrupt {
if *cpuprofile != "" {
pprof.StopCPUProfile()
}
os.Exit(0)
}
}
}
func main() {
flag.Parse()
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt)
go waitSignals(c)
if *cpuprofile != "" {
f, err := os.Create(*cpuprofile)
if err != nil {
log.Fatal(err)
}
err = pprof.StartCPUProfile(f)
if err != nil {
log.Fatal(err)
}
defer pprof.StopCPUProfile()
}
log.Printf("proxy addr: %v", *proxyAddr)
log.Printf("redis addr: %v", *redisAddr)
listener, err := net.Listen("tcp", *proxyAddr)
if err != nil {
panic(err)
}
for {
conn, err := listener.Accept()
if err != nil {
panic(err)
}
go proxyConn(conn)
}
}
//------------------------------------------------------------------------------
type BufPool struct {
bufs [][]byte
mtx sync.Mutex
size, bufSize int
}
func NewBufPool(size, bufSize int) *BufPool {
p := &BufPool{
bufs: make([][]byte, 0, size),
size: size,
bufSize: bufSize,
}
p.fill()
return p
}
func (p *BufPool) fill() {
for i := 0; i < p.size; i++ {
buf := make([]byte, p.bufSize)
p.addBuffer(buf)
}
}
func (p *BufPool) PopBuffer() ([]byte, error) {
p.mtx.Lock()
p.mtx.Unlock()
if len(p.bufs) <= 0 {
log.Print("PopConn: increase pool size")
p.fill()
}
buf := p.bufs[len(p.bufs)-1]
p.bufs = p.bufs[:len(p.bufs)-1]
return buf, nil
}
func (p *BufPool) addBuffer(buf []byte) {
p.bufs = append(p.bufs, buf)
}
func (p *BufPool) AddBuffer(buf []byte) {
p.mtx.Lock()
p.addBuffer(buf)
p.mtx.Unlock()
}
//------------------------------------------------------------------------------
type readerFactory func() (*BufReader, error)
type ReaderPool struct {
readers []*BufReader
mtx sync.Mutex
size, bufSize int
createReader readerFactory
}
func NewReaderPool(size int, createReader readerFactory) *ReaderPool {
p := &ReaderPool{
readers: make([]*BufReader, 0, size),
size: size,
createReader: createReader,
}
// TODO
p.fill()
return p
}
func (p *ReaderPool) fill() error {
for i := 0; i < p.size; i++ {
r, err := p.createReader()
if err != nil {
return err
}
p.add(r)
}
return nil
}
func (p *ReaderPool) Get() (*BufReader, error) {
p.mtx.Lock()
if len(p.readers) == 0 {
log.Print("Get: increase pool size")
// TODO
p.fill()
}
last := len(p.readers) - 1
reader := p.readers[last]
p.readers = p.readers[:last]
p.mtx.Unlock()
return reader, nil
}
func (p *ReaderPool) add(reader *BufReader) {
p.readers = append(p.readers, reader)
}
func (p *ReaderPool) Add(reader *BufReader) {
p.mtx.Lock()
p.add(reader)
p.mtx.Unlock()
}
//------------------------------------------------------------------------------
type BufReader struct {
rd io.Reader
buf []byte
n, pos int
}
func NewBufReader(size int) *BufReader {
return &BufReader{
buf: make([]byte, size),
n: 0,
pos: 0,
}
}
func (r *BufReader) PeekAll(rd io.Reader) error {
r.n = 0
r.pos = 0
var err error
r.n, err = rd.Read(r.buf)
if err != nil {
return err
}
for r.n == len(r.buf) {
log.Printf("PeekAll: increase buffer size")
buf := make([]byte, len(r.buf))
n, err := rd.Read(buf)
if err != nil {
return err
}
r.buf = append(r.buf, buf...)
r.n += n
}
return nil
}
func (r *BufReader) ReadLine() ([]byte, error) {
pos := r.pos
for i := pos; i < r.n; i++ {
if r.buf[i] == '\n' {
r.pos = i + 1
break
}
}
if r.pos == pos {
return nil, io.EOF
}
return r.buf[pos : r.pos-2], nil
}
func (r *BufReader) Bytes() []byte {
return r.buf[:r.n]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment