Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
GRPC Opentracing Interceptors
package otgrpc
import (
context "golang.org/x/net/context"
"github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"
"github.com/opentracing/opentracing-go/log"
"google.golang.org/grpc"
)
func UnaryClientInterceptor(tracer opentracing.Tracer, o ...Option) grpc.UnaryClientInterceptor {
traceOpts := newOptions(o...)
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
if !traceOpts.traceEnabledFunc(method, false) {
return invoker(ctx, method, req, reply, cc, opts...)
}
parentSpanCtx := spanContextFromContext(ctx)
childSpan := tracer.StartSpan(method, opentracing.ChildOf(parentSpanCtx), ext.SpanKindRPCClient, GRPCComponentTag)
newCtx, err := injectSpanToMetadata(tracer, childSpan, ctx)
if err != nil {
childSpan.LogFields(log.String(EventKey, "Span injection failed"), log.Error(err))
}
err = invoker(newCtx, method, req, reply, cc, opts...)
if err != nil {
childSpan.LogFields(log.String(EventKey, "gRPC invocation failed"), log.Error(err))
}
childSpan.Finish()
return err
}
}
func StreamClientInterceptor(tracer opentracing.Tracer, o ...Option) grpc.StreamClientInterceptor {
traceOpts := newOptions(o...)
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
if !traceOpts.traceEnabledFunc(method, true) {
return streamer(ctx, desc, cc, method, opts...)
}
parentSpanCtx := spanContextFromContext(ctx)
childSpan := tracer.StartSpan(method, opentracing.ChildOf(parentSpanCtx), ext.SpanKindRPCClient, GRPCComponentTag)
newCtx, err := injectSpanToMetadata(tracer, childSpan, ctx)
if err != nil {
childSpan.LogFields(log.String(EventKey, "Span injection failed"), log.Error(err))
}
stream, err := streamer(newCtx, desc, cc, method, opts...)
if err != nil {
childSpan.LogFields(log.String(EventKey, "gRPC invocation failed"), log.Error(err))
}
childSpan.Finish()
return stream, err
}
}
package otgrpc
import (
"context"
"google.golang.org/grpc/metadata"
"github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"
)
const EventKey = "event"
var GRPCComponentTag = opentracing.Tag{Key: string(ext.Component), Value: "gRPC"}
func extractSpanContext(tracer opentracing.Tracer, ctx context.Context) opentracing.SpanContext {
var sc opentracing.SpanContext
sc = spanContextFromContext(ctx)
if sc != nil {
return sc
}
sc = extractSpanContextFromMetadata(tracer, ctx)
return sc
}
func spanContextFromContext(ctx context.Context) opentracing.SpanContext {
if parentSpan := opentracing.SpanFromContext(ctx); parentSpan != nil {
return parentSpan.Context()
}
return nil
}
func injectSpanToMetadata(tracer opentracing.Tracer, span opentracing.Span, ctx context.Context) (context.Context, error) {
var md metadata.MD
if tmpMD, ok := metadata.FromContext(ctx); ok {
md = tmpMD.Copy()
} else {
md = metadata.New(nil)
}
if err := tracer.Inject(span.Context(), opentracing.HTTPHeaders, NewMetadataReaderWriter(md)); err != nil {
return ctx, err
}
return metadata.NewContext(ctx, md), nil
}
func extractSpanContextFromMetadata(tracer opentracing.Tracer, ctx context.Context) opentracing.SpanContext {
var md metadata.MD
if tmpMD, ok := metadata.FromContext(ctx); ok {
md = tmpMD
} else {
md = metadata.New(nil)
}
// TODO How to deal with errors from Extract
spanContext, _ := tracer.Extract(opentracing.HTTPHeaders, NewMetadataReaderWriter(md))
return spanContext
}
package otgrpc
import (
"strings"
"google.golang.org/grpc/metadata"
)
type MetadataReaderWriter struct {
md metadata.MD
}
// NewMetadataReaderWriter creates an object that implements the opentracing.TextMapReader and opentracing.TextMapWriter interfaces
func NewMetadataReaderWriter(md metadata.MD) *MetadataReaderWriter {
return &MetadataReaderWriter{md: md}
}
func (mrw *MetadataReaderWriter) ForeachKey(handler func(string, string) error) error {
for key, values := range mrw.md {
for _, value := range values {
if dk, dv, err := metadata.DecodeKeyValue(key, value); err != nil {
return err
} else {
if err = handler(dk, dv); err != nil {
return err
}
}
}
}
return nil
}
func (mrw *MetadataReaderWriter) Set(key, value string) {
// headers should be lowercase
k := strings.ToLower(key)
mrw.md[k] = append(mrw.md[k], value)
}
package otgrpc
type Option func(*options)
type options struct {
traceEnabledFunc func(method string, isStream bool) bool
}
func newOptions(opts ...Option) *options {
o := &options{}
for _, opt := range opts {
opt(o)
}
if o.traceEnabledFunc == nil {
o.traceEnabledFunc = func(string, bool) bool { return true }
}
return o
}
// WithTraceEnabledFunc defines a function that indicates to the tracing implementation whether the method should be traced or not
func WithTraceEnabledFunc(f func(method string, isStream bool) bool) Option {
return func(opt *options) {
opt.traceEnabledFunc = f
}
}
package otgrpc
import (
context "golang.org/x/net/context"
"github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"
"github.com/opentracing/opentracing-go/log"
"google.golang.org/grpc"
)
func UnaryServerInterceptor(tracer opentracing.Tracer, o ...Option) grpc.UnaryServerInterceptor {
traceOpts := newOptions(o...)
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
if !traceOpts.traceEnabledFunc(info.FullMethod, false) {
return handler(ctx, req)
}
spanContext := extractSpanContextFromMetadata(tracer, ctx)
span := tracer.StartSpan(info.FullMethod, ext.RPCServerOption(spanContext), GRPCComponentTag)
newCtx := opentracing.ContextWithSpan(ctx, span)
resp, err := handler(newCtx, req)
if err != nil {
span.LogFields(log.String(EventKey, "gRPC invocation failed"), log.Error(err))
}
return resp, err
}
}
func StreamServerInterceptor(tracer opentracing.Tracer, o ...Option) grpc.StreamServerInterceptor {
traceOpts := newOptions(o...)
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
if !traceOpts.traceEnabledFunc(info.FullMethod, true) {
return handler(srv, ss)
}
spanContext := extractSpanContextFromMetadata(tracer, ss.Context())
span := tracer.StartSpan(info.FullMethod, ext.RPCServerOption(spanContext), GRPCComponentTag)
newCtx := opentracing.ContextWithSpan(ss.Context(), span)
newStream := WrapServerStream(ss, newCtx)
err := handler(srv, newStream)
if err != nil {
span.LogFields(log.String(EventKey, "gRPC invocation failed"), log.Error(err))
span.Finish()
} else {
// TODO Use Goroutine pool
go func() {
<-newStream.Context().Done()
span.Finish()
}()
}
return err
}
}
package otgrpc
import (
context "golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)
type ServerStreamWrapper struct {
stream grpc.ServerStream
ctx context.Context
}
func WrapServerStream(stream grpc.ServerStream, ctx context.Context) *ServerStreamWrapper {
return &ServerStreamWrapper{stream: stream, ctx: ctx}
}
func (ssw *ServerStreamWrapper) SetHeader(md metadata.MD) error {
return ssw.stream.SetHeader(md)
}
func (ssw *ServerStreamWrapper) SendHeader(md metadata.MD) error {
return ssw.stream.SendHeader(md)
}
func (ssw *ServerStreamWrapper) SetTrailer(md metadata.MD) {
ssw.stream.SetTrailer(md)
}
func (ssw *ServerStreamWrapper) Context() context.Context {
return ssw.ctx
}
func (ssw *ServerStreamWrapper) SendMsg(m interface{}) error {
return ssw.stream.SendMsg(m)
}
func (ssw *ServerStreamWrapper) RecvMsg(m interface{}) error {
return ssw.stream.RecvMsg(m)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment