Skip to content

Instantly share code, notes, and snippets.

@mmitou
Created May 29, 2021 09:48
Show Gist options
  • Save mmitou/af112ffe114e5532f71c9e94c37284e1 to your computer and use it in GitHub Desktop.
Save mmitou/af112ffe114e5532f71c9e94c37284e1 to your computer and use it in GitHub Desktop.
redisを使うようにした
package main
import (
"context"
"errors"
"fmt"
"strings"
"time"
"github.com/go-redis/redis/v8"
"github.com/gorilla/websocket"
"github.com/labstack/echo/v4"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/rs/zerolog/pkgerrors"
)
const (
// Time allowed to write a message to the peer.
writeWait = 10 * time.Second
// Time allowed to read the next pong message from the peer.
pongWait = 60 * time.Second
// Send pings to peer with this period. Must be less than pongWait.
pingPeriod = (pongWait * 9) / 10
// Maximum message size allowed from peer.
maxMessageSize = 512
)
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
type wsclient struct {
roomID string
id string
conn *websocket.Conn
msg chan<- message
}
type wsMessage struct {
messageType int
payload []byte
}
type message struct {
roomID string
clientID string
wsMessage
err error
}
func (c wsclient) reciever() {
defer c.conn.Close()
c.conn.SetReadLimit(maxMessageSize)
c.conn.SetReadDeadline(time.Now().Add(pongWait))
c.conn.SetPongHandler(func(string) error { c.conn.SetReadDeadline(time.Now().Add(pongWait)); return nil })
for {
mt, p, err := c.conn.ReadMessage()
if err != nil {
c.msg <- message{roomID: c.roomID, clientID: c.id, err: err}
return
}
c.msg <- message{roomID: c.roomID, clientID: c.id, wsMessage: wsMessage{messageType: mt, payload: p}, err: nil}
}
}
func (c wsclient) runSender() chan<- wsMessage {
snd := make(chan wsMessage)
go func() {
ticker := time.NewTicker(pingPeriod)
defer func() {
c.conn.Close()
ticker.Stop()
}()
for {
select {
case <-ticker.C:
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
c.msg <- message{clientID: c.id, err: err}
return
}
case m, ok := <-snd:
if !ok {
return
}
log.Debug().Str("clientID", c.id).Str("message", string(m.payload)).Msg("write")
if err := c.conn.WriteMessage(m.messageType, m.payload); err != nil {
c.msg <- message{roomID: c.roomID, clientID: c.id, err: err}
return
}
}
}
}()
return snd
}
func (c wsclient) connect(msg chan<- message) chan<- wsMessage {
c.msg = msg
go c.reciever()
snd := c.runSender()
return snd
}
func runEchoServer(ctx context.Context) chan<- wsclient {
register := make(chan wsclient)
go func() {
snds := make(map[string]map[string]chan<- wsMessage)
msg := make(chan message)
for {
select {
case <-ctx.Done():
return
case client := <-register:
log.Debug().Str("roomID", client.roomID).Str("clientID", client.id).Msg("register")
snd := client.connect(msg)
if _, ok := snds[client.roomID]; !ok {
snds[client.roomID] = make(map[string]chan<- wsMessage)
}
snds[client.roomID][client.id] = snd
case m := <-msg:
log.Debug().Msg(string(m.payload))
if m.err != nil {
log.Debug().Err(m.err).Msg("")
if _, ok := snds[m.roomID]; !ok {
continue
}
if _, ok := snds[m.roomID][m.clientID]; !ok {
continue
}
close(snds[m.roomID][m.clientID])
delete(snds[m.roomID], m.clientID)
if len(snds[m.roomID]) == 0 {
delete(snds, m.roomID)
}
continue
}
for id, snd := range snds[m.roomID] {
log.Debug().Str("clientID", id).Str("message", string(m.payload)).Msg("send")
snd <- wsMessage{messageType: m.messageType, payload: m.payload}
}
}
}
}()
return register
}
func hub(registrar chan<- wsclient) func(c echo.Context) error {
i := 0
return func(c echo.Context) error {
i++
w := c.Response()
r := c.Request()
roomID := c.Param("id")
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return err
}
registrar <- wsclient{roomID: roomID, id: fmt.Sprintf("hello%d", i), conn: conn}
return nil
}
}
func (c wsclient) sender(snd <-chan message, msg chan<- message) {
defer c.conn.Close()
for {
select {
case m, ok := <-snd:
if !ok {
return
}
err := c.conn.WriteMessage(m.messageType, m.payload)
if err != nil {
msg <- message{roomID: c.roomID, clientID: c.id, err: err}
}
}
}
}
func parseRedisPayload(p string) (string, string, error) {
i := strings.Index(p, "@")
if i < 1 {
return "", "", errors.New("no room id")
}
roomID := p[:i]
rest := p[i+1:]
return roomID, rest, nil
}
func runPubSub(ctx context.Context) chan<- wsclient {
register := make(chan wsclient)
rdb := redis.NewClient(&redis.Options{Addr: "localhost:6379", Password: "", DB: 0})
subscriber := func () {
rooms := make(map[string]map[string]chan<-wsMessage)
sub := rdb.Subscribe(ctx, "sdp")
for {
select {
case client := <-register:
if _, ok := rooms[client.roomID]; !ok {
rooms[client.roomID] = make(map[string]chan<-wsMessage)
}
snd := make(chan wsMessage)
rooms[client.roomID][client.id] = snd
go func(conn *websocket.Conn, snd <-chan wsMessage) {
for {
select {
case msg := <-snd:
log.Debug().Str("goroutine", "sender").Msg("conn.WriteMessage")
conn.WriteMessage(msg.messageType, msg.payload)
}
}
}(client.conn, snd)
go func(roomID, clientID string, conn *websocket.Conn) {
log.Debug().Str("goroutine", "reciever").Msg("begin")
defer log.Debug().Str("goroutine", "reciever").Msg("end")
for {
_, p, err := conn.ReadMessage()
if err != nil {
log.Error().Err(err).Msg("conn.ReadMessage")
return
}
payload := append([]byte(roomID + "@" + clientID + "@"), p...)
log.Debug().Str("goroutine", "reciever").Str("payload", string(payload)).Msg("publish")
err = rdb.Publish(ctx, "sdp", payload).Err()
if err != nil {
log.Error().Err(err).Msg("rdb.Publish")
return
}
}
}(client.roomID, client.id, client.conn)
case rmsg := <-sub.Channel():
log.Debug().Str("goroutine","subscriber").Str("payload", rmsg.Payload).Msg("recieve subscribed msg")
roomID, payload, err := parseRedisPayload(rmsg.Payload)
if err != nil {
log.Error().Err(err).Str("payload", rmsg.Payload).Msg("parseRedisMessage")
return
}
for _, snd := range rooms[roomID] {
snd<- wsMessage{messageType: websocket.TextMessage, payload: []byte(payload)}
}
}
}
}
go subscriber()
return register
}
func main() {
zerolog.TimeFieldFormat = time.RFC3339Nano
zerolog.ErrorStackMarshaler = pkgerrors.MarshalStack
zerolog.SetGlobalLevel(zerolog.DebugLevel)
e := echo.New()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// registrar := runEchoServer(ctx)
registrar := runPubSub(ctx)
e.HTTPErrorHandler = func(err error, c echo.Context) {
log.Debug().Err(err).Msg(fmt.Sprintf("%+v", err))
e.DefaultHTTPErrorHandler(err, c)
}
// e.GET("/rooms/:id/ws", hub(registrar))
e.GET("/rooms/:id/ws", hub(registrar))
e.Static("/", "./web")
e.Logger.Fatal(e.Start(":8080"))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment