Skip to content

Instantly share code, notes, and snippets.

@maaft
Created December 9, 2020 10:31
Show Gist options
  • Save maaft/ac197c72fe5d18eec5916f3f5a75943f to your computer and use it in GitHub Desktop.
Save maaft/ac197c72fe5d18eec5916f3f5a75943f to your computer and use it in GitHub Desktop.
modified GraphQL subscription client (from https://github.com/hasura/go-graphql-client)
package subscription
import (
"context"
"encoding/json"
"fmt"
"io"
"strings"
"sync"
"time"
"github.com/google/uuid"
"nhooyr.io/websocket"
"nhooyr.io/websocket/wsjson"
)
// Subscription transport follow Apollo's subscriptions-transport-ws protocol specification
// https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md
// OperationMessageType
type OperationMessageType string
const (
// Client sends this message after plain websocket connection to start the communication with the server
GQL_CONNECTION_INIT OperationMessageType = "connection_init"
// The server may responses with this message to the GQL_CONNECTION_INIT from client, indicates the server rejected the connection.
GQL_CONNECTION_ERROR OperationMessageType = "conn_err"
// Client sends this message to execute GraphQL operation
GQL_START OperationMessageType = "start"
// Client sends this message in order to stop a running GraphQL operation execution (for example: unsubscribe)
GQL_STOP OperationMessageType = "stop"
// Server sends this message upon a failing operation, before the GraphQL execution, usually due to GraphQL validation errors (resolver errors are part of GQL_DATA message, and will be added as errors array)
GQL_ERROR OperationMessageType = "error"
// The server sends this message to transfter the GraphQL execution result from the server to the client, this message is a response for GQL_START message.
GQL_DATA OperationMessageType = "data"
// Server sends this message to indicate that a GraphQL operation is done, and no more data will arrive for the specific operation.
GQL_COMPLETE OperationMessageType = "complete"
// Server message that should be sent right after each GQL_CONNECTION_ACK processed and then periodically to keep the client connection alive.
// The client starts to consider the keep alive message only upon the first received keep alive message from the server.
GQL_CONNECTION_KEEP_ALIVE OperationMessageType = "ka"
// The server may responses with this message to the GQL_CONNECTION_INIT from client, indicates the server accepted the connection. May optionally include a payload.
GQL_CONNECTION_ACK OperationMessageType = "connection_ack"
// Client sends this message to terminate the connection.
GQL_CONNECTION_TERMINATE OperationMessageType = "connection_terminate"
// Unknown operation type, for logging only
GQL_UNKNOWN OperationMessageType = "unknown"
// Internal status, for logging only
GQL_INTERNAL OperationMessageType = "internal"
)
type OperationMessage struct {
ID string `json:"id,omitempty"`
Type OperationMessageType `json:"type"`
Payload json.RawMessage `json:"payload,omitempty"`
}
func (om OperationMessage) String() string {
bs, _ := json.Marshal(om)
return string(bs)
}
// WebsocketHandler abstracts WebSocket connecton functions
// ReadJSON and WriteJSON data of a frame from the WebSocket connection.
// Close the WebSocket connection.
type WebsocketConn interface {
ReadJSON(v interface{}) error
WriteJSON(v interface{}) error
Close() error
// SetReadLimit sets the maximum size in bytes for a message read from the peer. If a
// message exceeds the limit, the connection sends a close message to the peer
// and returns ErrReadLimit to the application.
SetReadLimit(limit int64)
}
type handlerFunc func(data *json.RawMessage, err error) error
type subscription struct {
query string
variables map[string]interface{}
handler func(data *json.RawMessage, err error)
started bool
}
// SubscriptionClient is a GraphQL subscription client.
type SubscriptionClient struct {
url string
conn WebsocketConn
connectionParams map[string]interface{}
context context.Context
subscriptions map[string]*subscription
cancel context.CancelFunc
subscribersMu sync.Mutex
timeout time.Duration
isRunning bool
readLimit int64 // max size of response message. Default 10 MB
log func(args ...interface{})
createConn func(sc *SubscriptionClient) (WebsocketConn, error)
retryTimeout time.Duration
onConnected func()
onDisconnected func()
onError func(sc *SubscriptionClient, err error) error
errorChan chan error
disabledLogTypes []OperationMessageType
}
func NewSubscriptionClient(url string) *SubscriptionClient {
return &SubscriptionClient{
url: url,
timeout: time.Minute,
readLimit: 10 * 1024 * 1024, // set default limit 10MB
subscriptions: make(map[string]*subscription),
createConn: newWebsocketConn,
retryTimeout: time.Minute,
errorChan: make(chan error),
}
}
// GetURL returns GraphQL server's URL
func (sc *SubscriptionClient) GetURL() string {
return sc.url
}
// GetContext returns current context of subscription client
func (sc *SubscriptionClient) GetContext() context.Context {
return sc.context
}
// GetContext returns write timeout of websocket client
func (sc *SubscriptionClient) GetTimeout() time.Duration {
return sc.timeout
}
// WithWebSocket replaces customized websocket client constructor
// In default, subscription client uses https://github.com/nhooyr/websocket
func (sc *SubscriptionClient) WithWebSocket(fn func(sc *SubscriptionClient) (WebsocketConn, error)) *SubscriptionClient {
sc.createConn = fn
return sc
}
// WithConnectionParams updates connection params for sending to server through GQL_CONNECTION_INIT event
// It's usually used for authentication handshake
func (sc *SubscriptionClient) WithConnectionParams(params map[string]interface{}) *SubscriptionClient {
sc.connectionParams = params
return sc
}
// WithTimeout updates write timeout of websocket client
func (sc *SubscriptionClient) WithTimeout(timeout time.Duration) *SubscriptionClient {
sc.timeout = timeout
return sc
}
// WithRetryTimeout updates reconnecting timeout. When the websocket server was stopped, the client will retry connecting every second until timeout
func (sc *SubscriptionClient) WithRetryTimeout(timeout time.Duration) *SubscriptionClient {
sc.retryTimeout = timeout
return sc
}
// WithLog sets loging function to print out received messages. By default, nothing is printed
func (sc *SubscriptionClient) WithLog(logger func(args ...interface{})) *SubscriptionClient {
sc.log = logger
return sc
}
// WithoutLogTypes these operation types won't be printed
func (sc *SubscriptionClient) WithoutLogTypes(types ...OperationMessageType) *SubscriptionClient {
sc.disabledLogTypes = types
return sc
}
// WithReadLimit set max size of response message
func (sc *SubscriptionClient) WithReadLimit(limit int64) *SubscriptionClient {
sc.readLimit = limit
return sc
}
// OnConnected event is triggered when there is any connection error. This is bottom exception handler level
// If this function is empty, or returns nil, the error is ignored
// If returns error, the websocket connection will be terminated
func (sc *SubscriptionClient) OnError(onError func(sc *SubscriptionClient, err error) error) *SubscriptionClient {
sc.onError = onError
return sc
}
// OnConnected event is triggered when the websocket connected to GraphQL server sucessfully
func (sc *SubscriptionClient) OnConnected(fn func()) *SubscriptionClient {
sc.onConnected = fn
return sc
}
// OnDisconnected event is triggered when the websocket server was stil down after retry timeout
func (sc *SubscriptionClient) OnDisconnected(fn func()) *SubscriptionClient {
sc.onDisconnected = fn
return sc
}
func (sc *SubscriptionClient) setIsRunning(value bool) {
sc.subscribersMu.Lock()
sc.isRunning = value
sc.subscribersMu.Unlock()
}
func (sc *SubscriptionClient) init() error {
now := time.Now()
ctx, cancel := context.WithCancel(context.Background())
sc.context = ctx
sc.cancel = cancel
for {
var err error
var conn WebsocketConn
// allow custom websocket client
if sc.conn == nil {
conn, err = newWebsocketConn(sc)
if err == nil {
sc.conn = conn
}
}
if err == nil {
sc.conn.SetReadLimit(sc.readLimit)
// send connection init event to the server
err = sc.sendConnectionInit()
}
if err == nil {
return nil
}
if now.Add(sc.retryTimeout).Before(time.Now()) {
if sc.onDisconnected != nil {
sc.onDisconnected()
}
return err
}
sc.printLog(err.Error()+". retry in second....", GQL_INTERNAL)
time.Sleep(time.Second)
}
}
func (sc *SubscriptionClient) printLog(message interface{}, opType OperationMessageType) {
if sc.log == nil {
return
}
for _, ty := range sc.disabledLogTypes {
if ty == opType {
return
}
}
sc.log(message)
}
func (sc *SubscriptionClient) sendConnectionInit() (err error) {
var bParams []byte = nil
if sc.connectionParams != nil {
bParams, err = json.Marshal(sc.connectionParams)
if err != nil {
return
}
}
// send connection_init event to the server
msg := OperationMessage{
Type: GQL_CONNECTION_INIT,
Payload: bParams,
}
sc.printLog(msg, GQL_CONNECTION_INIT)
return sc.conn.WriteJSON(msg)
}
// Subscribe sends start message to server and open a channel to receive data.
// The handler callback function will receive raw message data or error. If the call return error, onError event will be triggered
// The function returns subscription ID and error. You can use subscription ID to unsubscribe the subscription
func (sc *SubscriptionClient) Subscribe(query string, variables map[string]interface{}, handler func(message *json.RawMessage, err error) error) (string, error) {
return sc.do(query, variables, handler, "")
}
// NamedSubscribe sends start message to server and open a channel to receive data, with operation name
func (sc *SubscriptionClient) NamedSubscribe(name string, query string, variables map[string]interface{}, handler func(message *json.RawMessage, err error) error) (string, error) {
return sc.do(query, variables, handler, name)
}
func (sc *SubscriptionClient) do(query string, variables map[string]interface{}, handler func(message *json.RawMessage, err error) error, name string) (string, error) {
id := uuid.New().String()
sub := subscription{
query: query,
variables: variables,
handler: sc.wrapHandler(handler),
}
// if the websocket client is running, start subscription immediately
if sc.isRunning {
if err := sc.startSubscription(id, &sub); err != nil {
return "", err
}
}
sc.subscribersMu.Lock()
sc.subscriptions[id] = &sub
sc.subscribersMu.Unlock()
return id, nil
}
// Subscribe sends start message to server and open a channel to receive data
func (sc *SubscriptionClient) startSubscription(id string, sub *subscription) error {
if sub == nil || sub.started {
return nil
}
in := struct {
Query string `json:"query"`
Variables map[string]interface{} `json:"variables,omitempty"`
}{
Query: sub.query,
Variables: sub.variables,
}
payload, err := json.Marshal(in)
if err != nil {
return err
}
// send stop message to the server
msg := OperationMessage{
ID: id,
Type: GQL_START,
Payload: payload,
}
sc.printLog(msg, GQL_START)
if err := sc.conn.WriteJSON(msg); err != nil {
return err
}
sub.started = true
return nil
}
func (sc *SubscriptionClient) wrapHandler(fn handlerFunc) func(data *json.RawMessage, err error) {
return func(data *json.RawMessage, err error) {
if errValue := fn(data, err); errValue != nil {
sc.errorChan <- errValue
}
}
}
// Run start websocket client and subscriptions. If this function is run with goroutine, it can be stopped after closed
func (sc *SubscriptionClient) Run() error {
if err := sc.init(); err != nil {
return fmt.Errorf("retry timeout. exiting...")
}
// lazily start subscriptions
for k, v := range sc.subscriptions {
if err := sc.startSubscription(k, v); err != nil {
sc.Unsubscribe(k)
return err
}
}
sc.setIsRunning(true)
for sc.isRunning {
select {
case <-sc.context.Done():
return nil
case e := <-sc.errorChan:
if sc.onError != nil {
if err := sc.onError(sc, e); err != nil {
return err
}
}
default:
var message OperationMessage
if err := sc.conn.ReadJSON(&message); err != nil {
// manual EOF check
if err == io.EOF || strings.Contains(err.Error(), "EOF") {
return sc.Reset()
}
closeStatus := websocket.CloseStatus(err)
if closeStatus == websocket.StatusNormalClosure {
// close event from websocket client, exiting...
return nil
}
if closeStatus != -1 {
sc.printLog(fmt.Sprintf("%s. Retry connecting...", err), GQL_INTERNAL)
return sc.Reset()
}
if sc.onError != nil {
if err = sc.onError(sc, err); err != nil {
return err
}
}
continue
}
switch message.Type {
case GQL_ERROR:
sc.printLog(message, GQL_ERROR)
fallthrough
case GQL_DATA:
sc.printLog(message, GQL_DATA)
id, err := uuid.Parse(message.ID)
if err != nil {
continue
}
sub, ok := sc.subscriptions[id.String()]
if !ok {
continue
}
var out struct {
Data *json.RawMessage
Errors errors
//Extensions interface{} // Unused.
}
err = json.Unmarshal(message.Payload, &out)
if err != nil {
go sub.handler(nil, err)
continue
}
if len(out.Errors) > 0 {
go sub.handler(nil, out.Errors)
continue
}
go sub.handler(out.Data, nil)
case GQL_CONNECTION_ERROR:
sc.printLog(message, GQL_CONNECTION_ERROR)
case GQL_COMPLETE:
sc.printLog(message, GQL_COMPLETE)
sc.Unsubscribe(message.ID)
case GQL_CONNECTION_KEEP_ALIVE:
sc.printLog(message, GQL_CONNECTION_KEEP_ALIVE)
case GQL_CONNECTION_ACK:
sc.printLog(message, GQL_CONNECTION_ACK)
if sc.onConnected != nil {
sc.onConnected()
}
default:
sc.printLog(message, GQL_UNKNOWN)
}
}
}
// if the running status is false, stop retrying
if !sc.isRunning {
return nil
}
return sc.Reset()
}
// Unsubscribe sends stop message to server and close subscription channel
// The input parameter is subscription ID that is returned from Subscribe function
func (sc *SubscriptionClient) Unsubscribe(id string) error {
_, ok := sc.subscriptions[id]
if !ok {
return fmt.Errorf("subscription id %s doesn't not exist", id)
}
err := sc.stopSubscription(id)
sc.subscribersMu.Lock()
delete(sc.subscriptions, id)
sc.subscribersMu.Unlock()
return err
}
func (sc *SubscriptionClient) stopSubscription(id string) error {
if sc.conn != nil {
// send stop message to the server
msg := OperationMessage{
ID: id,
Type: GQL_STOP,
}
sc.printLog(msg, GQL_STOP)
if err := sc.conn.WriteJSON(msg); err != nil {
return err
}
}
return nil
}
func (sc *SubscriptionClient) terminate() error {
if sc.conn != nil {
// send terminate message to the server
msg := OperationMessage{
Type: GQL_CONNECTION_TERMINATE,
}
sc.printLog(msg, GQL_CONNECTION_TERMINATE)
return sc.conn.WriteJSON(msg)
}
return nil
}
// Reset restart websocket connection and subscriptions
func (sc *SubscriptionClient) Reset() error {
if !sc.isRunning {
return nil
}
for id, sub := range sc.subscriptions {
_ = sc.stopSubscription(id)
sub.started = false
}
if sc.conn != nil {
_ = sc.terminate()
_ = sc.conn.Close()
sc.conn = nil
}
sc.cancel()
return sc.Run()
}
// Close closes all subscription channel and websocket as well
func (sc *SubscriptionClient) Close() (err error) {
sc.setIsRunning(false)
for id := range sc.subscriptions {
if err = sc.Unsubscribe(id); err != nil {
sc.cancel()
return err
}
}
if sc.conn != nil {
_ = sc.terminate()
err = sc.conn.Close()
sc.conn = nil
}
sc.cancel()
return
}
// default websocket handler implementation using https://github.com/nhooyr/websocket
type websocketHandler struct {
ctx context.Context
timeout time.Duration
*websocket.Conn
}
func (wh *websocketHandler) WriteJSON(v interface{}) error {
ctx, cancel := context.WithTimeout(wh.ctx, wh.timeout)
defer cancel()
return wsjson.Write(ctx, wh.Conn, v)
}
func (wh *websocketHandler) ReadJSON(v interface{}) error {
ctx, cancel := context.WithTimeout(wh.ctx, wh.timeout)
defer cancel()
return wsjson.Read(ctx, wh.Conn, v)
}
func (wh *websocketHandler) Close() error {
return wh.Conn.Close(websocket.StatusNormalClosure, "close websocket")
}
func newWebsocketConn(sc *SubscriptionClient) (WebsocketConn, error) {
options := &websocket.DialOptions{
Subprotocols: []string{"graphql-ws"},
}
c, _, err := websocket.Dial(sc.GetContext(), sc.GetURL(), options)
if err != nil {
return nil, err
}
return &websocketHandler{
ctx: sc.GetContext(),
Conn: c,
timeout: sc.GetTimeout(),
}, nil
}
// errors represents the "errors" array in a response from a GraphQL server.
// If returned via error interface, the slice is expected to contain at least 1 element.
//
// Specification: https://facebook.github.io/graphql/#sec-Errors.
type errors []struct {
Message string
Locations []struct {
Line int
Column int
}
}
// Error implements error interface.
func (e errors) Error() string {
return e[0].Message
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment