Skip to content

Instantly share code, notes, and snippets.

@mmitou
Created May 24, 2021 09:02
Show Gist options
  • Save mmitou/ae231cc7e6fd58e7347bcddb6542416c to your computer and use it in GitHub Desktop.
Save mmitou/ae231cc7e6fd58e7347bcddb6542416c to your computer and use it in GitHub Desktop.
package main
import (
"context"
"fmt"
"time"
"github.com/gorilla/websocket"
"github.com/labstack/echo/v4"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/rs/zerolog/pkgerrors"
)
const (
// Time allowed to write a message to the peer.
writeWait = 10 * time.Second
// Time allowed to read the next pong message from the peer.
pongWait = 60 * time.Second
// Send pings to peer with this period. Must be less than pongWait.
pingPeriod = (pongWait * 9) / 10
// Maximum message size allowed from peer.
maxMessageSize = 512
)
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
type wsclient struct {
id string
conn *websocket.Conn
msg chan<- message
}
type wsMessage struct {
messageType int
payload []byte
}
type message struct {
clientID string
wsMessage
err error
}
func (c wsclient) reciever() {
defer c.conn.Close()
c.conn.SetReadLimit(maxMessageSize)
c.conn.SetReadDeadline(time.Now().Add(pongWait))
c.conn.SetPongHandler(func(string) error { c.conn.SetReadDeadline(time.Now().Add(pongWait)); return nil })
for {
mt, p, err := c.conn.ReadMessage()
if err != nil {
c.msg <- message{clientID: c.id, err: err}
return
}
c.msg <- message{clientID: c.id, wsMessage: wsMessage{messageType: mt, payload: p}, err: nil}
}
}
func (c wsclient) runSender() chan<- wsMessage {
snd := make(chan wsMessage)
go func() {
ticker := time.NewTicker(pingPeriod)
defer func() {
c.conn.Close()
ticker.Stop()
}()
for {
select {
case <-ticker.C:
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
c.msg <- message{clientID: c.id, err: err}
return
}
case m, ok := <-snd:
if !ok {
return
}
if err := c.conn.WriteMessage(m.messageType, m.payload); err != nil {
c.msg <- message{clientID: c.id, err: err}
return
}
}
}
}()
return snd
}
func (c wsclient) connect(msg chan<- message) chan<- wsMessage {
c.msg = msg
go c.reciever()
snd := c.runSender()
return snd
}
func runEchoServer(ctx context.Context) chan<- wsclient {
register := make(chan wsclient)
go func() {
snds := make(map[string]chan<- wsMessage)
msg := make(chan message)
for {
select {
case <-ctx.Done():
return
case client := <-register:
snds[client.id] = client.connect(msg)
case m := <-msg:
if m.err != nil {
close(snds[m.clientID])
delete(snds, m.clientID)
continue
}
for _, snd := range snds {
snd <- wsMessage{messageType: m.messageType, payload: m.payload}
}
}
}
}()
return register
}
func hub(registrar chan<- wsclient) func(c echo.Context) error {
i := 0
return func(c echo.Context) error {
i++
w := c.Response()
r := c.Request()
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return err
}
registrar <- wsclient{id: fmt.Sprintf("hello%d", i), conn: conn}
return nil
}
}
func main() {
zerolog.TimeFieldFormat = time.RFC3339Nano
zerolog.ErrorStackMarshaler = pkgerrors.MarshalStack
zerolog.SetGlobalLevel(zerolog.DebugLevel)
e := echo.New()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
registrar := runEchoServer(ctx)
e.HTTPErrorHandler = func(err error, c echo.Context) {
log.Debug().Err(err).Msg(fmt.Sprintf("%+v", err))
e.DefaultHTTPErrorHandler(err, c)
}
e.GET("/ws", hub(registrar))
e.Static("/", "./web")
e.Logger.Fatal(e.Start(":8080"))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment