Skip to content

Instantly share code, notes, and snippets.

@neel-bp
Last active December 1, 2023 16:52
Show Gist options
  • Save neel-bp/8b4e16197f1421478b81faec88b012b2 to your computer and use it in GitHub Desktop.
Save neel-bp/8b4e16197f1421478b81faec88b012b2 to your computer and use it in GitHub Desktop.
package main
import (
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/redis/go-redis/v9"
)
const (
// Time allowed to write a message to the peer.
WRITE_WAIT = 10 * time.Second
// Time allowed to read the next pong message from the peer.
PONG_WAIT = 60 * time.Second
// Send pings to peer with this period. Must be less than pongWait.
PING_PERIOD = (PONG_WAIT * 9) / 10
// Maximum message size allowed from peer.
MAX_MESSAGE_SIZE = 512
)
var Upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
},
ReadBufferSize: 2048,
WriteBufferSize: 2048,
Subprotocols: []string{"name"},
}
var RDB *redis.Client
type Message struct {
From string `json:"from"`
Message string `json:"message"`
}
func (m Message) String() string {
b, _ := json.Marshal(m)
return string(b)
}
func (m Message) Byte() []byte {
b, _ := json.Marshal(m)
return b
}
func InitializeRedis() error {
if RDB == nil {
rdb := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
})
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := rdb.Ping(ctx).Err()
if err != nil {
return err
}
RDB = rdb
return nil
}
return nil
}
func AuthMiddleware(c *gin.Context) {
if !websocket.IsWebSocketUpgrade(c.Request) {
c.JSON(400, gin.H{
"message": "not a websocket upgrade",
})
c.Abort()
return
}
subprotocols := websocket.Subprotocols(c.Request)
if len(subprotocols) < 2 {
c.JSON(400, gin.H{
"message": "wrong format of subprotocols",
})
c.Abort()
return
}
if subprotocols[0] != "name" {
c.JSON(400, gin.H{
"message": "wrong format of subprotocols",
})
c.Abort()
return
}
c.Set("name", subprotocols[1])
c.Next()
}
func RoomSocket(c *gin.Context) {
room, ok := c.Params.Get("room")
if !ok {
c.JSON(400, gin.H{
"message": "no room code in url",
})
return
}
name := c.GetString("name")
if name == "" {
c.JSON(400, gin.H{
"message": "no name provided",
})
return
}
isMember, err := RDB.SIsMember(c.Request.Context(), room, name).Result()
if err != nil {
c.JSON(500, gin.H{
"message": err.Error(),
})
return
}
if isMember {
c.JSON(400, gin.H{
"message": "name already taken",
})
return
}
conn, err := Upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
c.JSON(500, gin.H{
"message": err.Error(),
})
return
}
defer conn.Close()
if err := RDB.SAdd(c.Request.Context(), room, name).Err(); err != nil {
conn.WriteMessage(websocket.CloseInternalServerErr, []byte(err.Error()))
return
}
// instead of directly processing messages from redis channel, maybe create an intermediary buffer which writepump listens too, but redis pubsub pushes message into that intermediary buffer
pubsub := RDB.Subscribe(c.Request.Context(), room)
defer pubsub.Close()
// NOTE: signal channel for cleanup
breaker := make(chan struct{}, 1)
writerChan := make(chan Message, 512)
pubsubChan := pubsub.Channel()
go pubsubQueue(pubsubChan, writerChan)
go readPump(c.Request.Context(), conn, breaker, writerChan, room, name)
go writePump(conn, breaker, writerChan, pubsubChan)
<-breaker
// using null context here because i want them sent regardless of request being cancelled
RDB.SRem(context.Background(), room, name)
RDB.Publish(context.Background(), room, Message{From: "<server>", Message: fmt.Sprintf("%s nigga has left the chat", name)}.Byte())
}
func pubsubQueue(pubsumMessage <-chan *redis.Message, writeQueue chan Message) {
for msg := range pubsumMessage {
m := Message{}
err := json.Unmarshal([]byte(msg.Payload), &m)
if err == nil {
writeQueue <- m
}
}
}
func readPump(ctx context.Context, conn *websocket.Conn, breakerChan chan struct{}, writeQueue chan Message, room, name string) {
conn.SetReadLimit(MAX_MESSAGE_SIZE)
conn.SetReadDeadline(time.Now().Add(PONG_WAIT))
conn.SetPongHandler(func(string) error {
conn.SetReadDeadline(time.Now().Add(PONG_WAIT))
return nil
})
defer func() {
breakerChan <- struct{}{}
}()
for {
var msg Message
typ, reader, err := conn.NextReader()
if err != nil {
return
}
if typ != websocket.TextMessage {
writeQueue <- Message{
From: "<server>",
Message: "wrong type of messge please send text message when communicating",
}
continue
}
byt, err := io.ReadAll(reader)
if err != nil {
return
}
msg.Message = string(byt)
msg.From = name
err = RDB.Publish(ctx, room, msg.Byte()).Err()
if err != nil {
writeQueue <- Message{
From: "<server>",
Message: err.Error(),
}
}
}
}
func writePump(conn *websocket.Conn, breakerChan chan struct{}, writeQueue chan Message, pubsubMessage <-chan *redis.Message) {
ticker := time.NewTicker(PING_PERIOD)
defer func() {
ticker.Stop()
breakerChan <- struct{}{}
}()
for {
select {
case wmsg := <-writeQueue:
conn.SetWriteDeadline(time.Now().Add(WRITE_WAIT))
w, err := conn.NextWriter(websocket.TextMessage)
if err != nil {
return
}
w.Write(wmsg.Byte())
if err := w.Close(); err != nil {
return
}
case <-ticker.C:
conn.SetWriteDeadline(time.Now().Add(WRITE_WAIT))
if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil {
return
}
}
}
}
func main() {
if err := InitializeRedis(); err != nil {
log.Fatal(err)
}
router := gin.Default()
wsGroup := router.Group("/ws")
wsGroup.Use(AuthMiddleware)
wsGroup.GET("/:room", RoomSocket)
if err := router.Run(":8080"); err != nil {
log.Fatal(err)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment