Skip to content

Instantly share code, notes, and snippets.

@Hiyorimi
Created Jan 11, 2019
Embed
What would you like to do?
Minimal TCP testing server implementation
package main
import (
"bufio"
"errors"
"fmt"
"io"
"net"
"strings"
"time"
)
// Server defines the minimum contract our
// TCP and UDP server implementations must satisfy.
type TestingTCPServer interface {
Run() error
Put(*[]byte) error
Close() error
}
// NewServer creates a new Server using given protocol
// and addr.
func NewTestingTCPServer(protocol, addr string) (*TCPServer, error) {
switch strings.ToLower(protocol) {
case "tcp":
return &TCPServer{
addr: addr,
sendingQueue: make(chan *[]byte, 10),
}, nil
case "udp":
}
return nil, errors.New("invalid protocol given")
}
// TCPServer holds the structure of our TCP
// implementation.
type TCPServer struct {
addr string
server net.Listener
sendingQueue chan *[]byte
}
// Run starts the TCP Server.
func (t *TCPServer) Run() (err error) {
t.server, err = net.Listen("tcp", t.addr)
if err != nil {
return err
}
defer t.Close()
for {
conn, err := t.server.Accept()
fmt.Printf("Connection accepted\n")
if err != nil {
err = errors.New("could not accept connection")
break
}
if conn == nil {
err = errors.New("could not create connection")
break
}
go t.handleConnection(conn)
}
return
}
// Close shuts down the TCP Server
func (t *TCPServer) Close() (err error) {
close(t.sendingQueue)
return t.server.Close()
}
func (t *TCPServer) Put(data *[]byte) (err error) {
fmt.Printf("Putting data into channel\n")
go func(data *[]byte) {
t.sendingQueue <- data
fmt.Printf("Data really put into channel\n")
}(data)
fmt.Printf("Data put into channel\n")
return nil
}
// handleConnections deals with the business logic of
// each connection and their requests.
func (t *TCPServer) handleConnection(conn net.Conn) {
readerChannel := make(chan *[]byte)
defer conn.Close()
fmt.Printf("Connection being handled\n")
readBuffer := bufio.NewReader(conn)
writeBuffer := bufio.NewWriter(conn)
go func(readBuffer *bufio.Reader, writeBuffer *bufio.Writer, readerChannel chan *[]byte) {
for {
req, err := readBuffer.ReadString('\n')
if err != nil {
if err == io.EOF {
return
}
}
messageRead := []byte(fmt.Sprintf("Request received: %s", req))
fmt.Printf("Reader: Putting message into channel\n")
readerChannel <- &messageRead
}
}(readBuffer, writeBuffer, readerChannel)
ticker := time.Tick(time.Second * 5)
// go func() {
// f := []byte("test.")
// t.sendingQueue <- &f
// }()
for {
fmt.Printf("Started for loop\n")
select {
case data := <-readerChannel:
fmt.Printf("Read written data\n")
writeBuffer.Write(*data)
writeBuffer.Flush()
case data := <-t.sendingQueue:
fmt.Printf("Read pushed data\n")
writeBuffer.Write(*data)
writeBuffer.Flush()
case <-ticker:
fmt.Printf("Tick\n")
return
}
fmt.Printf("Finished for loop\n")
}
return
}
package main
import (
"bytes"
"database/sql"
"fmt"
"log"
"net"
"testing"
"github.com/stretchr/testify/assert"
)
var utilTestingSrv *TCPServer
func init() {
// Testing utils.go
// Start the new server.
utilTestingSrv, err := NewTestingTCPServer("tcp", ":41123")
if err != nil {
log.Println("error starting TestingTCPServer server")
return
}
// Run the server in Goroutine to stop tests from blocking
// test execution.
go func() {
utilTestingSrv.Run()
}()
}
func TestUtils_TestingTCPServer_Running(t *testing.T) {
// Simply check that the server is up and can
// accept connections.
servers := []struct {
protocol string
addr string
}{
{"tcp", ":41123"},
}
for _, serv := range servers {
conn, err := net.Dial(serv.protocol, serv.addr)
if !assert.Nil(t, err) {
t.Error("could not connect to server: ", err)
}
defer conn.Close()
}
}
func TestUtils_TestingTCPServer_Request(t *testing.T) {
servers := []struct {
protocol string
addr string
}{
{"tcp", ":41123"},
}
tt := []struct {
test string
payload []byte
want []byte
}{
{"Sending a simple request returns result",
[]byte("hello world\n"),
[]byte("Request received: hello world")},
{"Sending another simple request works",
[]byte("goodbye world\n"),
[]byte("Request received: goodbye world")},
}
for _, serv := range servers {
for _, tc := range tt {
t.Run(tc.test, func(t *testing.T) {
conn, err := net.Dial(serv.protocol, serv.addr)
if !assert.Nil(t, err) {
t.Error("could not connect to server: ", err)
}
defer conn.Close()
if _, err := conn.Write(tc.payload); !assert.Nil(t, err) {
t.Error("could not write payload to server:", err)
}
out := make([]byte, 1024)
if _, err := conn.Read(out); assert.Nil(t, err) {
// Need to remove trailing byte 0xa from bytes array to make sure bytes array are equal.
if out[len(tc.want)] == 0xa {
out[len(tc.want)] = 0x0
}
assert.Equal(t, tc.want, bytes.Trim(out, "\x00"))
} else {
t.Error("could not read from connection")
}
})
}
}
}
func TestUtils_TestingTCPServer_WritesRequest(t *testing.T) {
payload := []byte("hello world\n")
// INIT
utilTestingSrv, err := NewTestingTCPServer("tcp", ":41123")
if err != nil {
log.Println("error starting TestingTCPServer server")
return
}
// Run the server in Goroutine to stop tests from blocking
// test execution.
go func() {
utilTestingSrv.Run()
}()
// INIT
fmt.Printf("Putting payload into queue\n")
err = utilTestingSrv.Put(&payload)
assert.Nil(t, err)
conn, err := net.Dial("tcp", ":41123")
if !assert.Nil(t, err) {
t.Error("could not connect to server: ", err)
}
defer conn.Close()
out := make([]byte, 1024)
if _, err := conn.Read(out); assert.Nil(t, err) {
// Need to remove trailing byte 0xa from bytes array to make sure bytes array are equal.
if out[len(payload)] == 0xa {
out[len(payload)] = 0x0
}
assert.Equal(t, payload, bytes.Trim(out, "\x00"))
} else {
t.Error("could not read from connection")
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment