Skip to content

Instantly share code, notes, and snippets.

@dnephin
Last active November 15, 2019 16:49
Show Gist options
  • Save dnephin/4b8106cf5b09494d1caecdf8f74ce88e to your computer and use it in GitHub Desktop.
Save dnephin/4b8106cf5b09494d1caecdf8f74ce88e to your computer and use it in GitHub Desktop.
Go - Logger from context.Context
package logging
import (
"context"
"github.com/sirupsen/logrus"
)
type (
fieldsKey struct{}
loggerKey struct{}
)
// FromContext returns a logger from the context. The Logger is configured with
// any fields set using WithField, or WithFields.
func FromContext(ctx context.Context) logrus.FieldLogger {
logger := ctx.Value(loggerKey{})
fields := getFields(ctx)
if logger == nil {
return logrus.StandardLogger().WithFields(fields)
}
return logger.(logrus.FieldLogger).WithFields(fields)
}
func getFields(ctx context.Context) logrus.Fields {
fields := ctx.Value(fieldsKey{})
if fields == nil {
return logrus.Fields{}
}
return fields.(logrus.Fields)
}
// WithLogger creates a new Logger from fields, and sets it on the Context.
func WithLogger(ctx context.Context, logger logrus.FieldLogger) context.Context {
return context.WithValue(ctx, loggerKey{}, logger)
}
// WithField adds the key and value to the context which will be added to the logger
// when retrieved with FromContext.
func WithField(ctx context.Context, key string, value interface{}) context.Context {
existing := getFields(ctx)
existing[key] = value
return context.WithValue(ctx, fieldsKey{}, existing)
}
// WithFields adds fields to the context which will be added to the logger
// when retrieved with FromContext.
func WithFields(ctx context.Context, fields logrus.Fields) context.Context {
existing := getFields(ctx)
for k, v := range fields {
existing[k] = v
}
return context.WithValue(ctx, fieldsKey{}, existing)
}
package logging
import (
"context"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/sirupsen/logrus"
"gotest.tools/assert"
)
type formatterStub struct {
entries []*logrus.Entry
}
func (f *formatterStub) Format(entry *logrus.Entry) ([]byte, error) {
f.entries = append(f.entries, entry)
return nil, nil
}
func TestFromContext(t *testing.T) {
ctx := context.Background()
t.Run("returns a default value before init", func(t *testing.T) {
formatter := &formatterStub{}
msg := "no default fields yet"
logger := FromContext(ctx)
setFormatter(t, logger, formatter)
logger.Warn(msg)
expected := []*logrus.Entry{
{
Level: logrus.WarnLevel,
Message: msg,
Data: logrus.Fields{},
},
}
assert.DeepEqual(t, formatter.entries, expected, cmpEntry)
})
t.Run("returns a logger with fields", func(t *testing.T) {
formatter := &formatterStub{}
msg := "with fields"
ctx := WithField(ctx, "key", 12345)
ctx = WithFields(ctx, logrus.Fields{
"another-key": "ok",
})
logger := FromContext(ctx)
setFormatter(t, logger, formatter)
logger.Info(msg)
expected := []*logrus.Entry{
{
Level: logrus.InfoLevel,
Message: msg,
Data: logrus.Fields{
"key": 12345,
"another-key": "ok",
},
},
}
assert.DeepEqual(t, formatter.entries, expected, cmpEntry)
})
}
var cmpEntry = cmp.Comparer(func(x, y logrus.Entry) bool {
return x.Message == y.Message && x.Level == y.Level && cmp.Equal(x.Data, y.Data)
})
func setFormatter(t *testing.T, logger logrus.FieldLogger, formatter logrus.Formatter) {
switch logger := logger.(type) {
case *logrus.Entry:
logger.Logger.Formatter = formatter
default:
t.Fatalf("unexpected logger type %T", logger)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment