Skip to content

Instantly share code, notes, and snippets.

@pieterlouw
Created June 9, 2020 18:04
Show Gist options
  • Save pieterlouw/a01c17ac8ec84bf7d6e3a639019a42fe to your computer and use it in GitHub Desktop.
Save pieterlouw/a01c17ac8ec84bf7d6e3a639019a42fe to your computer and use it in GitHub Desktop.
/* This is a snippet to show how to add CORS middleware for grpc-gateway implementation */
package main
// all relevant imports here
func main() {
ctx := BackgroundWithSignals()
//create grpc server implementation
server := yourGrpcServerImplementation.New()
//create new grpc server ()
s := grpc.NewServer(/*with relevant options*/)
//register grpc server
grpcpb.RegisterXXXServer(s, server)
// Serve gRPC Server
logger.Info("Serving gRPC on %s", addr)
go func() {
logger.Fatal(s.Serve(lis))
}()
// Setup and start grpc-gateway server
go func() {
gwAddr := fmt.Sprintf("%s:%d", cfg.GWListener.Addr, cfg.GWListener.Port)
// See https://github.com/grpc/grpc/blob/master/doc/naming.md
// for gRPC naming standard information.
dialAddr := fmt.Sprintf("passthrough://localhost/%s", grpcAddr)
conn, err := grpc.DialContext(
context.Background(),
dialAddr,
//grpc.WithTransportCredentials(credentials.NewClientTLSFromCert(insecure.CertPool, "")),
grpc.WithInsecure(),
grpc.WithBlock(),
)
if err != nil {
logger.Fatalln("Failed to dial server:", err)
}
// create new http mux
mux := http.NewServeMux()
// create new grpc-gateway mux
gwmux := runtime.NewServeMux(
runtime.WithMarshalerOption(runtime.MIMEWildcard, &runtime.JSONPb{
EmitDefaults: true,
//Indent: " ",
OrigName: true,
}),
// This is necessary to get error details properly
// marshalled in unary requests.
runtime.WithProtoErrorHandler(runtime.DefaultHTTPProtoErrorHandler),
)
// register grpc-gateway mux
err = grpcpb.RegisterXXXHandler(context.Background(), gwmux, conn)
if err != nil {
logger.Fatalln("Failed to register gateway:", err)
}
mux.Handle("/", allowCORS(gwmux))
logger.Info("Serving gRPC-Gateway on http://", gwAddr)
gwServer := http.Server{
Addr: gwAddr,
Handler: mux,
ReadTimeout: time.Second * 5,
IdleTimeout: time.Second * 5,
}
gwServer.ListenAndServe()
}()
//wait for close signal
<-ctx.Done()
s.GracefulStop()
// force exit after graceful period exceeded
<-time.After(gracefulStopMaxWait)
logger.Println("main: Exit")
}
// allowCORS allows Cross Origin Resoruce Sharing from any origin.
// Don't do this without consideration in production systems.
func allowCORS(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if origin := r.Header.Get("Origin"); origin != "" {
w.Header().Set("Access-Control-Allow-Origin", origin)
if r.Method == "OPTIONS" && r.Header.Get("Access-Control-Request-Method") != "" {
preflightHandler(w, r)
return
}
}
h.ServeHTTP(w, r)
})
}
func preflightHandler(w http.ResponseWriter, r *http.Request) {
headers := []string{"Content-Type", "Accept"}
w.Header().Set("Access-Control-Allow-Headers", strings.Join(headers, ","))
methods := []string{"GET", "HEAD", "POST", "PUT", "DELETE"}
w.Header().Set("Access-Control-Allow-Methods", strings.Join(methods, ","))
}
// BackgroundWithSignals returns a Context that will be
// canceled with the process receives a SIGINT signal.
// This function starts a goroutine and listens for signals.
func BackgroundWithSignals() context.Context {
ctx, cancelFn := context.WithCancel(context.Background())
go func() {
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt)
<-c
signal.Reset(os.Interrupt)
cancelFn()
}()
return ctx
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment