Skip to content

Instantly share code, notes, and snippets.

@eliquious
Created January 23, 2016 19:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save eliquious/b4add2899fa6107d0941 to your computer and use it in GitHub Desktop.
Save eliquious/b4add2899fa6107d0941 to your computer and use it in GitHub Desktop.
LMAX Disruptor TCP Server
//CLIENT
package main
import (
"bufio"
"fmt"
"io"
"net"
"runtime"
"sync"
"sync/atomic"
"time"
)
var requestCount uint64
var totalPingsPerConnection uint64 = 10000000
var concurrentConnections uint64 = 32
var totalPings = concurrentConnections * totalPingsPerConnection
func monitor(done chan bool) chan bool {
out := make(chan bool)
go func() {
var last uint64
start := time.Now()
for {
select {
case <-done:
elapsed := time.Since(start)
fmt.Printf("%f ns\n", float64(elapsed)/float64(requestCount))
fmt.Printf("%d requests\n", requestCount)
fmt.Printf("%f requests per second\n", float64(time.Second)/(float64(elapsed)/float64(requestCount)))
fmt.Printf("elapsed: %s\r\n", elapsed)
out <- true
return
case <-time.After(1 * time.Second):
current := atomic.LoadUint64(&requestCount)
fmt.Printf("%d combined requests per second (%d)\n", current-last, current)
last = current
if current >= uint64(totalPings) {
return
}
}
}
}()
return out
}
func (c *client) readLoop(wg *sync.WaitGroup) {
defer wg.Done()
rd := bufio.NewReader(c.conn)
buf := make([]byte, 4)
for atomic.LoadUint64(&c.revcd) < totalPingsPerConnection {
n, err := rd.Read(buf)
if n > 0 {
atomic.AddUint64(&c.revcd, 1)
atomic.AddUint64(&requestCount, 1)
} else if err == io.EOF {
return
}
if err != nil && err != io.EOF {
fmt.Println(err)
return
}
}
// fmt.Printf("total recvd: %d\r\n", c.revcd)
}
func (c *client) writeLoop(wg *sync.WaitGroup) {
defer wg.Done()
wr := bufio.NewWriterSize(c.conn, 65536)
outBuf := []byte("Ping")
// var buffered int
for atomic.LoadUint64(&c.sent) < totalPingsPerConnection {
n, err := wr.Write(outBuf)
if n > 0 {
}
if err != nil && err != io.EOF {
fmt.Println(err)
return
}
atomic.AddUint64(&c.sent, 1)
}
wr.Flush()
// fmt.Printf("total sent: %d\r\n", c.sent)
}
const RingBufferCapacity = 1024 * 1024
type client struct {
sent uint64
revcd uint64
conn *net.TCPConn
}
func NewClient(wg *sync.WaitGroup) {
defer wg.Done()
tcpAddr, _ := net.ResolveTCPAddr("tcp4", "localhost:9022")
conn, err := net.DialTCP("tcp", nil, tcpAddr)
if err != nil {
fmt.Println(err)
return
}
var w sync.WaitGroup
c := client{conn: conn}
w.Add(2)
go c.writeLoop(&w)
go c.readLoop(&w)
w.Wait()
conn.Close()
}
func main() {
runtime.GOMAXPROCS(8)
var wg sync.WaitGroup
done := make(chan bool)
c := monitor(done)
for i := uint64(0); i < concurrentConnections; i++ {
wg.Add(1)
go NewClient(&wg)
}
wg.Wait()
done <- true
<-c
}
package main
import (
"fmt"
"io"
"log"
"net"
"os"
"runtime"
"time"
disruptor "github.com/smartystreets/go-disruptor"
"golang.org/x/net/context"
)
func main() {
runtime.GOMAXPROCS(8)
server := Server{}
server.Start(":9022")
}
// Server handles all the incoming connections as well as handler dispatch.
type Server struct {
Logger *log.Logger
Addr *net.TCPAddr
listener *net.TCPListener
context context.Context
cancel context.CancelFunc
}
// Start starts accepting client connections. This method is non-blocking.
func (s *Server) Start(addr string) (err error) {
// Validate the ssh bind addr
if addr == "" {
err = fmt.Errorf("server: Empty bind address")
return
}
// Open SSH socket listener
netAddr, e := net.ResolveTCPAddr("tcp", addr)
if e != nil {
err = fmt.Errorf("server: Invalid tcp address")
return
}
// Create listener
listener, e := net.ListenTCP("tcp", netAddr)
if e != nil {
err = e
return
}
s.Logger = log.New(os.Stdout, "logger: ", log.Lshortfile)
s.listener = listener
s.Addr = listener.Addr().(*net.TCPAddr)
s.Logger.Println("Starting server", "addr", addr)
c, cancel := context.WithCancel(context.Background())
s.context = c
s.cancel = cancel
go s.listen(c)
<-c.Done()
return
}
// Stop stops the server and kills all goroutines. This method is blocking.
func (s *Server) Stop() {
s.Logger.Println("[INFO] Shutting down server...")
s.cancel()
}
// listen accepts new connections and handles the conversion from TCP to SSH connections.
func (s *Server) listen(c context.Context) {
defer s.listener.Close()
for {
// Accepts will only block for 1s
s.listener.SetDeadline(time.Now().Add(time.Second))
select {
// Stop server on channel receive
case <-c.Done():
s.Logger.Println("[DEBUG] Context Completed")
return
default:
// Accept new connection
tcpConn, err := s.listener.Accept()
if err != nil {
if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
// s.Logger.Println("[DBG] Connection timeout...")
} else {
s.Logger.Println("[WRN] Connection failed", "error", err)
}
continue
}
// Handle connection
s.Logger.Println("[INF] Successful TCP connection:", tcpConn.RemoteAddr().String())
h := NewTcpHandler(tcpConn, s.Logger, s.context)
go h.Execute()
}
}
}
const RingBufferCapacity = 16 * 1024 * 1024
const RingBufferMask = RingBufferCapacity - 1
func NewTcpHandler(conn net.Conn, logger *log.Logger, ctx context.Context) *tcpHandler {
ring := [RingBufferCapacity]byte{}
controller := disruptor.
Configure(RingBufferCapacity).
WithConsumerGroup(&ByteConsumer{Writer: conn, ring: &ring}).
Build()
controller.Start()
c, cancel := context.WithCancel(ctx)
return &tcpHandler{
logger, conn, &ring, &controller, c, cancel,
}
}
type tcpHandler struct {
logger *log.Logger
conn net.Conn
ring *[RingBufferCapacity]byte
controller *disruptor.Disruptor
context context.Context
cancel context.CancelFunc
}
func (t *tcpHandler) Execute() {
defer t.conn.Close()
defer t.controller.Stop()
// Read from connection
go t.createReadLoop()
<-t.context.Done()
}
func (t *tcpHandler) createReadLoop() {
defer t.cancel()
writer := t.controller.Writer()
buffer := make([]byte, 65336)
var seq int64
for {
select {
case <-t.context.Done():
return
default:
n, err := t.conn.Read(buffer)
// fmt.Printf("n: %d; err: %s\r\n", n, err)
if n > 0 {
seq = writer.Reserve(int64(n))
for i := 0; i < n; i++ {
t.ring[seq&RingBufferMask] = buffer[i]
}
writer.Commit(seq, seq+int64(n))
} else if err == io.EOF {
return
}
if err != nil && err != io.EOF {
return
}
}
}
}
type ByteConsumer struct {
Writer io.Writer
ring *[RingBufferCapacity]byte
buffer [65336]byte
}
func (b *ByteConsumer) Consume(lower, upper int64) {
var offset int
for lower <= upper {
if offset >= 65336 {
b.Writer.Write(b.buffer[:])
offset = 0
}
b.buffer[offset] = b.ring[lower&RingBufferMask]
offset++
lower++
}
if offset >= 0 {
b.Writer.Write(b.buffer[:offset])
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment