Skip to content

Instantly share code, notes, and snippets.

@wthsths
Created December 1, 2020 20:19
Show Gist options
  • Save wthsths/d40789428f1953f05fe782096fe359b4 to your computer and use it in GitHub Desktop.
Save wthsths/d40789428f1953f05fe782096fe359b4 to your computer and use it in GitHub Desktop.
package main
import (
"bufio"
"encoding/base64"
"github.com/gorilla/websocket"
"github.com/sirupsen/logrus"
"io"
"log"
"net"
"net/http"
"os"
"strings"
"time"
)
var (
writeWait = 30 * time.Second
maxMessageSize int64 = 2048
pongWait = 60 * time.Second
pingPeriod = (pongWait * 9) / 10
)
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
},
EnableCompression: false,
}
func main() {
addr := os.Getenv("WS_TCP_PROXY_SVC_ADDR")
mode := os.Getenv("WS_TCP_PROXY_SVC_MODE")
var debug bool
if mode != "production" {
debug = true
}
if addr == "" {
addr = ":8080"
}
logrus.SetOutput(os.Stdout)
if debug {
logrus.SetLevel(logrus.DebugLevel)
}
logrus.Debugf("websocket is starting on %v", addr)
http.HandleFunc("/ping", func(writer http.ResponseWriter, _ *http.Request) {
writer.Write([]byte("pong"))
})
http.HandleFunc("/", wsHandler)
if err := http.ListenAndServe(addr, nil); err != nil {
logrus.Fatal(err)
}
}
func wsHandler(w http.ResponseWriter, r *http.Request) {
logrus.Debugf("new connection handled")
wsConn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
logrus.Debugf("error upgrading conn %v", err)
return
}
tcpAddrParam, ok := r.URL.Query()["addr"]
if !ok {
logrus.Debugf("no addr found")
return
}
tcpAddr, _ := base64.StdEncoding.DecodeString(tcpAddrParam[0])
tcpConn, err := net.Dial("tcp", string(tcpAddr))
if err != nil {
log.Printf("failed to dial tcp conn: %v", err)
return
}
doneChan := make(chan bool)
go in(wsConn, tcpConn, doneChan)
go out(tcpConn, wsConn, doneChan)
<-doneChan
tcpConn.Close()
wsConn.Close()
<-doneChan
logrus.Debugf("client has disconnected")
}
func in(dst *websocket.Conn, src io.ReadCloser, doneChan chan<- bool) {
ticker := time.NewTicker(pingPeriod)
sendChan := make(chan []byte)
defer func() {
ticker.Stop()
close(sendChan)
}()
go func() {
for {
reader := bufio.NewReader(src)
var buffer string
for {
ba, err := reader.ReadByte()
if err != nil {
logrus.Debugf("error reading byte %v", err)
return
}
buffer += string(ba)
received := strings.Split(buffer, "\000")
for len(received) > 1 {
sendChan <- []byte(received[0])
buffer = strings.Join(received[1:], "\000")
received = strings.Split(buffer, "\000")
}
}
}
}()
for {
select {
case message, ok := <-sendChan:
dst.SetWriteDeadline(time.Now().Add(writeWait))
if !ok {
logrus.Debugf("ok false")
dst.WriteMessage(websocket.CloseMessage, []byte{})
doneChan <- true
return
}
if err := dst.WriteMessage(websocket.TextMessage, message); err != nil {
logrus.Debugf("error writing text message %s", message)
doneChan <- true
return
}
case <-ticker.C:
dst.SetWriteDeadline(time.Now().Add(writeWait))
if err := dst.WriteMessage(websocket.PingMessage, nil); err != nil {
logrus.Debugf("closing due to wtrigin ping message %v", err)
doneChan <- true
return
}
}
}
}
func out(dst io.WriteCloser, src *websocket.Conn, doneChan chan<- bool) {
src.SetReadLimit(maxMessageSize)
src.SetReadDeadline(time.Now().Add(pongWait))
src.SetPongHandler(func(string) error {
src.SetReadDeadline(time.Now().Add(pongWait))
return nil
})
for {
_, msg, err := src.ReadMessage()
if err != nil {
logrus.Debugf("closing here %v", err)
doneChan <- true
return
}
logrus.Debugf("ws to tcp command %s", string(msg))
dst.Write(msg)
dst.Write([]byte("\r\n"))
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment