Created
January 12, 2020 15:45
-
-
Save tomberek/e86ba3937ed15f919b12d4cc30710f1b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Copyright 2017 Michal Witkowski. All Rights Reserved. | |
// See LICENSE for licensing terms. | |
// Modified Thomas Bereknyei | |
package main | |
import ( | |
"crypto/tls" | |
"crypto/x509" | |
"fmt" | |
"io" | |
"io/ioutil" | |
"net" | |
"os" | |
"github.com/golang/protobuf/proto" | |
"golang.org/x/net/context" | |
"google.golang.org/grpc" | |
"google.golang.org/grpc/codes" | |
"google.golang.org/grpc/credentials" | |
"google.golang.org/grpc/metadata" | |
"github.com/philippgille/gokv" | |
"github.com/philippgille/gokv/bbolt" | |
"gitlab.com/internal/go-cryptoservice" | |
) | |
type StreamDirector func(ctx context.Context, fullMethodName string) (context.Context, *grpc.ClientConn, error) | |
var ( | |
clientStreamDescForProxying = &grpc.StreamDesc{ | |
ServerStreams: true, | |
ClientStreams: true, | |
} | |
) | |
var creds grpc.DialOption | |
var store gokv.Store | |
var conn *grpc.ClientConn | |
func ExampleStreamDirector() StreamDirector { | |
return func(ctx context.Context, fullMethodName string) (context.Context, *grpc.ClientConn, error) { | |
md, ok := metadata.FromIncomingContext(ctx) | |
outCtx, _ := context.WithCancel(ctx) | |
outCtx = metadata.NewOutgoingContext(outCtx, md.Copy()) | |
if ok { | |
return outCtx, conn, nil | |
} | |
return nil, nil, grpc.Errorf(codes.Unimplemented, "Unknown method") | |
} | |
} | |
var offline bool = true | |
func main() { | |
fmt.Println("USAGE: mm") | |
fmt.Println(" SERVER_URL: required - remote_host:port") | |
fmt.Println(" AWS_REGION: required") | |
fmt.Println(" HTTP_PROXY: optional - default: localhost:5001") | |
fmt.Println(" OFFLINE : optional") | |
fmt.Println(" KEY : optional") | |
fmt.Println(" CERT : optional") | |
fmt.Println(" CA : optional") | |
fmt.Println(" DB_PATH : optional") | |
var err error | |
if os.Getenv("OFFLINE") == "" { | |
offline = false | |
} | |
var ta credentials.TransportCredentials | |
certificate, err := tls.LoadX509KeyPair(os.Getenv("CERT"), os.Getenv("KEY")) | |
if err != nil { | |
ta, err = cryptoservice.RequestNewGRPCTransportCredentials( | |
cryptoservice.Endpoint()) | |
if err != nil { | |
fmt.Println("Cannot obtain cryptoservice credentials") | |
} | |
fmt.Println("Received transport credentials from crypto service") | |
creds = grpc.WithTransportCredentials(ta) | |
} else { | |
certPool := x509.NewCertPool() | |
ca, err := ioutil.ReadFile(os.Getenv("CA")) | |
if err != nil { | |
fmt.Printf("could not read ca certificate: %s\n", err) | |
} | |
// Append the client certificates from the CA | |
if ok := certPool.AppendCertsFromPEM(ca); !ok { | |
fmt.Println("failed to append client certs") | |
return | |
} | |
ta = credentials.NewTLS(&tls.Config{ | |
ClientAuth: tls.RequireAndVerifyClientCert, | |
Certificates: []tls.Certificate{certificate}, | |
RootCAs: certPool, | |
ClientCAs: certPool, | |
}) | |
creds = grpc.WithTransportCredentials(ta) | |
fmt.Println("using stored creds") | |
} | |
options := bbolt.DefaultOptions | |
options.Path = getEnv("DB_PATH", "cache.db") | |
store, _ = bbolt.NewStore(options) | |
server := getEnv("SERVER_URL", "localhost:5000") | |
conn, err = grpc.Dial(server, grpc.WithCodec(&MyCodec{}), creds) | |
if err != nil { | |
panic(err) | |
} | |
// Explicitly don't check err | |
var director StreamDirector = ExampleStreamDirector() | |
s := grpc.NewServer( | |
grpc.CustomCodec(&MyCodec{}), | |
grpc.UnknownServiceHandler(TransparentHandler(director)), | |
grpc.Creds(ta), | |
) | |
host := getEnv("HTTP_PROXY", "localhost:5001") | |
lis, err := net.Listen("tcp", host) | |
if err != nil { | |
panic(err) | |
} | |
fmt.Printf("Listening on: %s\n", host) | |
fmt.Printf("Sending to: %s\n", server) | |
s.Serve(lis) | |
} | |
type Frame struct { | |
Payload []byte | |
} | |
// protoCodec is a Codec implementation with protobuf. It is the default rawCodec for gRPC. | |
type MyCodec struct{} | |
func (MyCodec) Marshal(v interface{}) ([]byte, error) { | |
// fmt.Printf("%+v", v) | |
out, ok := v.(*Frame) | |
if !ok { | |
return proto.Marshal(v.(proto.Message)) | |
} | |
return out.Payload, nil | |
} | |
func (MyCodec) Unmarshal(data []byte, v interface{}) error { | |
// fmt.Println(string(data)) | |
// fmt.Printf("%+v\n", v) | |
dst, ok := v.(*Frame) | |
if !ok { | |
return proto.Unmarshal(data, v.(proto.Message)) | |
} | |
dst.Payload = data | |
return nil | |
} | |
func (MyCodec) String() string { | |
return "proto" | |
} | |
func TransparentHandler(director StreamDirector) grpc.StreamHandler { | |
streamer := &handler{director} | |
return streamer.handler | |
} | |
type handler struct { | |
director StreamDirector | |
} | |
func (s *handler) handler(srv interface{}, serverStream grpc.ServerStream) error { | |
fullMethodName, ok := grpc.MethodFromServerStream(serverStream) | |
if !ok { | |
return grpc.Errorf(codes.Internal, "lowLevelServerStream not exists in context") | |
} | |
// We require that the director's returned context inherits from the serverStream.Context(). | |
outgoingCtx, backendConn, err := s.director(serverStream.Context(), fullMethodName) | |
if err != nil { | |
return err | |
} | |
clientCtx, _ := context.WithCancel(outgoingCtx) | |
// TODO(mwitkow): Add a `forwarded` header to metadata, https://en.wikipedia.org/wiki/X-Forwarded-For. | |
opts := []grpc.CallOption{grpc.FailFast(true)} | |
var clientStream grpc.ClientStream | |
if !offline { | |
clientStream, err = grpc.NewClientStream(clientCtx, clientStreamDescForProxying, backendConn, fullMethodName, opts...) | |
if err != nil { | |
panic(err) | |
offline = true | |
println("warning: backend not available, going offline") | |
//return err | |
} | |
} else { | |
clientStream = nil | |
} | |
c2sErrChan := s.forwardClientToServer(clientStream, serverStream) | |
c2sErr := <-c2sErrChan | |
if c2sErr != io.EOF { | |
return c2sErr | |
} | |
if clientStream != nil { | |
serverStream.SetTrailer(clientStream.Trailer()) | |
} | |
return nil | |
} | |
func (s *handler) forwardClientToServer(src grpc.ClientStream, dst grpc.ServerStream) chan error { | |
ret := make(chan error, 1) | |
go func() { | |
answer := &Frame{} | |
question := &Frame{} | |
// Get question from client | |
if err := dst.RecvMsg(question); err != nil { | |
ret <- err // this can be io.EOF which is happy case | |
return | |
} | |
found, err := store.Get(string(question.Payload), answer) | |
if err != nil { | |
fmt.Println(err) | |
ret <- err | |
return | |
} | |
if found { | |
if err := dst.SendMsg(answer); err != nil { | |
fmt.Println("breaking error") | |
ret <- err | |
return | |
} | |
if err := dst.RecvMsg(answer); err != nil { | |
ret <- err | |
return | |
} | |
return | |
} | |
// Send question to server | |
if err := src.SendMsg(question); err != nil { | |
ret <- err | |
return | |
} | |
if err := src.RecvMsg(answer); err != nil { | |
ret <- err // this can be io.EOF which is happy case | |
return | |
} | |
err = store.Set(string(question.Payload), answer) | |
if err != nil { | |
ret <- err | |
return | |
} | |
// Send answer to client | |
if err := dst.SendMsg(answer); err != nil { | |
ret <- err | |
return | |
} | |
// Get ack from client | |
if err := dst.RecvMsg(answer); err != nil { | |
ret <- err | |
return | |
} | |
return | |
}() | |
return ret | |
} | |
func getEnv(key string, defaultVal string) string { | |
if value, exists := os.LookupEnv(key); exists { | |
return value | |
} | |
return defaultVal | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment