Skip to content

Instantly share code, notes, and snippets.

@NathanBaulch
Last active May 3, 2023 21:14
Show Gist options
  • Save NathanBaulch/cf585100454b0afd724ccb6eb70eb334 to your computer and use it in GitHub Desktop.
Save NathanBaulch/cf585100454b0afd724ccb6eb70eb334 to your computer and use it in GitHub Desktop.
Pass OTEL span context across Redis pub/sub boundary
package redisotel
import (
"context"
"strings"
"sync"
"github.com/redis/go-redis/v9"
"go.opentelemetry.io/otel/propagation"
"go.opentelemetry.io/otel/trace"
)
var tc propagation.TraceContext
type PubSub struct {
*redis.PubSub
rds redis.UniversalClient
msgcOnce sync.Once
msgc chan *Message
}
type Message struct {
*redis.Message
trace.SpanContext
}
func Subscribe(ctx context.Context, rds redis.UniversalClient, channels ...string) *PubSub {
return &PubSub{PubSub: rds.Subscribe(ctx, channels...), rds: rds}
}
func PSubscribe(ctx context.Context, rds redis.UniversalClient, channels ...string) *PubSub {
return &PubSub{PubSub: rds.PSubscribe(ctx, channels...), rds: rds}
}
func SSubscribe(ctx context.Context, rds redis.UniversalClient, channels ...string) *PubSub {
return &PubSub{PubSub: rds.SSubscribe(ctx, channels...), rds: rds}
}
func (c *PubSub) Publish(ctx context.Context, channel string, message interface{}) *redis.IntCmd {
return c.rds.Publish(ctx, channel, c.inject(ctx, message))
}
func (c *PubSub) SPublish(ctx context.Context, channel string, message interface{}) *redis.IntCmd {
return c.rds.SPublish(ctx, channel, c.inject(ctx, message))
}
func (c *PubSub) ReceiveMessage(ctx context.Context) (*Message, error) {
if msg, err := c.PubSub.ReceiveMessage(ctx); err != nil {
return nil, err
} else {
return c.extract(msg), nil
}
}
func (c *PubSub) Channel(opts ...redis.ChannelOption) <-chan *Message {
c.msgcOnce.Do(func() {
c.msgc = make(chan *Message, 100)
go func() {
defer close(c.msgc)
for msg := range c.PubSub.Channel(opts...) {
c.msgc <- c.extract(msg)
}
}()
})
return c.msgc
}
func (c *PubSub) inject(ctx context.Context, message interface{}) interface{} {
if sc := trace.SpanContextFromContext(ctx); sc.IsValid() {
mc := &propagation.MapCarrier{}
tc.Inject(ctx, mc)
lines := make([]string, len(tc.Fields())+1)
for i, fld := range tc.Fields() {
lines[i] = mc.Get(fld)
}
lines[len(lines)-1] = redis.NewCmd(ctx, message).String()
return strings.Join(lines, "\n")
}
return message
}
func (c *PubSub) extract(msg *redis.Message) *Message {
lines := strings.SplitN(msg.Payload, "\n", len(tc.Fields())+1)
if len(lines) > len(tc.Fields()) {
mc := &propagation.MapCarrier{}
for i, fld := range tc.Fields() {
mc.Set(fld, lines[i])
}
msg.Payload = lines[len(lines)-1]
return &Message{
Message: msg,
SpanContext: trace.SpanContextFromContext(tc.Extract(context.Background(), mc)),
}
}
return &Message{Message: msg}
}
func (m *Message) WithContext(ctx context.Context) context.Context {
return trace.ContextWithRemoteSpanContext(ctx, m.SpanContext)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment