Skip to content

Instantly share code, notes, and snippets.

@adamgoose
Created December 28, 2018 06:24
Show Gist options
  • Save adamgoose/005cda8d9ce4030047e05e5b07953fd4 to your computer and use it in GitHub Desktop.
Save adamgoose/005cda8d9ce4030047e05e5b07953fd4 to your computer and use it in GitHub Desktop.
A sloppy attempt at a GraphQL Subscription Client for golang
package subs
import (
"context"
"encoding/json"
"log"
"github.com/gorilla/websocket"
uuid "github.com/satori/go.uuid"
)
type Subs struct {
conn *websocket.Conn
Context context.Context
subs map[uuid.UUID]chan json.RawMessage
}
func New(url string) (s *Subs, err error) {
ctx, cancel := context.WithCancel(context.Background())
s = &Subs{
subs: make(map[uuid.UUID]chan json.RawMessage),
Context: ctx,
}
d := websocket.DefaultDialer
d.Subprotocols = []string{"graphql-ws"}
s.conn, _, err = d.DialContext(s.Context, url, nil)
if err != nil {
return nil, err
}
go func() {
for {
m := SubscriptionMessage{}
if err := s.conn.ReadJSON(&m); err != nil {
log.Println(err)
cancel()
return
}
id, err := uuid.FromString(m.ID)
if err != nil {
continue
}
if m.Type != DATA {
continue
}
if c, ok := s.subs[id]; ok {
c <- m.Payload.Data
}
}
}()
if err := s.conn.WriteMessage(websocket.TextMessage, []byte(`{"type":"connection_init","payload":{}}`)); err != nil {
return nil, err
}
return s, nil
}
// SubStruct enables you to provide your query in a shurcooL/graphql-compliant way.
// Values passed on the return chan can safely be type-asserted to the type of query
func (s *Subs) SubStruct(query interface{}, variables map[string]interface{}) (chan interface{}, context.CancelFunc) {
newChan := make(chan interface{}, 50)
q := constructSubscription(query, variables)
c, cancel := s.SubJSON("", q, variables)
go func() {
for {
msg := <-c
if err := UnmarshalGraphQL(msg, query); err == nil {
newChan <- query
} else {
log.Println("Unable to unmarshal graphql: ", err)
}
}
}()
return newChan, cancel
}
// SubJSON enables you to provide your query in a string format.
// Values passed on the return chan contain a JSON representation of
// the "data" attribute returned by the GraphQL server
func (s *Subs) SubJSON(oprationName, subscription string, variables map[string]interface{}) (chan json.RawMessage, context.CancelFunc) {
id := uuid.Must(uuid.NewV4())
sub := Subscription{
ID: id.String(),
Type: START,
Payload: GraphQLPayload{
OperationName: oprationName,
Query: subscription,
Variables: variables,
},
}
if err := s.conn.WriteJSON(sub); err != nil {
log.Fatal(err)
}
c := make(chan json.RawMessage, 1)
s.subs[id] = c
ctx, cancel := context.WithCancel(context.Background())
go func() {
<-ctx.Done()
sub.Type = STOP
delete(s.subs, id)
close(c)
}()
return c, cancel
}
func (s Subs) Done() <-chan struct{} {
return s.Context.Done()
}
// Package jsonutil provides a function for decoding JSON
// into a GraphQL query data structure.
// ====== borrowed from https://github.com/shurcooL/graphql/blob/master/internal/jsonutil/graphql.go ======
package subs
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"reflect"
"strings"
)
// UnmarshalGraphQL parses the JSON-encoded GraphQL response data and stores
// the result in the GraphQL query data structure pointed to by v.
//
// The implementation is created on top of the JSON tokenizer available
// in "encoding/json".Decoder.
func UnmarshalGraphQL(data []byte, v interface{}) error {
dec := json.NewDecoder(bytes.NewReader(data))
dec.UseNumber()
err := (&decoder{tokenizer: dec}).Decode(v)
if err != nil {
return err
}
tok, err := dec.Token()
switch err {
case io.EOF:
// Expect to get io.EOF. There shouldn't be any more
// tokens left after we've decoded v successfully.
return nil
case nil:
return fmt.Errorf("invalid token '%v' after top-level value", tok)
default:
return err
}
}
// decoder is a JSON decoder that performs custom unmarshaling behavior
// for GraphQL query data structures. It's implemented on top of a JSON tokenizer.
type decoder struct {
tokenizer interface {
Token() (json.Token, error)
}
// Stack of what part of input JSON we're in the middle of - objects, arrays.
parseState []json.Delim
// Stacks of values where to unmarshal.
// The top of each stack is the reflect.Value where to unmarshal next JSON value.
//
// The reason there's more than one stack is because we might be unmarshaling
// a single JSON value into multiple GraphQL fragments or embedded structs, so
// we keep track of them all.
vs [][]reflect.Value
}
// Decode decodes a single JSON value from d.tokenizer into v.
func (d *decoder) Decode(v interface{}) error {
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Ptr {
return fmt.Errorf("cannot decode into non-pointer %T", v)
}
d.vs = [][]reflect.Value{{rv.Elem()}}
return d.decode()
}
// decode decodes a single JSON value from d.tokenizer into d.vs.
func (d *decoder) decode() error {
// The loop invariant is that the top of each d.vs stack
// is where we try to unmarshal the next JSON value we see.
for len(d.vs) > 0 {
tok, err := d.tokenizer.Token()
if err == io.EOF {
return errors.New("unexpected end of JSON input")
} else if err != nil {
return err
}
switch {
// Are we inside an object and seeing next key (rather than end of object)?
case d.state() == '{' && tok != json.Delim('}'):
key, ok := tok.(string)
if !ok {
return errors.New("unexpected non-key in JSON input")
}
someFieldExist := false
for i := range d.vs {
v := d.vs[i][len(d.vs[i])-1]
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
var f reflect.Value
if v.Kind() == reflect.Struct {
f = fieldByGraphQLName(v, key)
if f.IsValid() {
someFieldExist = true
}
}
d.vs[i] = append(d.vs[i], f)
}
if !someFieldExist {
return fmt.Errorf("struct field for %s doesn't exist in any of %v places to unmarshal", key, len(d.vs))
}
// We've just consumed the current token, which was the key.
// Read the next token, which should be the value, and let the rest of code process it.
tok, err = d.tokenizer.Token()
if err == io.EOF {
return errors.New("unexpected end of JSON input")
} else if err != nil {
return err
}
// Are we inside an array and seeing next value (rather than end of array)?
case d.state() == '[' && tok != json.Delim(']'):
someSliceExist := false
for i := range d.vs {
v := d.vs[i][len(d.vs[i])-1]
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
var f reflect.Value
if v.Kind() == reflect.Slice {
v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem()))) // v = append(v, T).
f = v.Index(v.Len() - 1)
someSliceExist = true
}
d.vs[i] = append(d.vs[i], f)
}
if !someSliceExist {
return fmt.Errorf("slice doesn't exist in any of %v places to unmarshal", len(d.vs))
}
}
switch tok := tok.(type) {
case string, json.Number, bool, nil:
// Value.
for i := range d.vs {
v := d.vs[i][len(d.vs[i])-1]
if !v.IsValid() {
continue
}
err := unmarshalValue(tok, v)
if err != nil {
return err
}
}
d.popAllVs()
case json.Delim:
switch tok {
case '{':
// Start of object.
d.pushState(tok)
frontier := make([]reflect.Value, len(d.vs)) // Places to look for GraphQL fragments/embedded structs.
for i := range d.vs {
v := d.vs[i][len(d.vs[i])-1]
frontier[i] = v
// TODO: Do this recursively or not? Add a test case if needed.
if v.Kind() == reflect.Ptr && v.IsNil() {
v.Set(reflect.New(v.Type().Elem())) // v = new(T).
}
}
// Find GraphQL fragments/embedded structs recursively, adding to frontier
// as new ones are discovered and exploring them further.
for len(frontier) > 0 {
v := frontier[0]
frontier = frontier[1:]
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
if v.Kind() != reflect.Struct {
continue
}
for i := 0; i < v.NumField(); i++ {
if isGraphQLFragment(v.Type().Field(i)) || v.Type().Field(i).Anonymous {
// Add GraphQL fragment or embedded struct.
d.vs = append(d.vs, []reflect.Value{v.Field(i)})
frontier = append(frontier, v.Field(i))
}
}
}
case '[':
// Start of array.
d.pushState(tok)
for i := range d.vs {
v := d.vs[i][len(d.vs[i])-1]
// TODO: Confirm this is needed, write a test case.
//if v.Kind() == reflect.Ptr && v.IsNil() {
// v.Set(reflect.New(v.Type().Elem())) // v = new(T).
//}
// Reset slice to empty (in case it had non-zero initial value).
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
if v.Kind() != reflect.Slice {
continue
}
v.Set(reflect.MakeSlice(v.Type(), 0, 0)) // v = make(T, 0, 0).
}
case '}', ']':
// End of object or array.
d.popAllVs()
d.popState()
default:
return errors.New("unexpected delimiter in JSON input")
}
default:
return errors.New("unexpected token in JSON input")
}
}
return nil
}
// pushState pushes a new parse state s onto the stack.
func (d *decoder) pushState(s json.Delim) {
d.parseState = append(d.parseState, s)
}
// popState pops a parse state (already obtained) off the stack.
// The stack must be non-empty.
func (d *decoder) popState() {
d.parseState = d.parseState[:len(d.parseState)-1]
}
// state reports the parse state on top of stack, or 0 if empty.
func (d *decoder) state() json.Delim {
if len(d.parseState) == 0 {
return 0
}
return d.parseState[len(d.parseState)-1]
}
// popAllVs pops from all d.vs stacks, keeping only non-empty ones.
func (d *decoder) popAllVs() {
var nonEmpty [][]reflect.Value
for i := range d.vs {
d.vs[i] = d.vs[i][:len(d.vs[i])-1]
if len(d.vs[i]) > 0 {
nonEmpty = append(nonEmpty, d.vs[i])
}
}
d.vs = nonEmpty
}
// fieldByGraphQLName returns a struct field of struct v that matches GraphQL name,
// or invalid reflect.Value if none found.
func fieldByGraphQLName(v reflect.Value, name string) reflect.Value {
for i := 0; i < v.NumField(); i++ {
if hasGraphQLName(v.Type().Field(i), name) {
return v.Field(i)
}
}
return reflect.Value{}
}
// hasGraphQLName reports whether struct field f has GraphQL name.
func hasGraphQLName(f reflect.StructField, name string) bool {
value, ok := f.Tag.Lookup("graphql")
if !ok {
// TODO: caseconv package is relatively slow. Optimize it, then consider using it here.
//return caseconv.MixedCapsToLowerCamelCase(f.Name) == name
return strings.EqualFold(f.Name, name)
}
value = strings.TrimSpace(value) // TODO: Parse better.
if strings.HasPrefix(value, "...") {
// GraphQL fragment. It doesn't have a name.
return false
}
if i := strings.Index(value, "("); i != -1 {
value = value[:i]
}
if i := strings.Index(value, ":"); i != -1 {
value = value[:i]
}
return strings.TrimSpace(value) == name
}
// isGraphQLFragment reports whether struct field f is a GraphQL fragment.
func isGraphQLFragment(f reflect.StructField) bool {
value, ok := f.Tag.Lookup("graphql")
if !ok {
return false
}
value = strings.TrimSpace(value) // TODO: Parse better.
return strings.HasPrefix(value, "...")
}
// unmarshalValue unmarshals JSON value into v.
func unmarshalValue(value json.Token, v reflect.Value) error {
b, err := json.Marshal(value) // TODO: Short-circuit (if profiling says it's worth it).
if err != nil {
return err
}
if !v.CanAddr() {
return fmt.Errorf("value %v is not addressable", v)
}
return json.Unmarshal(b, v.Addr().Interface())
}
// ======= borrowed from https://github.com/shurcooL/graphql/blob/master/query.go =======
package subs
import (
"bytes"
"encoding/json"
"io"
"reflect"
"sort"
"github.com/shurcooL/graphql/ident"
)
func constructQuery(v interface{}, variables map[string]interface{}) string {
query := query(v)
if len(variables) > 0 {
return "query(" + queryArguments(variables) + ")" + query
}
return query
}
func constructMutation(v interface{}, variables map[string]interface{}) string {
query := query(v)
if len(variables) > 0 {
return "mutation(" + queryArguments(variables) + ")" + query
}
return "mutation" + query
}
func constructSubscription(v interface{}, variables map[string]interface{}) string {
query := query(v)
if len(variables) > 0 {
return "subscription(" + queryArguments(variables) + ")" + query
}
return "subscription" + query
}
// queryArguments constructs a minified arguments string for variables.
//
// E.g., map[string]interface{}{"a": Int(123), "b": NewBoolean(true)} -> "$a:Int!$b:Boolean".
func queryArguments(variables map[string]interface{}) string {
// Sort keys in order to produce deterministic output for testing purposes.
// TODO: If tests can be made to work with non-deterministic output, then no need to sort.
keys := make([]string, 0, len(variables))
for k := range variables {
keys = append(keys, k)
}
sort.Strings(keys)
var buf bytes.Buffer
for _, k := range keys {
io.WriteString(&buf, "$")
io.WriteString(&buf, k)
io.WriteString(&buf, ":")
writeArgumentType(&buf, reflect.TypeOf(variables[k]), true)
// Don't insert a comma here.
// Commas in GraphQL are insignificant, and we want minified output.
// See https://facebook.github.io/graphql/October2016/#sec-Insignificant-Commas.
}
return buf.String()
}
// writeArgumentType writes a minified GraphQL type for t to w.
// value indicates whether t is a value (required) type or pointer (optional) type.
// If value is true, then "!" is written at the end of t.
func writeArgumentType(w io.Writer, t reflect.Type, value bool) {
if t.Kind() == reflect.Ptr {
// Pointer is an optional type, so no "!" at the end of the pointer's underlying type.
writeArgumentType(w, t.Elem(), false)
return
}
switch t.Kind() {
case reflect.Slice, reflect.Array:
// List. E.g., "[Int]".
io.WriteString(w, "[")
writeArgumentType(w, t.Elem(), true)
io.WriteString(w, "]")
default:
// Named type. E.g., "Int".
name := t.Name()
if name == "string" { // HACK: Workaround for https://github.com/shurcooL/githubv4/issues/12.
name = "ID"
}
io.WriteString(w, name)
}
if value {
// Value is a required type, so add "!" to the end.
io.WriteString(w, "!")
}
}
// query uses writeQuery to recursively construct
// a minified query string from the provided struct v.
//
// E.g., struct{Foo Int, BarBaz *Boolean} -> "{foo,barBaz}".
func query(v interface{}) string {
var buf bytes.Buffer
writeQuery(&buf, reflect.TypeOf(v), false)
return buf.String()
}
// writeQuery writes a minified query for t to w.
// If inline is true, the struct fields of t are inlined into parent struct.
func writeQuery(w io.Writer, t reflect.Type, inline bool) {
switch t.Kind() {
case reflect.Ptr, reflect.Slice:
writeQuery(w, t.Elem(), false)
case reflect.Struct:
// If the type implements json.Unmarshaler, it's a scalar. Don't expand it.
if reflect.PtrTo(t).Implements(jsonUnmarshaler) {
return
}
if !inline {
io.WriteString(w, "{")
}
for i := 0; i < t.NumField(); i++ {
if i != 0 {
io.WriteString(w, ",")
}
f := t.Field(i)
value, ok := f.Tag.Lookup("graphql")
inlineField := f.Anonymous && !ok
if !inlineField {
if ok {
io.WriteString(w, value)
} else {
io.WriteString(w, ident.ParseMixedCaps(f.Name).ToLowerCamelCase())
}
}
writeQuery(w, f.Type, inlineField)
}
if !inline {
io.WriteString(w, "}")
}
}
}
var jsonUnmarshaler = reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()
package subs
import "encoding/json"
type GraphQLPayload struct {
OperationName string `json:"operationName"`
Query string `json:"query"`
Variables map[string]interface{} `json:"variables"`
Extensions map[string]interface{} `json:"extensions"`
}
type SubscriptionAction string
const (
START SubscriptionAction = "start"
STOP SubscriptionAction = "stop"
DATA SubscriptionAction = "data"
ERROR SubscriptionAction = "error"
)
type Subscription struct {
ID string `json:"id"`
Type SubscriptionAction `json:"type"`
Payload GraphQLPayload `json:"payload"`
}
type SubscriptionMessage struct {
ID string `json:"id"`
Type SubscriptionAction `json:"type"`
Payload struct {
Data json.RawMessage `json:"data"`
} `json:"payload"`
// Payload json.RawMessage `json:"payload"`
Errors json.RawMessage `json:"errors"`
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment