Created
June 9, 2020 18:04
-
-
Save pieterlouw/a01c17ac8ec84bf7d6e3a639019a42fe to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* This is a snippet to show how to add CORS middleware for grpc-gateway implementation */ | |
package main | |
// all relevant imports here | |
func main() { | |
ctx := BackgroundWithSignals() | |
//create grpc server implementation | |
server := yourGrpcServerImplementation.New() | |
//create new grpc server () | |
s := grpc.NewServer(/*with relevant options*/) | |
//register grpc server | |
grpcpb.RegisterXXXServer(s, server) | |
// Serve gRPC Server | |
logger.Info("Serving gRPC on %s", addr) | |
go func() { | |
logger.Fatal(s.Serve(lis)) | |
}() | |
// Setup and start grpc-gateway server | |
go func() { | |
gwAddr := fmt.Sprintf("%s:%d", cfg.GWListener.Addr, cfg.GWListener.Port) | |
// See https://github.com/grpc/grpc/blob/master/doc/naming.md | |
// for gRPC naming standard information. | |
dialAddr := fmt.Sprintf("passthrough://localhost/%s", grpcAddr) | |
conn, err := grpc.DialContext( | |
context.Background(), | |
dialAddr, | |
//grpc.WithTransportCredentials(credentials.NewClientTLSFromCert(insecure.CertPool, "")), | |
grpc.WithInsecure(), | |
grpc.WithBlock(), | |
) | |
if err != nil { | |
logger.Fatalln("Failed to dial server:", err) | |
} | |
// create new http mux | |
mux := http.NewServeMux() | |
// create new grpc-gateway mux | |
gwmux := runtime.NewServeMux( | |
runtime.WithMarshalerOption(runtime.MIMEWildcard, &runtime.JSONPb{ | |
EmitDefaults: true, | |
//Indent: " ", | |
OrigName: true, | |
}), | |
// This is necessary to get error details properly | |
// marshalled in unary requests. | |
runtime.WithProtoErrorHandler(runtime.DefaultHTTPProtoErrorHandler), | |
) | |
// register grpc-gateway mux | |
err = grpcpb.RegisterXXXHandler(context.Background(), gwmux, conn) | |
if err != nil { | |
logger.Fatalln("Failed to register gateway:", err) | |
} | |
mux.Handle("/", allowCORS(gwmux)) | |
logger.Info("Serving gRPC-Gateway on http://", gwAddr) | |
gwServer := http.Server{ | |
Addr: gwAddr, | |
Handler: mux, | |
ReadTimeout: time.Second * 5, | |
IdleTimeout: time.Second * 5, | |
} | |
gwServer.ListenAndServe() | |
}() | |
//wait for close signal | |
<-ctx.Done() | |
s.GracefulStop() | |
// force exit after graceful period exceeded | |
<-time.After(gracefulStopMaxWait) | |
logger.Println("main: Exit") | |
} | |
// allowCORS allows Cross Origin Resoruce Sharing from any origin. | |
// Don't do this without consideration in production systems. | |
func allowCORS(h http.Handler) http.Handler { | |
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |
if origin := r.Header.Get("Origin"); origin != "" { | |
w.Header().Set("Access-Control-Allow-Origin", origin) | |
if r.Method == "OPTIONS" && r.Header.Get("Access-Control-Request-Method") != "" { | |
preflightHandler(w, r) | |
return | |
} | |
} | |
h.ServeHTTP(w, r) | |
}) | |
} | |
func preflightHandler(w http.ResponseWriter, r *http.Request) { | |
headers := []string{"Content-Type", "Accept"} | |
w.Header().Set("Access-Control-Allow-Headers", strings.Join(headers, ",")) | |
methods := []string{"GET", "HEAD", "POST", "PUT", "DELETE"} | |
w.Header().Set("Access-Control-Allow-Methods", strings.Join(methods, ",")) | |
} | |
// BackgroundWithSignals returns a Context that will be | |
// canceled with the process receives a SIGINT signal. | |
// This function starts a goroutine and listens for signals. | |
func BackgroundWithSignals() context.Context { | |
ctx, cancelFn := context.WithCancel(context.Background()) | |
go func() { | |
c := make(chan os.Signal, 1) | |
signal.Notify(c, os.Interrupt) | |
<-c | |
signal.Reset(os.Interrupt) | |
cancelFn() | |
}() | |
return ctx | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment