Last active
January 30, 2019 00:44
-
-
Save macnibblet/8c25cd7964c6c28fe09c40126424032c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package handler | |
import ( | |
"context" | |
"encoding/json" | |
"fmt" | |
"io" | |
"net/http" | |
"strings" | |
"sync" | |
"github.com/99designs/gqlgen/complexity" | |
"github.com/99designs/gqlgen/graphql" | |
"github.com/gorilla/websocket" | |
"github.com/hashicorp/golang-lru" | |
"github.com/vektah/gqlparser/ast" | |
"github.com/vektah/gqlparser/gqlerror" | |
"github.com/vektah/gqlparser/parser" | |
"github.com/vektah/gqlparser/validator" | |
) | |
type params struct { | |
Query string `json:"query"` | |
OperationName string `json:"operationName"` | |
Variables map[string]interface{} `json:"variables"` | |
} | |
type Config struct { | |
cacheSize int | |
upgrader websocket.Upgrader | |
recover graphql.RecoverFunc | |
errorPresenter graphql.ErrorPresenterFunc | |
resolverHook graphql.FieldMiddleware | |
requestHook graphql.RequestMiddleware | |
tracer graphql.Tracer | |
complexityLimit int | |
} | |
func (c *Config) newRequestContext(es graphql.ExecutableSchema, doc *ast.QueryDocument, op *ast.OperationDefinition, query string, variables map[string]interface{}) *graphql.RequestContext { | |
reqCtx := graphql.NewRequestContext(doc, query, variables) | |
if hook := c.recover; hook != nil { | |
reqCtx.Recover = hook | |
} | |
if hook := c.errorPresenter; hook != nil { | |
reqCtx.ErrorPresenter = hook | |
} | |
if hook := c.resolverHook; hook != nil { | |
reqCtx.ResolverMiddleware = hook | |
} | |
if hook := c.requestHook; hook != nil { | |
reqCtx.RequestMiddleware = hook | |
} | |
if hook := c.tracer; hook != nil { | |
reqCtx.Tracer = hook | |
} else { | |
reqCtx.Tracer = &graphql.NopTracer{} | |
} | |
if c.complexityLimit > 0 { | |
reqCtx.ComplexityLimit = c.complexityLimit | |
operationComplexity := complexity.Calculate(es, op, variables) | |
reqCtx.OperationComplexity = operationComplexity | |
} | |
return reqCtx | |
} | |
type Option func(cfg *Config) | |
func WebsocketUpgrader(upgrader websocket.Upgrader) Option { | |
return func(cfg *Config) { | |
cfg.upgrader = upgrader | |
} | |
} | |
func RecoverFunc(recover graphql.RecoverFunc) Option { | |
return func(cfg *Config) { | |
cfg.recover = recover | |
} | |
} | |
// ErrorPresenter transforms errors found while resolving into errors that will be returned to the user. It provides | |
// a good place to add any extra fields, like error.type, that might be desired by your frontend. Check the default | |
// implementation in graphql.DefaultErrorPresenter for an example. | |
func ErrorPresenter(f graphql.ErrorPresenterFunc) Option { | |
return func(cfg *Config) { | |
cfg.errorPresenter = f | |
} | |
} | |
// ComplexityLimit sets a maximum query complexity that is allowed to be executed. | |
// If a query is submitted that exceeds the limit, a 422 status code will be returned. | |
func ComplexityLimit(limit int) Option { | |
return func(cfg *Config) { | |
cfg.complexityLimit = limit | |
} | |
} | |
// ResolverMiddleware allows you to define a function that will be called around every resolver, | |
// useful for logging. | |
func ResolverMiddleware(middleware graphql.FieldMiddleware) Option { | |
return func(cfg *Config) { | |
if cfg.resolverHook == nil { | |
cfg.resolverHook = middleware | |
return | |
} | |
lastResolve := cfg.resolverHook | |
cfg.resolverHook = func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) { | |
return lastResolve(ctx, func(ctx context.Context) (res interface{}, err error) { | |
return middleware(ctx, next) | |
}) | |
} | |
} | |
} | |
// RequestMiddleware allows you to define a function that will be called around the root request, | |
// after the query has been parsed. This is useful for logging | |
func RequestMiddleware(middleware graphql.RequestMiddleware) Option { | |
return func(cfg *Config) { | |
if cfg.requestHook == nil { | |
cfg.requestHook = middleware | |
return | |
} | |
lastResolve := cfg.requestHook | |
cfg.requestHook = func(ctx context.Context, next func(ctx context.Context) []byte) []byte { | |
return lastResolve(ctx, func(ctx context.Context) []byte { | |
return middleware(ctx, next) | |
}) | |
} | |
} | |
} | |
// Tracer allows you to add a request/resolver tracer that will be called around the root request, | |
// calling resolver. This is useful for tracing | |
func Tracer(tracer graphql.Tracer) Option { | |
return func(cfg *Config) { | |
if cfg.tracer == nil { | |
cfg.tracer = tracer | |
} else { | |
lastResolve := cfg.tracer | |
cfg.tracer = &tracerWrapper{ | |
tracer1: lastResolve, | |
tracer2: tracer, | |
} | |
} | |
opt := RequestMiddleware(func(ctx context.Context, next func(ctx context.Context) []byte) []byte { | |
ctx = tracer.StartOperationExecution(ctx) | |
resp := next(ctx) | |
tracer.EndOperationExecution(ctx) | |
return resp | |
}) | |
opt(cfg) | |
} | |
} | |
type tracerWrapper struct { | |
tracer1 graphql.Tracer | |
tracer2 graphql.Tracer | |
} | |
func (tw *tracerWrapper) StartOperationParsing(ctx context.Context) context.Context { | |
ctx = tw.tracer1.StartOperationParsing(ctx) | |
ctx = tw.tracer2.StartOperationParsing(ctx) | |
return ctx | |
} | |
func (tw *tracerWrapper) EndOperationParsing(ctx context.Context) { | |
tw.tracer2.EndOperationParsing(ctx) | |
tw.tracer1.EndOperationParsing(ctx) | |
} | |
func (tw *tracerWrapper) StartOperationValidation(ctx context.Context) context.Context { | |
ctx = tw.tracer1.StartOperationValidation(ctx) | |
ctx = tw.tracer2.StartOperationValidation(ctx) | |
return ctx | |
} | |
func (tw *tracerWrapper) EndOperationValidation(ctx context.Context) { | |
tw.tracer2.EndOperationValidation(ctx) | |
tw.tracer1.EndOperationValidation(ctx) | |
} | |
func (tw *tracerWrapper) StartOperationExecution(ctx context.Context) context.Context { | |
ctx = tw.tracer1.StartOperationExecution(ctx) | |
ctx = tw.tracer2.StartOperationExecution(ctx) | |
return ctx | |
} | |
func (tw *tracerWrapper) StartFieldExecution(ctx context.Context, field graphql.CollectedField) context.Context { | |
ctx = tw.tracer1.StartFieldExecution(ctx, field) | |
ctx = tw.tracer2.StartFieldExecution(ctx, field) | |
return ctx | |
} | |
func (tw *tracerWrapper) StartFieldResolverExecution(ctx context.Context, rc *graphql.ResolverContext) context.Context { | |
ctx = tw.tracer1.StartFieldResolverExecution(ctx, rc) | |
ctx = tw.tracer2.StartFieldResolverExecution(ctx, rc) | |
return ctx | |
} | |
func (tw *tracerWrapper) StartFieldChildExecution(ctx context.Context) context.Context { | |
ctx = tw.tracer1.StartFieldChildExecution(ctx) | |
ctx = tw.tracer2.StartFieldChildExecution(ctx) | |
return ctx | |
} | |
func (tw *tracerWrapper) EndFieldExecution(ctx context.Context) { | |
tw.tracer2.EndFieldExecution(ctx) | |
tw.tracer1.EndFieldExecution(ctx) | |
} | |
func (tw *tracerWrapper) EndOperationExecution(ctx context.Context) { | |
tw.tracer2.EndOperationExecution(ctx) | |
tw.tracer1.EndOperationExecution(ctx) | |
} | |
// CacheSize sets the maximum size of the query cache. | |
// If size is less than or equal to 0, the cache is disabled. | |
func CacheSize(size int) Option { | |
return func(cfg *Config) { | |
cfg.cacheSize = size | |
} | |
} | |
const DefaultCacheSize = 1000 | |
func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc { | |
cfg := &Config{ | |
cacheSize: DefaultCacheSize, | |
upgrader: websocket.Upgrader{ | |
ReadBufferSize: 1024, | |
WriteBufferSize: 1024, | |
}, | |
} | |
for _, option := range options { | |
option(cfg) | |
} | |
var cache *lru.Cache | |
if cfg.cacheSize > 0 { | |
var err error | |
cache, err = lru.New(DefaultCacheSize) | |
if err != nil { | |
// An error is only returned for non-positive cache size | |
// and we already checked for that. | |
panic("unexpected error creating cache: " + err.Error()) | |
} | |
} | |
if cfg.tracer == nil { | |
cfg.tracer = &graphql.NopTracer{} | |
} | |
handler := &graphqlHandler{ | |
cfg: cfg, | |
cache: cache, | |
exec: exec, | |
} | |
return handler.ServeHTTP | |
} | |
var _ http.Handler = (*graphqlHandler)(nil) | |
type graphqlHandler struct { | |
cfg *Config | |
cache *lru.Cache | |
exec graphql.ExecutableSchema | |
} | |
func (gh *graphqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { | |
if r.Method == http.MethodOptions { | |
w.Header().Set("Allow", "OPTIONS, GET, POST") | |
w.WriteHeader(http.StatusOK) | |
return | |
} | |
if strings.Contains(r.Header.Get("Upgrade"), "websocket") { | |
connectWs(gh.exec, w, r, gh.cfg) | |
return | |
} | |
var reqCollection []params | |
switch r.Method { | |
case http.MethodGet: | |
var reqParams params | |
reqParams.Query = r.URL.Query().Get("query") | |
reqParams.OperationName = r.URL.Query().Get("operationName") | |
if variables := r.URL.Query().Get("variables"); variables != "" { | |
if err := jsonDecode(strings.NewReader(variables), &reqParams.Variables); err != nil { | |
sendErrorf(w, http.StatusBadRequest, "variables could not be decoded") | |
return | |
} | |
} | |
reqCollection = append(reqCollection, reqParams) | |
case http.MethodPost: | |
if r.Header.Get("X-Batching") == "true" { | |
if err := jsonDecode(r.Body, &reqCollection); err != nil { | |
sendErrorf(w, http.StatusBadRequest, "json body could not be decoded: "+err.Error()) | |
return | |
} | |
} else { | |
var reqParams params | |
if err := jsonDecode(r.Body, &reqParams); err != nil { | |
sendErrorf(w, http.StatusBadRequest, "json body could not be decoded: "+err.Error()) | |
return | |
} | |
reqCollection = append(reqCollection, reqParams) | |
} | |
default: | |
w.WriteHeader(http.StatusMethodNotAllowed) | |
return | |
} | |
w.Header().Set("Content-Type", "application/json") | |
ctx := r.Context() | |
wg := sync.WaitGroup{} | |
responses := make([]*graphql.Response, len(reqCollection)) | |
for idx, op := range reqCollection { | |
wg.Add(1) | |
go func(idx int, reqParams params) { | |
defer wg.Done() | |
var doc *ast.QueryDocument | |
var cacheHit bool | |
if gh.cache != nil { | |
val, ok := gh.cache.Get(reqParams.Query) | |
if ok { | |
doc = val.(*ast.QueryDocument) | |
cacheHit = true | |
} | |
} | |
ctx, doc, gqlErr := gh.parseOperation(ctx, &parseOperationArgs{ | |
Query: reqParams.Query, | |
CachedDoc: doc, | |
}) | |
if gqlErr != nil { | |
responses[idx] = &graphql.Response{ | |
Errors: []*gqlerror.Error{gqlErr}, | |
} | |
return | |
} | |
ctx, op, vars, listErr := gh.validateOperation(ctx, &validateOperationArgs{ | |
Doc: doc, | |
OperationName: reqParams.OperationName, | |
CacheHit: cacheHit, | |
R: r, | |
Variables: reqParams.Variables, | |
}) | |
if len(listErr) != 0 { | |
responses[idx] = &graphql.Response{ | |
Errors: listErr, | |
} | |
return | |
} | |
if gh.cache != nil && !cacheHit { | |
gh.cache.Add(reqParams.Query, doc) | |
} | |
reqCtx := gh.cfg.newRequestContext(gh.exec, doc, op, reqParams.Query, vars) | |
ctx = graphql.WithRequestContext(ctx, reqCtx) | |
defer func() { | |
if err := recover(); err != nil { | |
userErr := reqCtx.Recover(ctx, err) | |
responses[idx] = &graphql.Response{ | |
Errors: []*gqlerror.Error{ | |
{ | |
Message: userErr.Error(), | |
}, | |
}, | |
} | |
} | |
}() | |
if reqCtx.ComplexityLimit > 0 && reqCtx.OperationComplexity > reqCtx.ComplexityLimit { | |
responses[idx] = &graphql.Response{ | |
Errors: []*gqlerror.Error{ | |
{ | |
Message: fmt.Sprintf("operation has complexity %d, which exceeds the limit of %d", reqCtx.OperationComplexity, reqCtx.ComplexityLimit), | |
}, | |
}, | |
} | |
return | |
} | |
switch op.Operation { | |
case ast.Query: | |
responses[idx] = gh.exec.Query(ctx, op) | |
case ast.Mutation: | |
responses[idx] = gh.exec.Mutation(ctx, op) | |
default: | |
responses[idx] = &graphql.Response{ | |
Errors: []*gqlerror.Error{ | |
{ | |
Message: "Unsupported operation type", | |
}, | |
}, | |
} | |
} | |
}(idx, op) | |
} | |
wg.Wait() | |
encoder := json.NewEncoder(w) | |
if len(responses) == 1 { | |
encoder.Encode(responses[0]) | |
} else { | |
encoder.Encode(responses) | |
} | |
} | |
type parseOperationArgs struct { | |
Query string | |
CachedDoc *ast.QueryDocument | |
} | |
func (gh *graphqlHandler) parseOperation(ctx context.Context, args *parseOperationArgs) (context.Context, *ast.QueryDocument, *gqlerror.Error) { | |
ctx = gh.cfg.tracer.StartOperationParsing(ctx) | |
defer func() { gh.cfg.tracer.EndOperationParsing(ctx) }() | |
if args.CachedDoc != nil { | |
return ctx, args.CachedDoc, nil | |
} | |
doc, gqlErr := parser.ParseQuery(&ast.Source{Input: args.Query}) | |
if gqlErr != nil { | |
return ctx, nil, gqlErr | |
} | |
return ctx, doc, nil | |
} | |
type validateOperationArgs struct { | |
Doc *ast.QueryDocument | |
OperationName string | |
CacheHit bool | |
R *http.Request | |
Variables map[string]interface{} | |
} | |
func (gh *graphqlHandler) validateOperation(ctx context.Context, args *validateOperationArgs) (context.Context, *ast.OperationDefinition, map[string]interface{}, gqlerror.List) { | |
ctx = gh.cfg.tracer.StartOperationValidation(ctx) | |
defer func() { gh.cfg.tracer.EndOperationValidation(ctx) }() | |
if !args.CacheHit { | |
listErr := validator.Validate(gh.exec.Schema(), args.Doc) | |
if len(listErr) != 0 { | |
return ctx, nil, nil, listErr | |
} | |
} | |
op := args.Doc.Operations.ForName(args.OperationName) | |
if op == nil { | |
return ctx, nil, nil, gqlerror.List{gqlerror.Errorf("operation %s not found", args.OperationName)} | |
} | |
if op.Operation != ast.Query && args.R.Method == http.MethodGet { | |
return ctx, nil, nil, gqlerror.List{gqlerror.Errorf("GET requests only allow query operations")} | |
} | |
vars, err := validator.VariableValues(gh.exec.Schema(), op, args.Variables) | |
if err != nil { | |
return ctx, nil, nil, gqlerror.List{err} | |
} | |
return ctx, op, vars, nil | |
} | |
func jsonDecode(r io.Reader, val interface{}) error { | |
dec := json.NewDecoder(r) | |
dec.UseNumber() | |
return dec.Decode(val) | |
} | |
func sendError(w http.ResponseWriter, code int, errors ...*gqlerror.Error) { | |
w.WriteHeader(code) | |
b, err := json.Marshal(&graphql.Response{Errors: errors}) | |
if err != nil { | |
panic(err) | |
} | |
w.Write(b) | |
} | |
func sendErrorf(w http.ResponseWriter, code int, format string, args ...interface{}) { | |
sendError(w, code, &gqlerror.Error{Message: fmt.Sprintf(format, args...)}) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment