Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
// Copyright 2017 Michal Witkowski. All Rights Reserved.
// See LICENSE for licensing terms.
// Modified Thomas Bereknyei
package main
import (
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"io/ioutil"
"net"
"os"
"github.com/golang/protobuf/proto"
"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/metadata"
"github.com/philippgille/gokv"
"github.com/philippgille/gokv/bbolt"
"gitlab.com/internal/go-cryptoservice"
)
type StreamDirector func(ctx context.Context, fullMethodName string) (context.Context, *grpc.ClientConn, error)
var (
clientStreamDescForProxying = &grpc.StreamDesc{
ServerStreams: true,
ClientStreams: true,
}
)
var creds grpc.DialOption
var store gokv.Store
var conn *grpc.ClientConn
func ExampleStreamDirector() StreamDirector {
return func(ctx context.Context, fullMethodName string) (context.Context, *grpc.ClientConn, error) {
md, ok := metadata.FromIncomingContext(ctx)
outCtx, _ := context.WithCancel(ctx)
outCtx = metadata.NewOutgoingContext(outCtx, md.Copy())
if ok {
return outCtx, conn, nil
}
return nil, nil, grpc.Errorf(codes.Unimplemented, "Unknown method")
}
}
var offline bool = true
func main() {
fmt.Println("USAGE: mm")
fmt.Println(" SERVER_URL: required - remote_host:port")
fmt.Println(" AWS_REGION: required")
fmt.Println(" HTTP_PROXY: optional - default: localhost:5001")
fmt.Println(" OFFLINE : optional")
fmt.Println(" KEY : optional")
fmt.Println(" CERT : optional")
fmt.Println(" CA : optional")
fmt.Println(" DB_PATH : optional")
var err error
if os.Getenv("OFFLINE") == "" {
offline = false
}
var ta credentials.TransportCredentials
certificate, err := tls.LoadX509KeyPair(os.Getenv("CERT"), os.Getenv("KEY"))
if err != nil {
ta, err = cryptoservice.RequestNewGRPCTransportCredentials(
cryptoservice.Endpoint())
if err != nil {
fmt.Println("Cannot obtain cryptoservice credentials")
}
fmt.Println("Received transport credentials from crypto service")
creds = grpc.WithTransportCredentials(ta)
} else {
certPool := x509.NewCertPool()
ca, err := ioutil.ReadFile(os.Getenv("CA"))
if err != nil {
fmt.Printf("could not read ca certificate: %s\n", err)
}
// Append the client certificates from the CA
if ok := certPool.AppendCertsFromPEM(ca); !ok {
fmt.Println("failed to append client certs")
return
}
ta = credentials.NewTLS(&tls.Config{
ClientAuth: tls.RequireAndVerifyClientCert,
Certificates: []tls.Certificate{certificate},
RootCAs: certPool,
ClientCAs: certPool,
})
creds = grpc.WithTransportCredentials(ta)
fmt.Println("using stored creds")
}
options := bbolt.DefaultOptions
options.Path = getEnv("DB_PATH", "cache.db")
store, _ = bbolt.NewStore(options)
server := getEnv("SERVER_URL", "localhost:5000")
conn, err = grpc.Dial(server, grpc.WithCodec(&MyCodec{}), creds)
if err != nil {
panic(err)
}
// Explicitly don't check err
var director StreamDirector = ExampleStreamDirector()
s := grpc.NewServer(
grpc.CustomCodec(&MyCodec{}),
grpc.UnknownServiceHandler(TransparentHandler(director)),
grpc.Creds(ta),
)
host := getEnv("HTTP_PROXY", "localhost:5001")
lis, err := net.Listen("tcp", host)
if err != nil {
panic(err)
}
fmt.Printf("Listening on: %s\n", host)
fmt.Printf("Sending to: %s\n", server)
s.Serve(lis)
}
type Frame struct {
Payload []byte
}
// protoCodec is a Codec implementation with protobuf. It is the default rawCodec for gRPC.
type MyCodec struct{}
func (MyCodec) Marshal(v interface{}) ([]byte, error) {
// fmt.Printf("%+v", v)
out, ok := v.(*Frame)
if !ok {
return proto.Marshal(v.(proto.Message))
}
return out.Payload, nil
}
func (MyCodec) Unmarshal(data []byte, v interface{}) error {
// fmt.Println(string(data))
// fmt.Printf("%+v\n", v)
dst, ok := v.(*Frame)
if !ok {
return proto.Unmarshal(data, v.(proto.Message))
}
dst.Payload = data
return nil
}
func (MyCodec) String() string {
return "proto"
}
func TransparentHandler(director StreamDirector) grpc.StreamHandler {
streamer := &handler{director}
return streamer.handler
}
type handler struct {
director StreamDirector
}
func (s *handler) handler(srv interface{}, serverStream grpc.ServerStream) error {
fullMethodName, ok := grpc.MethodFromServerStream(serverStream)
if !ok {
return grpc.Errorf(codes.Internal, "lowLevelServerStream not exists in context")
}
// We require that the director's returned context inherits from the serverStream.Context().
outgoingCtx, backendConn, err := s.director(serverStream.Context(), fullMethodName)
if err != nil {
return err
}
clientCtx, _ := context.WithCancel(outgoingCtx)
// TODO(mwitkow): Add a `forwarded` header to metadata, https://en.wikipedia.org/wiki/X-Forwarded-For.
opts := []grpc.CallOption{grpc.FailFast(true)}
var clientStream grpc.ClientStream
if !offline {
clientStream, err = grpc.NewClientStream(clientCtx, clientStreamDescForProxying, backendConn, fullMethodName, opts...)
if err != nil {
panic(err)
offline = true
println("warning: backend not available, going offline")
//return err
}
} else {
clientStream = nil
}
c2sErrChan := s.forwardClientToServer(clientStream, serverStream)
c2sErr := <-c2sErrChan
if c2sErr != io.EOF {
return c2sErr
}
if clientStream != nil {
serverStream.SetTrailer(clientStream.Trailer())
}
return nil
}
func (s *handler) forwardClientToServer(src grpc.ClientStream, dst grpc.ServerStream) chan error {
ret := make(chan error, 1)
go func() {
answer := &Frame{}
question := &Frame{}
// Get question from client
if err := dst.RecvMsg(question); err != nil {
ret <- err // this can be io.EOF which is happy case
return
}
found, err := store.Get(string(question.Payload), answer)
if err != nil {
fmt.Println(err)
ret <- err
return
}
if found {
if err := dst.SendMsg(answer); err != nil {
fmt.Println("breaking error")
ret <- err
return
}
if err := dst.RecvMsg(answer); err != nil {
ret <- err
return
}
return
}
// Send question to server
if err := src.SendMsg(question); err != nil {
ret <- err
return
}
if err := src.RecvMsg(answer); err != nil {
ret <- err // this can be io.EOF which is happy case
return
}
err = store.Set(string(question.Payload), answer)
if err != nil {
ret <- err
return
}
// Send answer to client
if err := dst.SendMsg(answer); err != nil {
ret <- err
return
}
// Get ack from client
if err := dst.RecvMsg(answer); err != nil {
ret <- err
return
}
return
}()
return ret
}
func getEnv(key string, defaultVal string) string {
if value, exists := os.LookupEnv(key); exists {
return value
}
return defaultVal
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment