Skip to content

Instantly share code, notes, and snippets.

@navono
Created July 23, 2019 07:09
Show Gist options
  • Star 13 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save navono/d3742c4b0f26f68f1a48d86cf4556726 to your computer and use it in GitHub Desktop.
Save navono/d3742c4b0f26f68f1a48d86cf4556726 to your computer and use it in GitHub Desktop.
websocket reconnect in golang
package reconWS
import (
"errors"
"math/rand"
"net/http"
"net/url"
"sync"
"time"
"fmt"
"github.com/gorilla/websocket"
"github.com/jpillora/backoff"
//"github.com/sirupsen/logrus"
"go.uber.org/zap"
)
var (
ErrNotConnected = errors.New("websocket not connected")
ErrUrlEmpty = errors.New("url can not be empty")
ErrUrlWrongScheme = errors.New("websocket uri must start with ws or wss scheme")
ErrUrlNamePassNotAllowed = errors.New("user name and password are not allowed in websocket uri")
//ErrCantConnect = errors.New("websocket can't connect")
)
type WsOpts func(dl *websocket.Dialer)
type Websocket struct {
// Websocket ID
Id uint64
// Websocket Meta
Meta map[string]interface{}
//Logger *logrus.Logger
Logger *zap.Logger
Errors chan<- error
Reconnect bool
// default to 2 seconds
ReconnectIntervalMin time.Duration
// default to 30 seconds
ReconnectIntervalMax time.Duration
// interval, default to 1.5
ReconnectIntervalFactor float64
// default to 2 seconds
HandshakeTimeout time.Duration
// Verbose suppress connecting/reconnecting messages.
Verbose bool
// Cal function
OnConnect func(ws *Websocket)
OnDisconnect func(ws *Websocket)
OnConnectError func(ws *Websocket, err error)
OnDisconnectError func(ws *Websocket, err error)
OnReadError func(ws *Websocket, err error)
OnWriteError func(ws *Websocket, err error)
dialer *websocket.Dialer
url string
requestHeader http.Header
httpResponse *http.Response
mu sync.Mutex
dialErr error
isConnected bool
isClosed bool
*websocket.Conn
}
func (ws *Websocket) WriteJSON(v interface{}) error {
err := ErrNotConnected
if ws.IsConnected() {
err = ws.Conn.WriteJSON(v)
if err != nil {
if ws.OnWriteError != nil {
ws.OnWriteError(ws, err)
}
ws.closeAndReconnect()
}
}
return err
}
func (ws *Websocket) WriteMessage(messageType int, data []byte) error {
err := ErrNotConnected
if ws.IsConnected() {
err = ws.Conn.WriteMessage(messageType, data)
if err != nil {
if ws.OnWriteError != nil {
ws.OnWriteError(ws, err)
}
ws.closeAndReconnect()
}
}
return err
}
func (ws *Websocket) ReadMessage() (messageType int, message []byte, err error) {
err = ErrNotConnected
if ws.IsConnected() {
messageType, message, err = ws.Conn.ReadMessage()
if err != nil {
if ws.OnReadError != nil {
ws.OnReadError(ws, err)
}
ws.closeAndReconnect()
}
}
return
}
func (ws *Websocket) Close() {
ws.mu.Lock()
if ws.Conn != nil {
err := ws.Conn.Close()
if err == nil && ws.isConnected && ws.OnDisconnect != nil {
ws.OnDisconnect(ws)
}
if err != nil && ws.OnDisconnectError != nil {
ws.OnDisconnectError(ws, err)
}
}
//ws.isClosed = true
ws.isConnected = false
ws.mu.Unlock()
}
func (ws *Websocket) closeAndReconnect() {
ws.Close()
ws.Connect()
}
func (ws *Websocket) Dial(urlStr string, reqHeader http.Header, opts ...WsOpts) error {
_, err := parseUrl(urlStr)
if err != nil {
return err
}
ws.url = urlStr
//ws.isClosed = false
ws.setDefaults()
ws.dialer = &websocket.Dialer{
Proxy: http.ProxyFromEnvironment,
HandshakeTimeout: ws.HandshakeTimeout,
}
for _, opt := range opts {
opt(ws.dialer)
}
hs := ws.HandshakeTimeout
go ws.Connect()
// wait on first attempt
time.Sleep(hs)
return nil
}
func (ws *Websocket) Connect() {
b := &backoff.Backoff{
Min: ws.ReconnectIntervalMin,
Max: ws.ReconnectIntervalMax,
Factor: ws.ReconnectIntervalFactor,
Jitter: true,
}
// seed rand for backoff
rand.Seed(time.Now().UTC().UnixNano())
for {
//ws.mu.Lock()
//if ws.isClosed {
// ws.isClosed = false
// ws.mu.Unlock()
// return
//}
//ws.mu.Unlock()
nextInterval := b.Duration()
wsConn, httpResp, err := ws.dialer.Dial(ws.url, ws.requestHeader)
ws.mu.Lock()
ws.Conn = wsConn
ws.dialErr = err
ws.isConnected = err == nil
ws.httpResponse = httpResp
ws.mu.Unlock()
if err == nil {
if ws.Verbose && ws.Logger != nil {
ws.Logger.Info(fmt.Sprintf("Websocket[%d].Dial: connection was successfully established with %s\n", ws.Id, ws.url))
}
if ws.OnConnect != nil {
ws.OnConnect(ws)
}
return
} else {
if ws.Verbose && ws.Logger != nil {
ws.Logger.Error(fmt.Sprintf("Websocket[%d].Dial: can't connect to %s, will try again in %v\n", ws.Id, ws.url, nextInterval))
}
if ws.OnConnectError != nil {
ws.OnConnectError(ws, err)
}
}
time.Sleep(nextInterval)
}
}
func (ws *Websocket) GetHTTPResponse() *http.Response {
ws.mu.Lock()
defer ws.mu.Unlock()
return ws.httpResponse
}
func (ws *Websocket) GetDialError() error {
ws.mu.Lock()
defer ws.mu.Unlock()
return ws.dialErr
}
func (ws *Websocket) IsConnected() bool {
ws.mu.Lock()
defer ws.mu.Unlock()
return ws.isConnected
}
func (ws *Websocket) setDefaults() {
if ws.ReconnectIntervalMin == 0 {
ws.ReconnectIntervalMin = 2 * time.Second
}
if ws.ReconnectIntervalMax == 0 {
ws.ReconnectIntervalMax = 30 * time.Second
}
if ws.ReconnectIntervalFactor == 0 {
ws.ReconnectIntervalFactor = 1.5
}
if ws.HandshakeTimeout == 0 {
ws.HandshakeTimeout = 2 * time.Second
}
}
func parseUrl(urlStr string) (*url.URL, error) {
if urlStr == "" {
return nil, ErrUrlEmpty
}
u, err := url.Parse(urlStr)
if err != nil {
return nil, err
}
if u.Scheme != "ws" && u.Scheme != "wss" {
return nil, ErrUrlWrongScheme
}
if u.User != nil {
return nil, ErrUrlNamePassNotAllowed
}
return u, nil
}
@franck34
Copy link

Code for using this package should be great ! thx

@navono
Copy link
Author

navono commented Mar 18, 2021

Code for using this package should be great ! thx

glad to be of help.

@ethermachine
Copy link

ethermachine commented Aug 18, 2022

Code for using this package should be great ! thx

glad to be of help.

could you please elaborate on how can I use this code to catch the error, and try reconnect after

websocket: close 1006 (abnormal closure): unexpected EOF
panic: repeated read on failed websocket connection

I'm currently using "github.com/gorilla/websocket" with a simple loop but I don't know what can I do to attach your code. What I'm using atm:

for {
	_, nextNotification, err := wsSubscriber.ReadMessage()
	if err != nil {
		fmt.Println(err)
	}

	go doStuff(nextNotification, method, network)
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment