Skip to content

Instantly share code, notes, and snippets.

@padurean
Last active December 22, 2020 10:05
Show Gist options
  • Save padurean/3c88c6c7054522d16fb7d6683a1be928 to your computer and use it in GitHub Desktop.
Save padurean/3c88c6c7054522d16fb7d6683a1be928 to your computer and use it in GitHub Desktop.
package websocket
import (
"fmt"
"io"
"strings"
"time"
logger "github.com/rs/zerolog/log"
"golang.org/x/net/websocket"
)
// WebSocket ...
type WebSocket struct {
Name string
KeepAliveInterval time.Duration
MaxClientIdleTime time.Duration
SendDeadline time.Duration
PollValueInterval time.Duration
GetValue func() interface{}
NotEqual func(a, b interface{}) bool
}
// SetDefaultDurations sets defaults durations; user code still needs to set the other fields
func (ws *WebSocket) SetDefaultDurations() {
ws.KeepAliveInterval = 20 * time.Second
ws.MaxClientIdleTime = 60 * time.Second
ws.SendDeadline = 5 * time.Second
ws.PollValueInterval = 3 * time.Second
}
// SendAndWaitForAcknowledgement ...
func (ws *WebSocket) SendAndWaitForAcknowledgement(conn *websocket.Conn) {
logPrefix := fmt.Sprintf("WebSocket %s -", ws.Name)
logger.Info().Msgf("%s START connection", logPrefix)
pollValueChangedTicker := time.NewTicker(ws.PollValueInterval)
defer func() {
pollValueChangedTicker.Stop()
conn.Close()
logger.Info().Msgf("%s END connection", logPrefix)
}()
var lastValue, currValue interface{}
lastSentAt := time.Now()
clientLastSeenAt := time.Now()
send := func() error {
now := time.Now()
if err := conn.SetWriteDeadline(now.Add(ws.SendDeadline)); err != nil {
logger.Err(err).Msgf(
"%s ABORT send: error setting write deadline", logPrefix)
return err
}
msg := fmt.Sprintf("%v", currValue)
if err := websocket.Message.Send(conn, msg); err != nil {
logger.Err(err).Msgf("%s ABORT send: error sending", logPrefix)
return err
}
logger.Info().Msgf("%s SEND: %v", logPrefix, msg)
lastValue = currValue
lastSentAt = now
return nil
}
receive := func() error {
var received string
err := websocket.Message.Receive(conn, &received)
if err != nil {
switch {
case err == io.EOF:
logger.Info().Msgf(
"%s ABORT receive: client closed the connection: %v",
logPrefix, err)
case strings.Contains(err.Error(), "use of closed network connection"):
logger.Info().Msgf(
"%s ABORT receive: connection was meanwhile closed: %v",
logPrefix, err)
default:
logger.Err(err).Msgf(
"%s ABORT receive: error receiving from client", logPrefix)
}
return err
}
logger.Info().Msgf(
"%s RECEIVE client response: %s", logPrefix, received)
clientLastSeenAt = time.Now()
return nil
}
go func() {
currValue = ws.GetValue()
// execute 1st send with no delay
if err := send(); err != nil {
conn.Close()
return
}
for range pollValueChangedTicker.C {
// close connection if client did not respond since too long
clientIdleTime := time.Now().Sub(clientLastSeenAt)
if clientIdleTime >= ws.MaxClientIdleTime {
logger.Info().Msgf(
"%s ABORT send and close connection: did not hear from client since "+
"%s (more than max idle time %s)",
logPrefix, clientIdleTime, ws.MaxClientIdleTime)
conn.Close()
return
}
// execute subsequent sends only if the value changed or too much time passed
if ws.NotEqual(currValue, lastValue) || time.Now().Sub(lastSentAt) >= ws.KeepAliveInterval {
if err := send(); err != nil {
conn.Close()
return
}
}
}
}()
for {
if err := receive(); err != nil {
conn.Close()
return
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment